You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/01 02:58:06 UTC

[2/2] git commit: [MLLIB] SPARK-1547: Add Gradient Boosting to MLlib

[MLLIB] SPARK-1547: Add Gradient Boosting to MLlib

Given the popular demand for gradient boosting and AdaBoost in MLlib, I am creating a WIP branch for early feedback on gradient boosting with AdaBoost to follow soon after this PR is accepted. This is based on work done along with hirakendu that was pending due to decision tree optimizations and random forests work.

Ideally, boosting algorithms should work with any base learners.  This will soon be possible once the MLlib API is finalized -- we want to ensure we use a consistent interface for the underlying base learners. In the meantime, this PR uses decision trees as base learners for the gradient boosting algorithm. The current PR allows "pluggable" loss functions and provides least squares error and least absolute error by default.

Here is the task list:
- [x] Gradient boosting support
- [x] Pluggable loss functions
- [x] Stochastic gradient boosting support – Re-use the BaggedPoint approach used for RandomForest.
- [x] Binary classification support
- [x] Support configurable checkpointing – This approach will avoid long lineage chains.
- [x] Create classification and regression APIs
- [x] Weighted Ensemble Model -- created a WeightedEnsembleModel class that can be used by ensemble algorithms such as random forests and boosting.
- [x] Unit Tests

Future work:
+ Multi-class classification is currently not supported by this PR since it requires discussion on the best way to support "deviance" as a loss function.
+ BaggedRDD caching -- Avoid repeating feature to bin mapping for each tree estimator after standard API work is completed.

cc: jkbradley hirakendu mengxr etrain atalwalkar chouqin

Author: Manish Amde <ma...@gmail.com>
Author: manishamde <ma...@gmail.com>

Closes #2607 from manishamde/gbt and squashes the following commits:

991c7b5 [Manish Amde] public api
ff2a796 [Manish Amde] addressing comments
b4c1318 [Manish Amde] removing spaces
8476b6b [Manish Amde] fixing line length
0183cb9 [Manish Amde] fixed naming and formatting issues
1c40c33 [Manish Amde] add newline, removed spaces
e33ab61 [Manish Amde] minor comment
eadbf09 [Manish Amde] parameter renaming
035a2ed [Manish Amde] jkbradley formatting suggestions
9f7359d [Manish Amde] simplified gbt logic and added more tests
49ba107 [Manish Amde] merged from master
eff21fe [Manish Amde] Added gradient boosting tests
3fd0528 [Manish Amde] moved helper methods to new class
a32a5ab [Manish Amde] added test for subsampling without replacement
781542a [Manish Amde] added support for fractional subsampling with replacement
3a18cc1 [Manish Amde] cleaned up api for conversion to bagged point and moved tests to it's own test suite
0e81906 [Manish Amde] improving caching unpersisting logic
d971f73 [Manish Amde] moved RF code to use WeightedEnsembleModel class
fee06d3 [Manish Amde] added weighted ensemble model
1b01943 [Manish Amde] add weights for base learners
9bc6e74 [Manish Amde] adding random seed as parameter
d2c8323 [Manish Amde] Merge branch 'master' into gbt
2ae97b7 [Manish Amde] added documentation for the loss classes
9366b8f [Manish Amde] minor: using numTrees instead of trees.size
3b43896 [Manish Amde] added learning rate for prediction
9b2e35e [Manish Amde] Merge branch 'master' into gbt
6a11c02 [manishamde] fixing formatting
823691b [Manish Amde] fixing RF test
1f47941 [Manish Amde] changing access modifier
5b67102 [Manish Amde] shortened parameter list
5ab3796 [Manish Amde] minor reformatting
9155a9d [Manish Amde] consolidated boosting configuration and added public API
631baea [Manish Amde] Merge branch 'master' into gbt
2cb1258 [Manish Amde] public API support
3b8ffc0 [Manish Amde] added documentation
8e10c63 [Manish Amde] modified unpersist strategy
f62bc48 [Manish Amde] added unpersist
bdca43a [Manish Amde] added timing parameters
2fbc9c7 [Manish Amde] fixing binomial classification prediction
6dd4dd8 [Manish Amde] added support for log loss
9af0231 [Manish Amde] classification attempt
62cc000 [Manish Amde] basic checkpointing
4784091 [Manish Amde] formatting
78ed452 [Manish Amde] added newline and fixed if statement
3973dd1 [Manish Amde] minor indicating subsample is double during comparison
aa8fae7 [Manish Amde] minor refactoring
1a8031c [Manish Amde] sampling with replacement
f1c9ef7 [Manish Amde] Merge branch 'master' into gbt
cdceeef [Manish Amde] added documentation
6251fd5 [Manish Amde] modified method name
5538521 [Manish Amde] disable checkpointing for now
0ae1c0a [Manish Amde] basic gradient boosting code from earlier branches


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/86021955
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/86021955
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/86021955

Branch: refs/heads/master
Commit: 8602195510f5821b37746bb7fa24902f43a1bd93
Parents: e07fb6a
Author: Manish Amde <ma...@gmail.com>
Authored: Fri Oct 31 18:57:55 2014 -0700
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Fri Oct 31 18:57:55 2014 -0700

