You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/05/09 18:09:24 UTC
spark git commit: [SPARK-14682][ML] Provide evaluateEachIteration
method or equivalent for spark.ml GBTs
Repository: spark
Updated Branches:
refs/heads/master 628c7b517 -> 7aaa148f5
[SPARK-14682][ML] Provide evaluateEachIteration method or equivalent for spark.ml GBTs
## What changes were proposed in this pull request?
Provide evaluateEachIteration method or equivalent for spark.ml GBTs.
## How was this patch tested?
UT.
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: WeichenXu <we...@databricks.com>
Closes #21097 from WeichenXu123/GBTeval.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7aaa148f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7aaa148f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7aaa148f
Branch: refs/heads/master
Commit: 7aaa148f593470b2c32221b69097b8b54524eb74
Parents: 628c7b5
Author: WeichenXu <we...@databricks.com>
Authored: Wed May 9 11:09:19 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed May 9 11:09:19 2018 -0700
----------------------------------------------------------------------
.../spark/ml/classification/GBTClassifier.scala | 15 +++++++++
.../spark/ml/regression/GBTRegressor.scala | 17 ++++++++++-
.../org/apache/spark/ml/tree/treeParams.scala | 6 +++-
.../ml/classification/GBTClassifierSuite.scala | 29 +++++++++++++++++-
.../spark/ml/regression/GBTRegressorSuite.scala | 32 ++++++++++++++++++--
5 files changed, 94 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/7aaa148f/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 0aa24f0..3fb6d1e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -334,6 +334,21 @@ class GBTClassificationModel private[ml](
// hard coded loss, which is not meant to be changed in the model
private val loss = getOldLossType
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param dataset Dataset for validation.
+ */
+ @Since("2.4.0")
+ def evaluateEachIteration(dataset: Dataset[_]): Array[Double] = {
+ val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) => LabeledPoint(label, features)
+ }
+ GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights, loss,
+ OldAlgo.Classification
+ )
+ }
+
@Since("2.0.0")
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/7aaa148f/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 8598e80..d7e054b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -34,7 +34,7 @@ import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
/**
@@ -269,6 +269,21 @@ class GBTRegressionModel private[ml](
new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
}
+ /**
+ * Method to compute error or loss for every iteration of gradient boosting.
+ *
+ * @param dataset Dataset for validation.
+ * @param loss The loss function used to compute error. Supported options: squared, absolute
+ */
+ @Since("2.4.0")
+ def evaluateEachIteration(dataset: Dataset[_], loss: String): Array[Double] = {
+ val data = dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
+ case Row(label: Double, features: Vector) => LabeledPoint(label, features)
+ }
+ GradientBoostedTrees.evaluateEachIteration(data, trees, treeWeights,
+ convertToOldLossType(loss), OldAlgo.Regression)
+ }
+
@Since("2.0.0")
override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this)
}
http://git-wip-us.apache.org/repos/asf/spark/blob/7aaa148f/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 81b6222..ec8868b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -579,7 +579,11 @@ private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams
/** (private[ml]) Convert new loss to old loss. */
override private[ml] def getOldLossType: OldLoss = {
- getLossType match {
+ convertToOldLossType(getLossType)
+ }
+
+ private[ml] def convertToOldLossType(loss: String): OldLoss = {
+ loss match {
case "squared" => OldSquaredError
case "absolute" => OldAbsoluteError
case _ =>
http://git-wip-us.apache.org/repos/asf/spark/blob/7aaa148f/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index f0ee549..e20de19 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree.RegressionLeafNode
-import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@@ -365,6 +365,33 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
assert(mostImportantFeature !== mostIF)
}
+ test("model evaluateEachIteration") {
+ val gbt = new GBTClassifier()
+ .setSeed(1L)
+ .setMaxDepth(2)
+ .setMaxIter(3)
+ .setLossType("logistic")
+ val model3 = gbt.fit(trainData.toDF)
+ val model1 = new GBTClassificationModel("gbt-cls-model-test1",
+ model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures, model3.numClasses)
+ val model2 = new GBTClassificationModel("gbt-cls-model-test2",
+ model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures, model3.numClasses)
+
+ val evalArr = model3.evaluateEachIteration(validationData.toDF)
+ val remappedValidationData = validationData.map(
+ x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val lossErr1 = GradientBoostedTrees.computeError(remappedValidationData,
+ model1.trees, model1.treeWeights, model1.getOldLossType)
+ val lossErr2 = GradientBoostedTrees.computeError(remappedValidationData,
+ model2.trees, model2.treeWeights, model2.getOldLossType)
+ val lossErr3 = GradientBoostedTrees.computeError(remappedValidationData,
+ model3.trees, model3.treeWeights, model3.getOldLossType)
+
+ assert(evalArr(0) ~== lossErr1 relTol 1E-3)
+ assert(evalArr(1) ~== lossErr2 relTol 1E-3)
+ assert(evalArr(2) ~== lossErr3 relTol 1E-3)
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
http://git-wip-us.apache.org/repos/asf/spark/blob/7aaa148f/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index fad11d0..773f6d2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -20,8 +20,9 @@ package org.apache.spark.ml.regression
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.ml.tree.impl.{GradientBoostedTrees, TreeTests}
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
+import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
@@ -201,7 +202,34 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
assert(mostImportantFeature !== mostIF)
}
-
+ test("model evaluateEachIteration") {
+ for (lossType <- GBTRegressor.supportedLossTypes) {
+ val gbt = new GBTRegressor()
+ .setSeed(1L)
+ .setMaxDepth(2)
+ .setMaxIter(3)
+ .setLossType(lossType)
+ val model3 = gbt.fit(trainData.toDF)
+ val model1 = new GBTRegressionModel("gbt-reg-model-test1",
+ model3.trees.take(1), model3.treeWeights.take(1), model3.numFeatures)
+ val model2 = new GBTRegressionModel("gbt-reg-model-test2",
+ model3.trees.take(2), model3.treeWeights.take(2), model3.numFeatures)
+
+ for (evalLossType <- GBTRegressor.supportedLossTypes) {
+ val evalArr = model3.evaluateEachIteration(validationData.toDF, evalLossType)
+ val lossErr1 = GradientBoostedTrees.computeError(validationData,
+ model1.trees, model1.treeWeights, model1.convertToOldLossType(evalLossType))
+ val lossErr2 = GradientBoostedTrees.computeError(validationData,
+ model2.trees, model2.treeWeights, model2.convertToOldLossType(evalLossType))
+ val lossErr3 = GradientBoostedTrees.computeError(validationData,
+ model3.trees, model3.treeWeights, model3.convertToOldLossType(evalLossType))
+
+ assert(evalArr(0) ~== lossErr1 relTol 1E-3)
+ assert(evalArr(1) ~== lossErr2 relTol 1E-3)
+ assert(evalArr(2) ~== lossErr3 relTol 1E-3)
+ }
+ }
+ }
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org