You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/04/14 00:36:39 UTC

spark git commit: [SPARK-5972] [MLlib] Cache residuals and gradient in GBT during training and validation

Repository: spark
Updated Branches:
  refs/heads/master 3a205bbd9 -> 2a55cb41b


[SPARK-5972] [MLlib] Cache residuals and gradient in GBT during training and validation

The previous PR https://github.com/apache/spark/pull/4906 helped to extract the learning curve giving the error for each iteration. This continues the work refactoring some code and extending the same logic during training and validation.

Author: MechCoder <ma...@gmail.com>

Closes #5330 from MechCoder/spark-5972 and squashes the following commits:

0b5d659 [MechCoder] minor
32d409d [MechCoder] EvaluateeachIteration and training cache should follow different paths
d542bb0 [MechCoder] Remove unused imports and docs
58f4932 [MechCoder] Remove unpersist
70d3b4c [MechCoder] Broadcast for each tree
5869533 [MechCoder] Access broadcasted values locally and other minor changes
923dbf6 [MechCoder] [SPARK-5972] Cache residuals and gradient in GBT during training and validation


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a55cb41
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a55cb41
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a55cb41

Branch: refs/heads/master
Commit: 2a55cb41bf7da1786be2c76b8af398da8fedb44b
Parents: 3a205bb
Author: MechCoder <ma...@gmail.com>
Authored: Mon Apr 13 15:36:33 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Apr 13 15:36:33 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/tree/GradientBoostedTrees.scala | 42 +++++++----
 .../spark/mllib/tree/loss/AbsoluteError.scala   | 10 +--
 .../apache/spark/mllib/tree/loss/LogLoss.scala  | 11 +--
 .../org/apache/spark/mllib/tree/loss/Loss.scala |  8 +-
 .../spark/mllib/tree/loss/SquaredError.scala    | 10 +--
 .../mllib/tree/model/treeEnsembleModels.scala   | 77 ++++++++++++++++----
 6 files changed, 105 insertions(+), 53 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2a55cb41/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index a9c93e1..c02c79f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -157,7 +157,6 @@ object GradientBoostedTrees extends Logging {
       validationInput: RDD[LabeledPoint],
       boostingStrategy: BoostingStrategy,
       validate: Boolean): GradientBoostedTreesModel = {
-
     val timer = new TimeTracker()
     timer.start("total")
     timer.start("init")
@@ -192,20 +191,29 @@ object GradientBoostedTrees extends Logging {
     // Initialize tree
     timer.start("building tree 0")
     val firstTreeModel = new DecisionTree(treeStrategy).run(data)
+    val firstTreeWeight = 1.0
     baseLearners(0) = firstTreeModel
-    baseLearnerWeights(0) = 1.0
-    val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
-    logDebug("error of gbt = " + loss.computeError(startingModel, input))
+    baseLearnerWeights(0) = firstTreeWeight
+    val startingModel = new GradientBoostedTreesModel(
+      Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))
+
+    var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
+      computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
+    logDebug("error of gbt = " + predError.values.mean())
 
     // Note: A model of type regression is used since we require raw prediction
     timer.stop("building tree 0")
 
-    var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0
+    var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
+      computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
+    var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
     var bestM = 1
 
-    // psuedo-residual for second iteration
-    data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
-      point.features))
+    // pseudo-residual for second iteration
+    data = predError.zip(input).map { case ((pred, _), point) =>
+      LabeledPoint(-loss.gradient(pred, point.label), point.features)
+    }
+
     var m = 1
     while (m < numIterations) {
       timer.start(s"building tree $m")
@@ -222,15 +230,22 @@ object GradientBoostedTrees extends Logging {
       baseLearnerWeights(m) = learningRate
       // Note: A model of type regression is used since we require raw prediction
       val partialModel = new GradientBoostedTreesModel(
-        Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
-      logDebug("error of gbt = " + loss.computeError(partialModel, input))
+        Regression, baseLearners.slice(0, m + 1),
+        baseLearnerWeights.slice(0, m + 1))
+
+      predError = GradientBoostedTreesModel.updatePredictionError(
+        input, predError, baseLearnerWeights(m), baseLearners(m), loss)
+      logDebug("error of gbt = " + predError.values.mean())
 
       if (validate) {
         // Stop training early if
         // 1. Reduction in error is less than the validationTol or
         // 2. If the error increases, that is if the model is overfit.
         // We want the model returned corresponding to the best validation error.
-        val currentValidateError = loss.computeError(partialModel, validationInput)
+
+        validatePredError = GradientBoostedTreesModel.updatePredictionError(
+          validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
+        val currentValidateError = validatePredError.values.mean()
         if (bestValidateError - currentValidateError < validationTol) {
           return new GradientBoostedTreesModel(
             boostingStrategy.treeStrategy.algo,
@@ -242,8 +257,9 @@ object GradientBoostedTrees extends Logging {
         }
       }
       // Update data with pseudo-residuals
-      data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
-        point.features))
+      data = predError.zip(input).map { case ((pred, _), point) =>
+        LabeledPoint(-loss.gradient(pred, point.label), point.features)
+      }
       m += 1
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2a55cb41/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index 793dd66..6f570b4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -37,14 +37,12 @@ object AbsoluteError extends Loss {
    * Method to calculate the gradients for the gradient boosting calculation for least
    * absolute error calculation.
    * The gradient with respect to F(x) is: sign(F(x) - y)
-   * @param model Ensemble model
-   * @param point Instance of the training dataset
+   * @param prediction Predicted label.
+   * @param label True label.
    * @return Loss gradient
    */
-  override def gradient(
-      model: TreeEnsembleModel,
-      point: LabeledPoint): Double = {
-    if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
+  override def gradient(prediction: Double, label: Double): Double = {
+    if (label - prediction < 0) 1.0 else -1.0
   }
 
   override def computeError(prediction: Double, label: Double): Double = {

http://git-wip-us.apache.org/repos/asf/spark/blob/2a55cb41/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 51b1aed..24ee9f3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -39,15 +39,12 @@ object LogLoss extends Loss {
    * Method to calculate the loss gradients for the gradient boosting calculation for binary
    * classification
    * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x)))
-   * @param model Ensemble model
-   * @param point Instance of the training dataset
+   * @param prediction Predicted label.
+   * @param label True label.
    * @return Loss gradient
    */
-  override def gradient(
-      model: TreeEnsembleModel,
-      point: LabeledPoint): Double = {
-    val prediction = model.predict(point.features)
-    - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction))
+  override def gradient(prediction: Double, label: Double): Double = {
+    - 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
   }
 
   override def computeError(prediction: Double, label: Double): Double = {

http://git-wip-us.apache.org/repos/asf/spark/blob/2a55cb41/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index 357869f..d3b82b7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -31,13 +31,11 @@ trait Loss extends Serializable {
 
   /**
    * Method to calculate the gradients for the gradient boosting calculation.
-   * @param model Model of the weak learner.
-   * @param point Instance of the training dataset.
+   * @param prediction Predicted feature
+   * @param label true label.
    * @return Loss gradient.
    */
-  def gradient(
-      model: TreeEnsembleModel,
-      point: LabeledPoint): Double
+  def gradient(prediction: Double, label: Double): Double
 
   /**
    * Method to calculate error of the base learner for the gradient boosting calculation.

http://git-wip-us.apache.org/repos/asf/spark/blob/2a55cb41/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index b990707..58857ae 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -37,14 +37,12 @@ object SquaredError extends Loss {
    * Method to calculate the gradients for the gradient boosting calculation for least
    * squares error calculation.
    * The gradient with respect to F(x) is: - 2 (y - F(x))
-   * @param model Ensemble model
-   * @param point Instance of the training dataset
+   * @param prediction Predicted label.
+   * @param label True label.
    * @return Loss gradient
    */
-  override def gradient(
-    model: TreeEnsembleModel,
-    point: LabeledPoint): Double = {
-    2.0 * (model.predict(point.features) - point.label)
+  override def gradient(prediction: Double, label: Double): Double = {
+    2.0 * (prediction - label)
   }
 
   override def computeError(prediction: Double, label: Double): Double = {

http://git-wip-us.apache.org/repos/asf/spark/blob/2a55cb41/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 1950254..fef3d2a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -130,35 +130,28 @@ class GradientBoostedTreesModel(
 
     val numIterations = trees.length
     val evaluationArray = Array.fill(numIterations)(0.0)
+    val localTreeWeights = treeWeights
+
+    var predictionAndError = GradientBoostedTreesModel.computeInitialPredictionAndError(
+      remappedData, localTreeWeights(0), trees(0), loss)
 
-    var predictionAndError: RDD[(Double, Double)] = remappedData.map { i =>
-      val pred = treeWeights(0) * trees(0).predict(i.features)
-      val error = loss.computeError(pred, i.label)
-      (pred, error)
-    }
     evaluationArray(0) = predictionAndError.values.mean()
 
-    // Avoid the model being copied across numIterations.
     val broadcastTrees = sc.broadcast(trees)
-    val broadcastWeights = sc.broadcast(treeWeights)
-
     (1 until numIterations).map { nTree =>
       predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
         val currentTree = broadcastTrees.value(nTree)
-        val currentTreeWeight = broadcastWeights.value(nTree)
-        iter.map {
-          case (point, (pred, error)) => {
-            val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
-            val newError = loss.computeError(newPred, point.label)
-            (newPred, newError)
-          }
+        val currentTreeWeight = localTreeWeights(nTree)
+        iter.map { case (point, (pred, error)) =>
+          val newPred = pred + currentTree.predict(point.features) * currentTreeWeight
+          val newError = loss.computeError(newPred, point.label)
+          (newPred, newError)
         }
       }
       evaluationArray(nTree) = predictionAndError.values.mean()
     }
 
     broadcastTrees.unpersist()
-    broadcastWeights.unpersist()
     evaluationArray
   }
 
@@ -166,6 +159,58 @@ class GradientBoostedTreesModel(
 
 object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
 
+  /**
+   * Compute the initial predictions and errors for a dataset for the first
+   * iteration of gradient boosting.
+   * @param data: training data.
+   * @param initTreeWeight: learning rate assigned to the first tree.
+   * @param initTree: first DecisionTreeModel.
+   * @param loss: evaluation metric.
+   * @return a RDD with each element being a zip of the prediction and error
+   *         corresponding to every sample.
+   */
+  def computeInitialPredictionAndError(
+      data: RDD[LabeledPoint],
+      initTreeWeight: Double,
+      initTree: DecisionTreeModel,
+      loss: Loss): RDD[(Double, Double)] = {
+    data.map { lp =>
+      val pred = initTreeWeight * initTree.predict(lp.features)
+      val error = loss.computeError(pred, lp.label)
+      (pred, error)
+    }
+  }
+
+  /**
+   * Update a zipped predictionError RDD
+   * (as obtained with computeInitialPredictionAndError)
+   * @param data: training data.
+   * @param predictionAndError: predictionError RDD
+   * @param treeWeight: Learning rate.
+   * @param tree: Tree using which the prediction and error should be updated.
+   * @param loss: evaluation metric.
+   * @return a RDD with each element being a zip of the prediction and error
+   *         corresponding to each sample.
+   */
+  def updatePredictionError(
+    data: RDD[LabeledPoint],
+    predictionAndError: RDD[(Double, Double)],
+    treeWeight: Double,
+    tree: DecisionTreeModel,
+    loss: Loss): RDD[(Double, Double)] = {
+
+    val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
+      iter.map {
+        case (lp, (pred, error)) => {
+          val newPred = pred + tree.predict(lp.features) * treeWeight
+          val newError = loss.computeError(newPred, lp.label)
+          (newPred, newError)
+        }
+      }
+    }
+    newPredError
+  }
+
   override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
     val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
     val classNameV1_0 = SaveLoadV1_0.thisClassName


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org