----------------------------------------------------------------------
 .../examples/mllib/DecisionTreeRunner.scala     |   4 +-
 .../apache/spark/mllib/tree/DecisionTree.scala  |   2 +-
 .../spark/mllib/tree/GradientBoosting.scala     | 314 +++++++++++++++++++
 .../apache/spark/mllib/tree/RandomForest.scala  |  49 +--
 .../tree/configuration/BoostingStrategy.scala   | 109 +++++++
 .../EnsembleCombiningStrategy.scala             |  30 ++
 .../mllib/tree/configuration/Strategy.scala     |  23 +-
 .../spark/mllib/tree/impl/BaggedPoint.scala     |  69 +++-
 .../spark/mllib/tree/loss/AbsoluteError.scala   |  66 ++++
 .../apache/spark/mllib/tree/loss/LogLoss.scala  |  63 ++++
 .../org/apache/spark/mllib/tree/loss/Loss.scala |  52 +++
 .../apache/spark/mllib/tree/loss/Losses.scala   |  29 ++
 .../spark/mllib/tree/loss/SquaredError.scala    |  66 ++++
 .../mllib/tree/model/RandomForestModel.scala    | 115 -------
 .../tree/model/WeightedEnsembleModel.scala      | 158 ++++++++++
 .../spark/mllib/tree/DecisionTreeSuite.scala    |   6 +-
 .../spark/mllib/tree/EnsembleTestHelper.scala   |  94 ++++++
 .../mllib/tree/GradientBoostingSuite.scala      | 132 ++++++++
 .../spark/mllib/tree/RandomForestSuite.scala    | 117 +------
 .../mllib/tree/impl/BaggedPointSuite.scala      | 100 ++++++
 20 files changed, 1331 insertions(+), 267 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 0890e62..f987303 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
 import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
 import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.Utils
