You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2014/11/26 05:10:18 UTC

spark git commit: [SPARK-4583] [mllib] LogLoss for GradientBoostedTrees fix + doc updates

Repository: spark
Updated Branches:
  refs/heads/master 7eba0fbe4 -> c251fd740


[SPARK-4583] [mllib] LogLoss for GradientBoostedTrees fix + doc updates

Currently, the LogLoss used by GradientBoostedTrees has 2 issues:
* the gradient (and therefore loss) does not match that used by Friedman (1999)
* the error computation uses 0/1 accuracy, not log loss

This PR updates LogLoss.
It also adds some doc for boosting and forests.

I tested it on sample data and made sure the log loss is monotonically decreasing with each boosting iteration.

CC: mengxr manishamde codedeft

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #3439 from jkbradley/gbt-loss-fix and squashes the following commits:

cfec17e [Joseph K. Bradley] removed forgotten temp comments
a27eb6d [Joseph K. Bradley] corrections to last log loss commit
ed5da2c [Joseph K. Bradley] updated LogLoss (boosting) for numerical stability
5e52bff [Joseph K. Bradley] * Removed the 1/2 from SquaredError.  This also required updating the test suite since it effectively doubles the gradient and loss. * Added doc for developers within RandomForest. * Small cleanup in test suite (generating data only once)
e57897a [Joseph K. Bradley] Fixed LogLoss for GradientBoostedTrees, and updated doc for losses, forests, and boosting


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c251fd74
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c251fd74
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c251fd74

Branch: refs/heads/master
Commit: c251fd7405db57d3ab2686c38712601fd8f13ccd
Parents: 7eba0fb
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Tue Nov 25 20:10:15 2014 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Tue Nov 25 20:10:15 2014 -0800

----------------------------------------------------------------------
 .../spark/mllib/tree/GradientBoostedTrees.scala | 18 ++---
 .../apache/spark/mllib/tree/RandomForest.scala  | 44 +++++++++++-
 .../spark/mllib/tree/loss/AbsoluteError.scala   | 26 ++++---
 .../apache/spark/mllib/tree/loss/LogLoss.scala  | 34 +++++----
 .../spark/mllib/tree/loss/SquaredError.scala    | 22 +++---
 .../mllib/tree/GradientBoostedTreesSuite.scala  | 74 +++++++++++++-------
 6 files changed, 146 insertions(+), 72 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c251fd74/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index cb4ddfc..61f6b13 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -31,18 +31,20 @@ import org.apache.spark.storage.StorageLevel
 
 /**
  * :: Experimental ::
- * A class that implements Stochastic Gradient Boosting for regression and binary classification.
+ * A class that implements
+ * [[http://en.wikipedia.org/wiki/Gradient_boosting  Stochastic Gradient Boosting]]
+ * for regression and binary classification.
  *
  * The implementation is based upon:
  *   J.H. Friedman.  "Stochastic Gradient Boosting."  1999.
  *
- * Notes:
- *  - This currently can be run with several loss functions.  However, only SquaredError is
- *    fully supported.  Specifically, the loss function should be used to compute the gradient
- *    (to re-label training instances on each iteration) and to weight weak hypotheses.
- *    Currently, gradients are computed correctly for the available loss functions,
- *    but weak hypothesis weights are not computed correctly for LogLoss or AbsoluteError.
- *    Running with those losses will likely behave reasonably, but lacks the same guarantees.
+ * Notes on Gradient Boosting vs. TreeBoost:
+ *  - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.
+ *  - Both algorithms learn tree ensembles by minimizing loss functions.
+ *  - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes
+ *    based on the loss function, whereas the original gradient boosting method does not.
+ *     - When the loss is SquaredError, these methods give the same result, but they could differ
+ *       for other loss functions.
  *
  * @param boostingStrategy Parameters for the gradient boosting algorithm.
  */

