You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/05/09 18:09:24 UTC

spark git commit: [SPARK-14682][ML] Provide evaluateEachIteration method or equivalent for spark.ml GBTs

Repository: spark
Updated Branches:
  refs/heads/master 628c7b517 -> 7aaa148f5


[SPARK-14682][ML] Provide evaluateEachIteration method or equivalent for spark.ml GBTs

## What changes were proposed in this pull request?

Provide evaluateEachIteration method or equivalent for spark.ml GBTs.

## How was this patch tested?

UT.

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: WeichenXu <we...@databricks.com>

Closes #21097 from WeichenXu123/GBTeval.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7aaa148f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7aaa148f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7aaa148f

Branch: refs/heads/master
Commit: 7aaa148f593470b2c32221b69097b8b54524eb74
Parents: 628c7b5
Author: WeichenXu <we...@databricks.com>
Authored: Wed May 9 11:09:19 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed May 9 11:09:19 2018 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/GBTClassifier.scala | 15 +++++++++
 .../spark/ml/regression/GBTRegressor.scala      | 17 ++++++++++-
 .../org/apache/spark/ml/tree/treeParams.scala   |  6 +++-
 .../ml/classification/GBTClassifierSuite.scala  | 29 +++++++++++++++++-
 .../spark/ml/regression/GBTRegressorSuite.scala | 32 ++++++++++++++++++--
 5 files changed, 94 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7aaa148f/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 0aa24f0..3fb6d1e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -334,6 +334,21 @@ class GBTClassificationModel private[ml](
   // hard coded loss, which is not meant to be changed in the model
   private val loss = getOldLossType
 
+  /**
+   * Method to compute error or loss for every iteration of gradient boosting.
+   *
+   * @param dataset Dataset for validation.
+   */
+  @Since("2.4.0")
+  def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
+    val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
+      case Row(label: Double, features: Vector) => LabeledPoint(label, features)
+    }
+    GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
+      OldAlgo.Classification
+    )
+  }
+
   @Since("2.0.0")
   override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7aaa148f/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 8598e80..d7e054b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
 
 /**
@@ -269,6 +269,21 @@ class GBTRegressionModel private[ml](
     new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
   }
 
+  /**
+   * Method to compute error or loss for every iteration of gradient boosting.
+   *
+   * @param dataset Dataset for validation.
+   * @param loss The loss function used to compute error. Supported options: squared, absolute
+   */
+  @Since("2.4.0")
+  def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = {
+    val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
+      case Row(label: Double, features: Vector) => LabeledPoint(label, features)
+    }
+    GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights,
+      convertToOldLossType(loss), OldAlgo.Regression)
+  }
+
   @Since("2.0.0")
   override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7aaa148f/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 81b6222..ec8868b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -579,7 +579,11 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams
 
   /** (private[ml]) Convert new loss to old loss. */
   override private[ml] def getOldLossType: OldLoss = {
-    getLossType match {
+    convertToOldLossType(getLossType)
+  }
+
+  private[ml] def convertToOldLossType(loss: String): OldLoss = {
+    loss match {
       case "squared" => OldSquaredError
       case "absolute" => OldAbsoluteError
       case _ =>

http://git-wip-us.apache.org/repos/asf/spark/blob/7aaa148f/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index f0ee549..e20de19 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree.RegressionLeafNode
-import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@@ -365,6 +365,33 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
     assert(mostImportantFeature !== mostIF)
   }
 
+  test("model evaluateEachIteration") {
+    val gbt = new GBTClassifier()
+      .setSeed(1L)
+      .setMaxDepth(2)
+      .setMaxIter(3)
+      .setLossType("logistic")
+    val model3 = gbt.fit(trainData.toDF)
+    val model1 = new GBTClassificationModel("gbt-cls-model-test1",
+      model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures, model3.numClasses)
+    val model2 = new GBTClassificationModel("gbt-cls-model-test2",
+      model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses)
+
+    val evalArr = model3.evaluateEachIteration(validationData.toDF)
+    val remappedValidationData = validationData.map(
+      x => new LabeledPoint((x.label * 2) - 1, x.features))
+    val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData,
+      model1.trees, model1.treeWeights, model1.getOldLossType)
+    val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData,
+      model2.trees, model2.treeWeights, model2.getOldLossType)
+    val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData,
+      model3.trees, model3.treeWeights, model3.getOldLossType)
+
+    assert(evalArr(0) ~== lossErr1 relTol 1E-3)
+    assert(evalArr(1) ~== lossErr2 relTol 1E-3)
+    assert(evalArr(2) ~== lossErr3 relTol 1E-3)
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/7aaa148f/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index fad11d0..773f6d2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -20,8 +20,9 @@ package org.apache.spark.ml.regression
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -201,7 +202,34 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
     assert(mostImportantFeature !== mostIF)
   }
 
-
+  test("model evaluateEachIteration") {
+    for (lossType <- GBTRegressor.supportedLossTypes) {
+      val gbt = new GBTRegressor()
+        .setSeed(1L)
+        .setMaxDepth(2)
+        .setMaxIter(3)
+        .setLossType(lossType)
+      val model3 = gbt.fit(trainData.toDF)
+      val model1 = new GBTRegressionModel("gbt-reg-model-test1",
+        model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures)
+      val model2 = new GBTRegressionModel("gbt-reg-model-test2",
+        model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures)
+
+      for (evalLossType <- GBTRegressor.supportedLossTypes) {
+        val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType)
+        val lossErr1 = GradientBoostedTrees.computeError(validationData,
+          model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType))
+        val lossErr2 = GradientBoostedTrees.computeError(validationData,
+          model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType))
+        val lossErr3 = GradientBoostedTrees.computeError(validationData,
+          model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType))
+
+        assert(evalArr(0) ~== lossErr1 relTol 1E-3)
+        assert(evalArr(1) ~== lossErr2 relTol 1E-3)
+        assert(evalArr(2) ~== lossErr3 relTol 1E-3)
+      }
+    }
+  }
 
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org