@@ -317,7 +317,7 @@ object DecisionTreeRunner {
   /**
    * Calculates the mean squared error for regression.
    */
-  private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
+  private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
     data.map { y =>
       val err = tree.predict(y.features) - y.label
       err * err

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 6737a2f..752ed59 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -62,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     // Note: random seed will not be used since numTrees = 1.
     val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
     val rfModel = rf.train(input)
-    rfModel.trees(0)
+    rfModel.weakHypotheses(0)
   }
 
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
new file mode 100644
index 0000000..1a84720
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy}
+import org.apache.spark.Logging
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.loss.Losses
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
+
+/**
+ * :: Experimental ::
+ * A class that implements gradient boosting for regression and binary classification problems.
+ * @param boostingStrategy Parameters for the gradient boosting algorithm
+ */
+@Experimental
+class GradientBoosting (
+    private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
+
+  /**
+   * Method to train a gradient boosting model
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   * @return WeightedEnsembleModel that can be used for prediction
+   */
+  def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
+    val algo = boostingStrategy.algo
+    algo match {
+      case Regression => GradientBoosting.boost(input, boostingStrategy)
+      case Classification =>
+        val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+        GradientBoosting.boost(remappedInput, boostingStrategy)
+      case _ =>
+        throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+    }
+  }
+
+}
+
+
+object GradientBoosting extends Logging {
+
+  /**
+   * Method to train a gradient boosting model.
+   *
+   * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
+   *       is recommended to clearly specify regression.
+   *       Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
+   *       is recommended to clearly specify regression.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
+   *              For regression, labels are real numbers.
+   * @param boostingStrategy Configuration options for the boosting algorithm.
+   * @return WeightedEnsembleModel that can be used for prediction
+   */
+  def train(
+      input: RDD[LabeledPoint],
+      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+    new GradientBoosting(boostingStrategy).train(input)
+  }
+
+  /**
+   * Method to train a gradient boosting classification model.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
+   *              For regression, labels are real numbers.
+   * @param boostingStrategy Configuration options for the boosting algorithm.
+   * @return WeightedEnsembleModel that can be used for prediction
+   */
+  def trainClassifier(
+      input: RDD[LabeledPoint],
+      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+    val algo = boostingStrategy.algo
+    require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
+    new GradientBoosting(boostingStrategy).train(input)
+  }
+
+  /**
+   * Method to train a gradient boosting regression model.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
+   *              For regression, labels are real numbers.
+   * @param boostingStrategy Configuration options for the boosting algorithm.
+   * @return WeightedEnsembleModel that can be used for prediction
+   */
+  def trainRegressor(
+      input: RDD[LabeledPoint],
+      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+    val algo = boostingStrategy.algo
+    require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
+    new GradientBoosting(boostingStrategy).train(input)
+  }
+
+  /**
+   * Method to train a gradient boosting binary classification model.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
+   *              For regression, labels are real numbers.
+   * @param numEstimators Number of estimators used in boosting stages. In other words,
+   *                      number of boosting iterations performed.
+   * @param loss Loss function used for minimization during gradient boosting.
+   * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+   *                     learning rate should be between in the interval (0, 1]
+   * @param subsamplingRate  Fraction of the training data used for learning the decision tree.
+   * @param numClassesForClassification Number of classes for classification.
+   *                                    (Ignored for regression.)
+   * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+   *                                the number of discrete values they take. For example,
+   *                                an entry (n -> k) implies the feature n is categorical with k
+   *                                categories 0, 1, 2, ... , k-1. It's important to note that
+   *                                features are zero-indexed.
+   * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
+   *                          supported.)
+   * @return WeightedEnsembleModel that can be used for prediction
+   */
+  def trainClassifier(
+      input: RDD[LabeledPoint],
+      numEstimators: Int,
+      loss: String,
+      learningRate: Double,
+      subsamplingRate: Double,
+      numClassesForClassification: Int,
+      categoricalFeaturesInfo: Map[Int, Int],
+      weakLearnerParams: Strategy): WeightedEnsembleModel = {
+    val lossType = Losses.fromString(loss)
+    val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType,
+      learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
+      weakLearnerParams)
+    new GradientBoosting(boostingStrategy).train(input)
+  }
+
+  /**
+   * Method to train a gradient boosting regression model.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              For classification, labels should take values {0, 1, ..., numClasses-1}.
+   *              For regression, labels are real numbers.
+   * @param numEstimators Number of estimators used in boosting stages. In other words,
+   *                      number of boosting iterations performed.
+   * @param loss Loss function used for minimization during gradient boosting.
+   * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+   *                     learning rate should be between in the interval (0, 1]
+   * @param subsamplingRate  Fraction of the training data used for learning the decision tree.
+   * @param numClassesForClassification Number of classes for classification.
+   *                                    (Ignored for regression.)
+   * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+   *                                the number of discrete values they take. For example,
+   *                                an entry (n -> k) implies the feature n is categorical with k
+   *                                categories 0, 1, 2, ... , k-1. It's important to note that
+   *                                features are zero-indexed.
+   * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
+   *                          supported.)
+   * @return WeightedEnsembleModel that can be used for prediction
+   */
+  def trainRegressor(
+       input: RDD[LabeledPoint],
+       numEstimators: Int,
+       loss: String,
+       learningRate: Double,
+       subsamplingRate: Double,
+       numClassesForClassification: Int,
+       categoricalFeaturesInfo: Map[Int, Int],
+       weakLearnerParams: Strategy): WeightedEnsembleModel = {
+    val lossType = Losses.fromString(loss)
+    val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType,
+      learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
+      weakLearnerParams)
+    new GradientBoosting(boostingStrategy).train(input)
+  }
+
+  /**
+   * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
+   */
+  def trainClassifier(
+      input: RDD[LabeledPoint],
+      numEstimators: Int,
+      loss: String,
+      learningRate: Double,
+      subsamplingRate: Double,
+      numClassesForClassification: Int,
+      categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer],
+      weakLearnerParams: Strategy): WeightedEnsembleModel = {
+    trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate,
+      numClassesForClassification,
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+      weakLearnerParams)
+  }
+
+  /**
+   * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
+   */
+  def trainRegressor(
+      input: RDD[LabeledPoint],
+      numEstimators: Int,
+      loss: String,
+      learningRate: Double,
+      subsamplingRate: Double,
+      numClassesForClassification: Int,
+      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+      weakLearnerParams: Strategy): WeightedEnsembleModel = {
+    trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate,
+      numClassesForClassification,
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+      weakLearnerParams)
+  }
+
+
+  /**
+   * Internal method for performing regression using trees as base learners.
+   * @param input training dataset
+   * @param boostingStrategy boosting parameters
+   * @return
+   */
+  private def boost(
+      input: RDD[LabeledPoint],
+      boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+
+    val timer = new TimeTracker()
+    timer.start("total")
+    timer.start("init")
+
+    // Initialize gradient boosting parameters
+    val numEstimators = boostingStrategy.numEstimators
+    val baseLearners = new Array[DecisionTreeModel](numEstimators)
+    val baseLearnerWeights = new Array[Double](numEstimators)
+    val loss = boostingStrategy.loss
+    val learningRate = boostingStrategy.learningRate
+    val strategy = boostingStrategy.weakLearnerParams
+
+    // Cache input
+    input.persist(StorageLevel.MEMORY_AND_DISK)
+
+    timer.stop("init")
+
+    logDebug("##########")
+    logDebug("Building tree 0")
+    logDebug("##########")
+    var data = input
+
+    // 1. Initialize tree
+    timer.start("building tree 0")
+    val firstTreeModel = new DecisionTree(strategy).train(data)
+    baseLearners(0) = firstTreeModel
+    baseLearnerWeights(0) = 1.0
+    val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
+      Sum)
+    logDebug("error of gbt = " + loss.computeError(startingModel, input))
+    // Note: A model of type regression is used since we require raw prediction
+    timer.stop("building tree 0")
+
+    // psuedo-residual for second iteration
+    data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
+      point.features))
+
+    var m = 1
+    while (m < numEstimators) {
+      timer.start(s"building tree $m")
+      logDebug("###################################################")
+      logDebug("Gradient boosting tree iteration " + m)
+      logDebug("###################################################")
+      val model = new DecisionTree(strategy).train(data)
+      timer.stop(s"building tree $m")
+      // Create partial model
+      baseLearners(m) = model
+      baseLearnerWeights(m) = learningRate
+      // Note: A model of type regression is used since we require raw prediction
+      val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
+        baseLearnerWeights.slice(0, m + 1), Regression, Sum)
+      logDebug("error of gbt = " + loss.computeError(partialModel, input))
+      // Update data with pseudo-residuals
+      data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
+        point.features))
+      m += 1
+    }
+
+    timer.stop("total")
+
+    logInfo("Internal timing for DecisionTree:")
+    logInfo(s"$timer")
+
+
+    // 3. Output classifier
+    new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
+
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index ebbd8e0..1dcaf91 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -26,6 +26,7 @@ import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
 import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
 import org.apache.spark.mllib.tree.impurity.Impurities
@@ -59,7 +60,7 @@ import org.apache.spark.util.Utils
  *                                if numTrees == 1, set to "all";
  *                                if numTrees > 1 (forest) set to "sqrt" for classification and
  *                                  to "onethird" for regression.