http://git-wip-us.apache.org/repos/asf/spark/blob/c251fd74/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 3ae6fa2..482d339 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -37,7 +37,8 @@ import org.apache.spark.util.Utils
 
 /**
  * :: Experimental ::
- * A class which implements a random forest learning algorithm for classification and regression.
+ * A class that implements a [[http://en.wikipedia.org/wiki/Random_forest  Random Forest]]
+ * learning algorithm for classification and regression.
  * It supports both continuous and categorical features.
  *
  * The settings for featureSubsetStrategy are based on the following references:
@@ -70,6 +71,47 @@ private class RandomForest (
     private val seed: Int)
   extends Serializable with Logging {
 
+  /*
+     ALGORITHM
+     This is a sketch of the algorithm to help new developers.
+
+     The algorithm partitions data by instances (rows).
+     On each iteration, the algorithm splits a set of nodes.  In order to choose the best split
+     for a given node, sufficient statistics are collected from the distributed data.
+     For each node, the statistics are collected to some worker node, and that worker selects
+     the best split.
+
+     This setup requires discretization of continuous features.  This binning is done in the
+     findSplitsBins() method during initialization, after which each continuous feature becomes
+     an ordered discretized feature with at most maxBins possible values.
+
+     The main loop in the algorithm operates on a queue of nodes (nodeQueue).  These nodes
+     lie at the periphery of the tree being trained.  If multiple trees are being trained at once,
+     then this queue contains nodes from all of them.  Each iteration works roughly as follows:
+       On the master node:
+         - Some number of nodes are pulled off of the queue (based on the amount of memory
+           required for their sufficient statistics).
+         - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate
+           features are chosen for each node.  See method selectNodesToSplit().
+       On worker nodes, via method findBestSplits():
+         - The worker makes one pass over its subset of instances.
+         - For each (tree, node, feature, split) tuple, the worker collects statistics about
+           splitting.  Note that the set of (tree, node) pairs is limited to the nodes selected
+           from the queue for this iteration.  The set of features considered can also be limited
+           based on featureSubsetStrategy.
+         - For each node, the statistics for that node are aggregated to a particular worker
+           via reduceByKey().  The designated worker chooses the best (feature, split) pair,
+           or chooses to stop splitting if the stopping criteria are met.
+       On the master node:
+         - The master collects all decisions about splitting nodes and updates the model.
+         - The updated model is passed to the workers on the next iteration.
+     This process continues until the node queue is empty.
+
+     Most of the methods in this implementation support the statistics aggregation, which is
+     the heaviest part of the computation.  In general, this implementation is bound by either
+     the cost of statistics computation on workers or by communicating the sufficient statistics.
+   */
+
   strategy.assertValid()
   require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.")
   require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy),

http://git-wip-us.apache.org/repos/asf/spark/blob/c251fd74/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index e828866..d1bde15 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.mllib.tree.loss
 
-import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.model.TreeEnsembleModel
@@ -25,11 +24,11 @@ import org.apache.spark.rdd.RDD
 
 /**
  * :: DeveloperApi ::
- * Class for least absolute error loss calculation.
- * The features x and the corresponding label y is predicted using the function F.
- * For each instance:
- * Loss: |y - F|
- * Negative gradient: sign(y - F)
+ * Class for absolute error loss calculation (for regression).
+ *
+ * The absolute (L1) error is defined as:
+ *  |y - F(x)|
+ * where y is the label and F(x) is the model prediction for features x.
  */
 @DeveloperApi
 object AbsoluteError extends Loss {
@@ -37,7 +36,8 @@ object AbsoluteError extends Loss {
   /**
    * Method to calculate the gradients for the gradient boosting calculation for least
    * absolute error calculation.
-   * @param model Model of the weak learner
+   * The gradient with respect to F(x) is: sign(F(x) - y)
+   * @param model Ensemble model
    * @param point Instance of the training dataset
    * @return Loss gradient
    */
@@ -48,19 +48,17 @@ object AbsoluteError extends Loss {
   }
 
   /**
-   * Method to calculate error of the base learner for the gradient boosting calculation.
+   * Method to calculate loss of the base learner for the gradient boosting calculation.
    * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
    * purposes.
-   * @param model Model of the weak learner.
+   * @param model Ensemble model
    * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-   * @return
+   * @return  Mean absolute error of model on data
    */
   override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
-    val sumOfAbsolutes = data.map { y =>
+    data.map { y =>
       val err = model.predict(y.features) - y.label
       math.abs(err)
-    }.sum()
-    sumOfAbsolutes / data.count()
+    }.mean()
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c251fd74/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 8b8adb4..7ce9fa6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -24,12 +24,12 @@ import org.apache.spark.rdd.RDD
 
 /**
  * :: DeveloperApi ::
- * Class for least squares error loss calculation.
+ * Class for log loss calculation (for classification).
+ * This uses twice the binomial negative log likelihood, called "deviance" in Friedman (1999).
  *
- * The features x and the corresponding label y is predicted using the function F.
- * For each instance:
- * Loss: log(1 + exp(-2yF)), y in {-1, 1}
- * Negative gradient: 2y / ( 1 + exp(2yF))
+ * The log loss is defined as:
+ *   2 log(1 + exp(-2 y F(x)))
+ * where y is a label in {-1, 1} and F(x) is the model prediction for features x.
  */
 @DeveloperApi
 object LogLoss extends Loss {
@@ -37,7 +37,8 @@ object LogLoss extends Loss {
   /**
    * Method to calculate the loss gradients for the gradient boosting calculation for binary
    * classification
-   * @param model Model of the weak learner
+   * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
+   * @param model Ensemble model
    * @param point Instance of the training dataset
    * @return Loss gradient
    */
@@ -45,19 +46,28 @@ object LogLoss extends Loss {
       model: TreeEnsembleModel,
       point: LabeledPoint): Double = {
     val prediction = model.predict(point.features)
-    1.0 / (1.0 + math.exp(-prediction)) - point.label
+    - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
   }
 
   /**
-   * Method to calculate error of the base learner for the gradient boosting calculation.
+   * Method to calculate loss of the base learner for the gradient boosting calculation.
    * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
    * purposes.
-   * @param model Model of the weak learner.
+   * @param model Ensemble model
    * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-   * @return
+   * @return Mean log loss of model on data
    */
   override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
-    val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count()
-    wrongPredictions / data.count
+    data.map { case point =>
+      val prediction = model.predict(point.features)
+      val margin = 2.0 * point.label * prediction
+      // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically
+      // stable.
+      if (margin >= 0) {
+        2.0 * math.log1p(math.exp(-margin))
+      } else {
+        2.0 * (-margin + math.log1p(math.exp(margin)))
+      }
+    }.mean()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c251fd74/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index cfe395b..50ecaa2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -17,7 +17,6 @@
 
 package org.apache.spark.mllib.tree.loss
 
-import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.model.TreeEnsembleModel
@@ -25,12 +24,11 @@ import org.apache.spark.rdd.RDD
 
 /**
  * :: DeveloperApi ::
- * Class for least squares error loss calculation.
+ * Class for squared error loss calculation.
  *
- * The features x and the corresponding label y is predicted using the function F.
- * For each instance:
- * Loss: (y - F)**2/2
- * Negative gradient: y - F
+ * The squared (L2) error is defined as:
+ *   (y - F(x))**2
+ * where y is the label and F(x) is the model prediction for features x.
  */
 @DeveloperApi
 object SquaredError extends Loss {
@@ -38,23 +36,24 @@ object SquaredError extends Loss {
   /**
    * Method to calculate the gradients for the gradient boosting calculation for least
    * squares error calculation.
-   * @param model Model of the weak learner
+   * The gradient with respect to F(x) is: - 2 (y - F(x))
+   * @param model Ensemble model
    * @param point Instance of the training dataset
    * @return Loss gradient
    */
   override def gradient(
     model: TreeEnsembleModel,
     point: LabeledPoint): Double = {
-    model.predict(point.features) - point.label
+    2.0 * (model.predict(point.features) - point.label)
   }
 
   /**
-   * Method to calculate error of the base learner for the gradient boosting calculation.
+   * Method to calculate loss of the base learner for the gradient boosting calculation.
    * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
    * purposes.
-   * @param model Model of the weak learner.
+   * @param model Ensemble model
    * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
-   * @return
+   * @return  Mean squared error of model on data
    */
   override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = {
     data.map { y =>
@@ -62,5 +61,4 @@ object SquaredError extends Loss {
       err * err
     }.mean()
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c251fd74/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index f3f8eff..d4d54cf 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -35,32 +35,39 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
   test("Regression with continuous features: SquaredError") {
     GradientBoostedTreesSuite.testCombinations.foreach {
       case (numIterations, learningRate, subsamplingRate) =>
-        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
-        val rdd = sc.parallelize(arr, 2)
-
-        val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
-          categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
-        val boostingStrategy =
-          new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
-
-        val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
-
-        assert(gbt.trees.size === numIterations)
-        EnsembleTestHelper.validateRegressor(gbt, arr, 0.03)
-
-        val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
-        val dt = DecisionTree.train(remappedInput, treeStrategy)
-
-        // Make sure trees are the same.
-        assert(gbt.trees.head.toString == dt.toString)
+        GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed =>
+          val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
+
+          val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+            categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
+          val boostingStrategy =
+            new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
+
+          val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+          assert(gbt.trees.size === numIterations)
+          try {
+            EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
+          } catch {
+            case e: java.lang.AssertionError =>
+              println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+                s" subsamplingRate=$subsamplingRate")
+              throw e
+          }
+
+          val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+          val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+          // Make sure trees are the same.
+          assert(gbt.trees.head.toString == dt.toString)
+        }
     }
   }
 
   test("Regression with continuous features: Absolute Error") {
     GradientBoostedTreesSuite.testCombinations.foreach {
       case (numIterations, learningRate, subsamplingRate) =>
-        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
-        val rdd = sc.parallelize(arr, 2)
+        val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
 
         val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
           categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
@@ -70,7 +77,14 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
         val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
 
         assert(gbt.trees.size === numIterations)
-        EnsembleTestHelper.validateRegressor(gbt, arr, 0.85, "mae")
+        try {
+          EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae")
+        } catch {
+          case e: java.lang.AssertionError =>
+            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+              s" subsamplingRate=$subsamplingRate")
+            throw e
+        }
 
         val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
         val dt = DecisionTree.train(remappedInput, treeStrategy)
@@ -83,8 +97,7 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
   test("Binary classification with continuous features: Log Loss") {
     GradientBoostedTreesSuite.testCombinations.foreach {
       case (numIterations, learningRate, subsamplingRate) =>
-        val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
-        val rdd = sc.parallelize(arr, 2)
+        val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
 
         val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2,
           numClassesForClassification = 2, categoricalFeaturesInfo = Map.empty,
@@ -95,7 +108,14 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
         val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
 
         assert(gbt.trees.size === numIterations)
-        EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
+        try {
+          EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9)
+        } catch {
+          case e: java.lang.AssertionError =>
+            println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+              s" subsamplingRate=$subsamplingRate")
+            throw e
+        }
 
         val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
         val ensembleStrategy = treeStrategy.copy
@@ -113,5 +133,9 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
 object GradientBoostedTreesSuite {
 
   // Combinations for estimators, learning rates and subsamplingRate
-  val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
+  val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
+
+  val randomSeeds = Array(681283, 4398)
+
+  val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org