- * @param seed  Random seed for bootstrapping and choosing feature subsets.
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
  */
 @Experimental
 private class RandomForest (
@@ -78,9 +79,9 @@ private class RandomForest (
   /**
    * Method to train a decision tree model over an RDD
    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
-   * @return RandomForestModel that can be used for prediction
+   * @return WeightedEnsembleModel that can be used for prediction
    */
-  def train(input: RDD[LabeledPoint]): RandomForestModel = {
+  def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
 
     val timer = new TimeTracker()
 
@@ -111,11 +112,20 @@ private class RandomForest (
     // Bin feature values (TreePoint representation).
     // Cache input RDD for speedup during multiple passes.
     val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
-    val baggedInput = if (numTrees > 1) {
-      BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
-    } else {
-      BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
-    }.persist(StorageLevel.MEMORY_AND_DISK)
+
+    val (subsample, withReplacement) = {
+      // TODO: Have a stricter check for RF in the strategy
+      val isRandomForest = numTrees > 1
+      if (isRandomForest) {
+        (1.0, true)
+      } else {
+        (strategy.subsamplingRate, false)
+      }
+    }
+
+    val baggedInput
+      = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
+        .persist(StorageLevel.MEMORY_AND_DISK)
 
     // depth of the decision tree
     val maxDepth = strategy.maxDepth
@@ -184,7 +194,8 @@ private class RandomForest (
     logInfo(s"$timer")
 
     val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
-    RandomForestModel.build(trees)
+    val treeWeights = Array.fill[Double](numTrees)(1.0)
+    new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
   }
 
 }
@@ -205,14 +216,14 @@ object RandomForest extends Serializable with Logging {
    *                                if numTrees > 1 (forest) set to "sqrt" for classification and
    *                                  to "onethird" for regression.
    * @param seed  Random seed for bootstrapping and choosing feature subsets.
-   * @return RandomForestModel that can be used for prediction
+   * @return WeightedEnsembleModel that can be used for prediction
    */
   def trainClassifier(
       input: RDD[LabeledPoint],
       strategy: Strategy,
       numTrees: Int,
       featureSubsetStrategy: String,
-      seed: Int): RandomForestModel = {
+      seed: Int): WeightedEnsembleModel = {
     require(strategy.algo == Classification,
       s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
     val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
@@ -243,7 +254,7 @@ object RandomForest extends Serializable with Logging {
    * @param maxBins maximum number of bins used for splitting features
    *                 (suggested value: 100)
    * @param seed  Random seed for bootstrapping and choosing feature subsets.
-   * @return RandomForestModel that can be used for prediction
+   * @return WeightedEnsembleModel that can be used for prediction
    */
   def trainClassifier(
       input: RDD[LabeledPoint],
@@ -254,7 +265,7 @@ object RandomForest extends Serializable with Logging {
       impurity: String,
       maxDepth: Int,
       maxBins: Int,
-      seed: Int = Utils.random.nextInt()): RandomForestModel = {
+      seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
     val impurityType = Impurities.fromString(impurity)
     val strategy = new Strategy(Classification, impurityType, maxDepth,
       numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
@@ -273,7 +284,7 @@ object RandomForest extends Serializable with Logging {
       impurity: String,
       maxDepth: Int,
       maxBins: Int,
-      seed: Int): RandomForestModel = {
+      seed: Int): WeightedEnsembleModel = {
     trainClassifier(input.rdd, numClassesForClassification,
       categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
       numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -293,14 +304,14 @@ object RandomForest extends Serializable with Logging {
    *                                if numTrees > 1 (forest) set to "sqrt" for classification and
    *                                  to "onethird" for regression.
    * @param seed  Random seed for bootstrapping and choosing feature subsets.
-   * @return RandomForestModel that can be used for prediction
+   * @return WeightedEnsembleModel that can be used for prediction
    */
   def trainRegressor(
       input: RDD[LabeledPoint],
       strategy: Strategy,
       numTrees: Int,
       featureSubsetStrategy: String,
-      seed: Int): RandomForestModel = {
+      seed: Int): WeightedEnsembleModel = {
     require(strategy.algo == Regression,
       s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
     val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
@@ -330,7 +341,7 @@ object RandomForest extends Serializable with Logging {
    * @param maxBins maximum number of bins used for splitting features
    *                 (suggested value: 100)
    * @param seed  Random seed for bootstrapping and choosing feature subsets.
-   * @return RandomForestModel that can be used for prediction
+   * @return WeightedEnsembleModel that can be used for prediction
    */
   def trainRegressor(
       input: RDD[LabeledPoint],
@@ -340,7 +351,7 @@ object RandomForest extends Serializable with Logging {
       impurity: String,
       maxDepth: Int,
       maxBins: Int,
-      seed: Int = Utils.random.nextInt()): RandomForestModel = {
+      seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
     val impurityType = Impurities.fromString(impurity)
     val strategy = new Strategy(Regression, impurityType, maxDepth,
       0, maxBins, Sort, categoricalFeaturesInfo)
@@ -358,7 +369,7 @@ object RandomForest extends Serializable with Logging {
       impurity: String,
       maxDepth: Int,
       maxBins: Int,
-      seed: Int): RandomForestModel = {
+      seed: Int): WeightedEnsembleModel = {
     trainRegressor(input.rdd,
       categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
       numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
new file mode 100644
index 0000000..501d9ff
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import scala.beans.BeanProperty
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
+import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
+
+/**
+ * :: Experimental ::
+ * Stores all the configuration options for the boosting algorithms
+ * @param algo  Learning goal.  Supported:
+ *              [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ *              [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @param numEstimators Number of estimators used in boosting stages. In other words,
+ *                      number of boosting iterations performed.
+ * @param loss Loss function used for minimization during gradient boosting.
+ * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+ *                     learning rate should be between in the interval (0, 1]
+ * @param subsamplingRate  Fraction of the training data used for learning the decision tree.
+ * @param numClassesForClassification Number of classes for classification.
+ *                                    (Ignored for regression.)
+ *                                    Default value is 2 (binary classification).
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
+ *                                number of discrete values they take. For example, an entry (n ->
+ *                                k) implies the feature n is categorical with k categories 0,
+ *                                1, 2, ... , k-1. It's important to note that features are
+ *                                zero-indexed.
+ * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
+ *                          supported.
+ */
+@Experimental
+case class BoostingStrategy(
+    // Required boosting parameters
+    algo: Algo,
+    @BeanProperty var numEstimators: Int,
+    @BeanProperty var loss: Loss,
+    // Optional boosting parameters
+    @BeanProperty var learningRate: Double = 0.1,
+    @BeanProperty var subsamplingRate: Double = 1.0,
+    @BeanProperty var numClassesForClassification: Int = 2,
+    @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+    @BeanProperty var weakLearnerParams: Strategy) extends Serializable {
+
+  require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " +
+    s"$learningRate.")
+  require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " +
+    s"$learningRate.")
+
+  // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
+  weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo
+  weakLearnerParams.numClassesForClassification = numClassesForClassification
+  weakLearnerParams.subsamplingRate = subsamplingRate
+
+}
+
+@Experimental
+object BoostingStrategy {
+
+  /**
+   * Returns default configuration for the boosting algorithm
+   * @param algo Learning goal.  Supported:
+   *             [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+   *             [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+   * @return Configuration for boosting algorithm
+   */
+  def defaultParams(algo: Algo): BoostingStrategy = {
+    val treeStrategy = defaultWeakLearnerParams(algo)
+    algo match {
+      case Classification =>
+        new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy)
+      case Regression =>
+        new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy)
+      case _ =>
+        throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
+    }
+  }
+
+  /**
+   * Returns default configuration for the weak learner (decision tree) algorithm
+   * @param algo   Learning goal.  Supported:
+   *              [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+   *              [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+   * @return Configuration for weak learner
+   */
+  def defaultWeakLearnerParams(algo: Algo): Strategy = {
+    // Note: Regression tree used even for classification for GBT.
+    new Strategy(Regression, Variance, 3)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
new file mode 100644
index 0000000..82889dc
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: Experimental ::
+ * Enum to select ensemble combining strategy for base learners
+ */
+@DeveloperApi
+object EnsembleCombiningStrategy extends Enumeration {
+  type EnsembleCombiningStrategy = Value
+  val Sum, Average = Value
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index caaccbf..2ed63cf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.mllib.tree.configuration
 
+import scala.beans.BeanProperty
 import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental
@@ -43,7 +44,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
  *                for choosing how to split on features at each node.
  *                More bins give higher granularity.
  * @param quantileCalculationStrategy Algorithm for calculating quantiles.  Supported:
-   *                             [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
+ *                             [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
  * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
  *                                number of discrete values they take. For example, an entry (n ->
  *                                k) implies the feature n is categorical with k categories 0,
@@ -58,19 +59,21 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
  *                    this split will not be considered as a valid split.
  * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
  *                      256 MB.
+ * @param subsamplingRate Fraction of the training data used for learning decision tree.
  */
 @Experimental
 class Strategy (
     val algo: Algo,
-    val impurity: Impurity,
-    val maxDepth: Int,
-    val numClassesForClassification: Int = 2,
-    val maxBins: Int = 32,
-    val quantileCalculationStrategy: QuantileStrategy = Sort,
-    val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
-    val minInstancesPerNode: Int = 1,
-    val minInfoGain: Double = 0.0,
-    val maxMemoryInMB: Int = 256) extends Serializable {
+    @BeanProperty var impurity: Impurity,
+    @BeanProperty var maxDepth: Int,
+    @BeanProperty var numClassesForClassification: Int = 2,
+    @BeanProperty var maxBins: Int = 32,
+    @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
+    @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+    @BeanProperty var minInstancesPerNode: Int = 1,
+    @BeanProperty var minInfoGain: Double = 0.0,
+    @BeanProperty var maxMemoryInMB: Int = 256,
+    @BeanProperty var subsamplingRate: Double = 1) extends Serializable {
 
   if (algo == Classification) {
     require(numClassesForClassification >= 2)

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
index e7a2127..089010c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
@@ -21,13 +21,14 @@ import org.apache.commons.math3.distribution.PoissonDistribution
 
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
 
 /**
  * Internal representation of a datapoint which belongs to several subsamples of the same dataset,
  * particularly for bagging (e.g., for random forests).
  *
  * This holds one instance, as well as an array of weights which represent the (weighted)
- * number of times which this instance appears in each subsample.
+ * number of times which this instance appears in each subsamplingRate.
  * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
  * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
  *
@@ -44,22 +45,65 @@ private[tree] object BaggedPoint {
 
   /**
    * Convert an input dataset into its BaggedPoint representation,
-   * choosing subsample counts for each instance.
-   * Each subsample has the same number of instances as the original dataset,
-   * and is created by subsampling with replacement.
-   * @param input     Input dataset.
-   * @param numSubsamples  Number of subsamples of this RDD to take.
-   * @param seed   Random seed.
-   * @return  BaggedPoint dataset representation
+   * choosing subsamplingRate counts for each instance.
+   * Each subsamplingRate has the same number of instances as the original dataset,
+   * and is created by subsampling without replacement.
+   * @param input Input dataset.
+   * @param subsamplingRate Fraction of the training data used for learning decision tree.
+   * @param numSubsamples Number of subsamples of this RDD to take.
+   * @param withReplacement Sampling with/without replacement.
+   * @param seed Random seed.
+   * @return BaggedPoint dataset representation.
    */
-  def convertToBaggedRDD[Datum](
+  def convertToBaggedRDD[Datum] (
       input: RDD[Datum],
+      subsamplingRate: Double,
       numSubsamples: Int,
+      withReplacement: Boolean,
       seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
+    if (withReplacement) {
+      convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
+    } else {
+      if (numSubsamples == 1 && subsamplingRate == 1.0) {
+        convertToBaggedRDDWithoutSampling(input)
+      } else {
+        convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
+      }
+    }
+  }
+
+  private def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
+      input: RDD[Datum],
+      subsamplingRate: Double,
+      numSubsamples: Int,
+      seed: Int): RDD[BaggedPoint[Datum]] = {
+    input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+      // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+      val rng = new XORShiftRandom
+      rng.setSeed(seed + partitionIndex + 1)
+      instances.map { instance =>
+        val subsampleWeights = new Array[Double](numSubsamples)
+        var subsampleIndex = 0
+        while (subsampleIndex < numSubsamples) {
+          val x = rng.nextDouble()
+          subsampleWeights(subsampleIndex) = {
+            if (x < subsamplingRate) 1.0 else 0.0
+          }
+          subsampleIndex += 1
+        }
+        new BaggedPoint(instance, subsampleWeights)
+      }
+    }
+  }
+
+  private def convertToBaggedRDDSamplingWithReplacement[Datum] (
+      input: RDD[Datum],
+      subsample: Double,
+      numSubsamples: Int,
+      seed: Int): RDD[BaggedPoint[Datum]] = {
     input.mapPartitionsWithIndex { (partitionIndex, instances) =>
-      // TODO: Support different sampling rates, and sampling without replacement.
       // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
-      val poisson = new PoissonDistribution(1.0)
+      val poisson = new PoissonDistribution(subsample)
       poisson.reseedRandomGenerator(seed + partitionIndex + 1)
       instances.map { instance =>
         val subsampleWeights = new Array[Double](numSubsamples)
@@ -73,7 +117,8 @@ private[tree] object BaggedPoint {
     }
   }
 
-  def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
+  private def convertToBaggedRDDWithoutSampling[Datum] (
+      input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
     input.map(datum => new BaggedPoint(datum, Array(1.0)))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
new file mode 100644
index 0000000..d111ffe
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least absolute error loss calculation.
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: |y - F|
+ * Negative gradient: sign(y - F)
+ */
+@DeveloperApi
+object AbsoluteError extends Loss {
+
+  /**
+   * Method to calculate the gradients for the gradient boosting calculation for least
+   * absolute error calculation.
+   * @param model Model of the weak learner
+   * @param point Instance of the training dataset
+   * @return Loss gradient
+   */
+  override def gradient(
+      model: WeightedEnsembleModel,
+      point: LabeledPoint): Double = {
+    if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
+  }
+
+  /**
+   * Method to calculate error of the base learner for the gradient boosting calculation.
+   * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+   * purposes.
+   * @param model Model of the weak learner.
+   * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   * @return
+   */
+  override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+    val sumOfAbsolutes = data.map { y =>
+      val err = model.predict(y.features) - y.label
+      math.abs(err)
+    }.sum()
+    sumOfAbsolutes / data.count()
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
new file mode 100644
index 0000000..6f3d434
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least squares error loss calculation.
+ *
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: log(1 + exp(-2yF)), y in {-1, 1}
+ * Negative gradient: 2y / ( 1 + exp(2yF))
+ */
+@DeveloperApi
+object LogLoss extends Loss {
+
+  /**
+   * Method to calculate the loss gradients for the gradient boosting calculation for binary
+   * classification
+   * @param model Model of the weak learner
+   * @param point Instance of the training dataset
+   * @return Loss gradient
+   */
+  override def gradient(
+      model: WeightedEnsembleModel,
+      point: LabeledPoint): Double = {
+    val prediction = model.predict(point.features)
+    1.0 / (1.0 + math.exp(-prediction)) - point.label
+  }
+
+  /**
+   * Method to calculate error of the base learner for the gradient boosting calculation.
+   * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+   * purposes.
+   * @param model Model of the weak learner.
+   * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   * @return
+   */
+  override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+    val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count()
+    wrongPredictions / data.count
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
new file mode 100644
index 0000000..5580866
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
+ */
+@DeveloperApi
+trait Loss extends Serializable {
+
+  /**
+   * Method to calculate the gradients for the gradient boosting calculation.
+   * @param model Model of the weak learner.
+   * @param point Instance of the training dataset.
+   * @return Loss gradient.
+   */
+  def gradient(
+      model: WeightedEnsembleModel,
+      point: LabeledPoint): Double
+
+  /**
+   * Method to calculate error of the base learner for the gradient boosting calculation.
+   * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+   * purposes.
+   * @param model Model of the weak learner.
+   * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   * @return
+   */
+  def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
new file mode 100644
index 0000000..42c9ead
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+object Losses {
+
+  def fromString(name: String): Loss = name match {
+    case "leastSquaresError" => SquaredError
+    case "leastAbsoluteError" => AbsoluteError
+    case "logLoss" => LogLoss
+    case _ => throw new IllegalArgumentException(s"Did not recognize Loss name: $name")
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
new file mode 100644
index 0000000..4349fef
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least squares error loss calculation.
+ *
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: (y - F)**2/2
+ * Negative gradient: y - F
+ */
+@DeveloperApi
+object SquaredError extends Loss {
+
+  /**
+   * Method to calculate the gradients for the gradient boosting calculation for least
+   * squares error calculation.
+   * @param model Model of the weak learner
+   * @param point Instance of the training dataset
+   * @return Loss gradient
+   */
+  override def gradient(
+    model: WeightedEnsembleModel,
+    point: LabeledPoint): Double = {
+    model.predict(point.features) - point.label
+  }
+
+  /**
+   * Method to calculate error of the base learner for the gradient boosting calculation.
+   * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+   * purposes.
+   * @param model Model of the weak learner.
+   * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   * @return
+   */
+  override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+    data.map { y =>
+      val err = model.predict(y.features) - y.label
+      err * err
+    }.mean()
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
deleted file mode 100644
index 6a22e2a..0000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import scala.collection.mutable
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.rdd.RDD
-
-/**
- * :: Experimental ::
- * Random forest model for classification or regression.
- * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make
- * aggregate predictions.
- * @param trees Trees which make up this forest.  This cannot be empty.
- * @param algo algorithm type -- classification or regression
- */
-@Experimental
-class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable {
-
-  require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
-
-  /**
-   * Predict values for a single data point.
-   *
-   * @param features array representing a single data point
-   * @return Double prediction from the trained model
-   */
-  def predict(features: Vector): Double = {
-    algo match {
-      case Classification =>
-        val predictionToCount = new mutable.HashMap[Int, Int]()
-        trees.foreach { tree =>
-          val prediction = tree.predict(features).toInt
-          predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
-        }
-        predictionToCount.maxBy(_._2)._1
-      case Regression =>
-        trees.map(_.predict(features)).sum / trees.size
-    }
-  }
-
-  /**
-   * Predict values for the given data set.
-   *
-   * @param features RDD representing data points to be predicted
-   * @return RDD[Double] where each entry contains the corresponding prediction
-   */
-  def predict(features: RDD[Vector]): RDD[Double] = {
-    features.map(x => predict(x))
-  }
-
-  /**
-   * Get number of trees in forest.
-   */
-  def numTrees: Int = trees.size
-
-  /**
-   * Get total number of nodes, summed over all trees in the forest.
-   */
-  def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum
-
-  /**
-   * Print a summary of the model.
-   */
-  override def toString: String = algo match {
-    case Classification =>
-      s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes"
-    case Regression =>
-      s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes"
-    case _ => throw new IllegalArgumentException(
-      s"RandomForestModel given unknown algo parameter: $algo.")
-  }
-
-  /**
-   * Print the full model to a string.
-   */
-  def toDebugString: String = {
-    val header = toString + "\n"
-    header + trees.zipWithIndex.map { case (tree, treeIndex) =>
-      s"  Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
-    }.fold("")(_ + _)
-  }
-
-}
-
-private[tree] object RandomForestModel {
-
-  def build(trees: Array[DecisionTreeModel]): RandomForestModel = {
-    require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
-    val algo: Algo = trees(0).algo
-    require(trees.forall(_.algo == algo),
-      "RandomForestModel cannot combine trees which have different output types" +
-      " (classification/regression).")
-    new RandomForestModel(trees, algo)
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
new file mode 100644
index 0000000..7b052d9
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.rdd.RDD
+
+import scala.collection.mutable
+
+@Experimental
+class WeightedEnsembleModel(
+    val weakHypotheses: Array[DecisionTreeModel],
+    val weakHypothesisWeights: Array[Double],
+    val algo: Algo,
+    val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
+
+  require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" +
+    s". Number of weakHypotheses = $weakHypotheses")
+
+  /**
+   * Predict values for a single data point using the model trained.
+   *
+   * @param features array representing a single data point
+   * @return predicted category from the trained model
+   */
+  private def predictRaw(features: Vector): Double = {
+    val treePredictions = weakHypotheses.map(learner => learner.predict(features))
+    if (numWeakHypotheses == 1){
+      treePredictions(0)
+    } else {
+      var prediction = treePredictions(0)
+      var index = 1
+      while (index < numWeakHypotheses) {
+        prediction += weakHypothesisWeights(index) * treePredictions(index)
+        index += 1
+      }
+      prediction
+    }
+  }
+
+  /**
+   * Predict values for a single data point using the model trained.
+   *
+   * @param features array representing a single data point
+   * @return predicted category from the trained model
+   */
+  private def predictBySumming(features: Vector): Double = {
+    algo match {
+      case Regression => predictRaw(features)
+      case Classification => {
+        // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
+        if (predictRaw(features) > 0 ) 1.0 else 0.0
+      }
+      case _ => throw new IllegalArgumentException(
+        s"WeightedEnsembleModel given unknown algo parameter: $algo.")
+    }
+  }
+
+  /**
+   * Predict values for a single data point.
+   *
+   * @param features array representing a single data point
+   * @return Double prediction from the trained model
+   */
+  private def predictByAveraging(features: Vector): Double = {
+    algo match {
+      case Classification =>
+        val predictionToCount = new mutable.HashMap[Int, Int]()
+        weakHypotheses.foreach { learner =>
+          val prediction = learner.predict(features).toInt
+          predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
+        }
+        predictionToCount.maxBy(_._2)._1
+      case Regression =>
+        weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size
+    }
+  }
+
+
+  /**
+   * Predict values for a single data point using the model trained.
+   *
+   * @param features array representing a single data point
+   * @return predicted category from the trained model
+   */
+  def predict(features: Vector): Double = {
+    combiningStrategy match {
+      case Sum => predictBySumming(features)
+      case Average => predictByAveraging(features)
+      case _ => throw new IllegalArgumentException(
+        s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.")
+    }
+  }
+
+  /**
+   * Predict values for the given data set.
+   *
+   * @param features RDD representing data points to be predicted
+   * @return RDD[Double] where each entry contains the corresponding prediction
+   */
+  def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
+
+  /**
+   * Print a summary of the model.
+   */
+  override def toString: String = {
+    algo match {
+      case Classification =>
+        s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n"
+      case Regression =>
+        s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n"
+      case _ => throw new IllegalArgumentException(
+        s"WeightedEnsembleModel given unknown algo parameter: $algo.")
+    }
+  }
+
+  /**
+   * Print the full model to a string.
+   */
+  def toDebugString: String = {
+    val header = toString + "\n"
+    header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) =>
+      s"  Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+    }.fold("")(_ + _)
+  }
+
+  /**
+   * Get number of trees in forest.
+   */
+  def numWeakHypotheses: Int = weakHypotheses.size
+
+  // TODO: Remove these helpers methods once class is generalized to support any base learning
+  // algorithms.
+
+  /**
+   * Get total number of nodes, summed over all trees in the forest.
+   */
+  def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 8fc5e11..c579cb5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -493,7 +493,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(rootNode1.rightNode.nonEmpty)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
 
     // Single group second level tree construction.
     val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
@@ -786,7 +786,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
 
     val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
 
     val topNode = Node.emptyNode(nodeIndex = 1)
     assert(topNode.predict.predict === Double.MinValue)
@@ -829,7 +829,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
 
     val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+    val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
 
     val topNode = Node.emptyNode(nodeIndex = 1)
     assert(topNode.predict.predict === Double.MinValue)

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
new file mode 100644
index 0000000..effb7b8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.util.StatCounter
+
+import scala.collection.mutable
+
+object EnsembleTestHelper {
+
+  /**
+   * Aggregates all values in data, and tests whether the empirical mean and stddev are within
+   * epsilon of the expected values.
+   * @param data  Every element of the data should be an i.i.d. sample from some distribution.
+   */
+  def testRandomArrays(
+      data: Array[Array[Double]],
+      numCols: Int,
+      expectedMean: Double,
+      expectedStddev: Double,
+      epsilon: Double) {
+    val values = new mutable.ArrayBuffer[Double]()
+    data.foreach { row =>
+      assert(row.size == numCols)
+      values ++= row
+    }
+    val stats = new StatCounter(values)
+    assert(math.abs(stats.mean - expectedMean) < epsilon)
+    assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+  }
+
+  def validateClassifier(
+      model: WeightedEnsembleModel,
+      input: Seq[LabeledPoint],
+      requiredAccuracy: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+      prediction != expected.label
+    }
+    val accuracy = (input.length - numOffPredictions).toDouble / input.length
+    assert(accuracy >= requiredAccuracy,
+      s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+  }
+
+  def validateRegressor(
+      model: WeightedEnsembleModel,
+      input: Seq[LabeledPoint],
+      requiredMSE: Double) {
+    val predictions = input.map(x => model.predict(x.features))
+    val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+      val err = prediction - expected.label
+      err * err
+    }.sum
+    val mse = squaredError / input.length
+    assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+  }
+
+  def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = {
+    val arr = new Array[LabeledPoint](numInstances)
+    for (i <- 0 until numInstances) {
+      val label = if (i < numInstances / 10) {
+        0.0
+      } else if (i < numInstances / 2) {
+        1.0
+      } else if (i < numInstances * 0.9) {
+        0.0
+      } else {
+        1.0
+      }
+      val features = Array.fill[Double](numFeatures)(i.toDouble)
+      arr(i) = new LabeledPoint(label, Vectors.dense(features))
+    }
+    arr
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/86021955/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
new file mode 100644
index 0000000..970fff8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
+import org.apache.spark.mllib.tree.impurity.{Variance, Gini}
+import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+
+import org.apache.spark.mllib.util.LocalSparkContext
+
+/**
+ * Test suite for [[GradientBoosting]].
+ */
+class GradientBoostingSuite extends FunSuite with LocalSparkContext {
+
+  test("Regression with continuous features: SquaredError") {
+
+    GradientBoostingSuite.testCombinations.foreach {
+      case (numEstimators, learningRate, subsamplingRate) =>
+        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+        val rdd = sc.parallelize(arr)
+        val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+        val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+          numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+          subsamplingRate = subsamplingRate)
+
+        val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+        val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
+          subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+
+        val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
+        assert(gbt.weakHypotheses.size === numEstimators)
+        val gbtTree = gbt.weakHypotheses(0)
+
+        EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
+
+        // Make sure trees are the same.
+        assert(gbtTree.toString == dt.toString)
+    }
+  }
+
+  test("Regression with continuous features: Absolute Error") {
+
+    GradientBoostingSuite.testCombinations.foreach {
+      case (numEstimators, learningRate, subsamplingRate) =>
+        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+        val rdd = sc.parallelize(arr)
+        val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+        val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+          numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+          subsamplingRate = subsamplingRate)
+
+        val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+        val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
+          subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+
+        val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
+        assert(gbt.weakHypotheses.size === numEstimators)
+        val gbtTree = gbt.weakHypotheses(0)
+
+        EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
+
+        // Make sure trees are the same.
+        assert(gbtTree.toString == dt.toString)
+    }
+  }
+
+
+  test("Binary classification with continuous features: Log Loss") {
+
+    GradientBoostingSuite.testCombinations.foreach {
+      case (numEstimators, learningRate, subsamplingRate) =>
+        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+        val rdd = sc.parallelize(arr)
+        val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+        val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+          numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+          subsamplingRate = subsamplingRate)
+
+        val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+        val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss,
+          subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+
+        val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy)
+        assert(gbt.weakHypotheses.size === numEstimators)
+        val gbtTree = gbt.weakHypotheses(0)
+
+        EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
+
+        // Make sure trees are the same.
+        assert(gbtTree.toString == dt.toString)
+    }
+  }
+
+}
+
+object GradientBoostingSuite {
+
+  // Combinations for estimators, learning rates and subsamplingRate
+  val testCombinations
+    = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org