You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/01/24 10:02:56 UTC
[spark] branch branch-2.4 updated: [SPARK-30630][ML][2.4] Deprecate
numTrees in GBT in 2.4.5
This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new d7be535 [SPARK-30630][ML][2.4] Deprecate numTrees in GBT in 2.4.5
d7be535 is described below
commit d7be535e02b84e8ef00bab8036b0014d5a389fbb
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Fri Jan 24 02:01:49 2020 -0800
[SPARK-30630][ML][2.4] Deprecate numTrees in GBT in 2.4.5
### What changes were proposed in this pull request?
Deprecate numTrees in GBT in 2.4.5 so it can be removed in 3.0.0
### Why are the changes needed?
Currently, GBT has
```
/**
* Number of trees in ensemble
*/
Since("2.0.0")
val getNumTrees: Int = trees.length
```
and
```
/** Number of trees in ensemble */
val numTrees: Int = trees.length
```
I think we should remove one of them. I will deprecate it in 2.4.5 and remove it in 3.0.0
### Does this PR introduce any user-facing change?
Deprecate numTrees in 2.4.5
### How was this patch tested?
Existing tests
Closes #27352 from huaxingao/spark-tree.
Authored-by: Huaxin Gao <hu...@us.ibm.com>
Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
.../org/apache/spark/ml/classification/GBTClassifier.scala | 11 ++++++++---
.../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 11 ++++++++---
.../apache/spark/ml/classification/GBTClassifierSuite.scala | 13 +++++++------
.../org/apache/spark/ml/regression/GBTRegressorSuite.scala | 10 +++++-----
4 files changed, 28 insertions(+), 17 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 62cfa39..c5cb03e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -319,7 +319,12 @@ class GBTClassificationModel private[ml](
}
}
- /** Number of trees in ensemble */
+ /**
+ * Number of trees in ensemble
+ *
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 3.0.0.
+ */
+ @deprecated("Use getNumTrees instead. This method will be removed in 3.0.0.", "2.4.5")
val numTrees: Int = trees.length
@Since("1.4.0")
@@ -330,7 +335,7 @@ class GBTClassificationModel private[ml](
@Since("1.4.0")
override def toString: String = {
- s"GBTClassificationModel (uid=$uid) with $numTrees trees"
+ s"GBTClassificationModel (uid=$uid) with $getNumTrees trees"
}
/**
@@ -349,7 +354,7 @@ class GBTClassificationModel private[ml](
/** Raw prediction for the positive class. */
private def margin(features: Vector): Double = {
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
- blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+ blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
}
/** (private[ml]) Convert to a model in the old API */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 07f88d8..a56b5c4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -255,10 +255,15 @@ class GBTRegressionModel private[ml](
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
- blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+ blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
}
- /** Number of trees in ensemble */
+ /**
+ * Number of trees in ensemble
+ *
+ * @deprecated Use [[getNumTrees]] instead. This method will be removed in 3.0.0.
+ */
+ @deprecated("Use getNumTrees instead. This method will be removed in 3.0.0.", "2.4.5")
val numTrees: Int = trees.length
@Since("1.4.0")
@@ -269,7 +274,7 @@ class GBTRegressionModel private[ml](
@Since("1.4.0")
override def toString: String = {
- s"GBTRegressionModel (uid=$uid) with $numTrees trees"
+ s"GBTRegressionModel (uid=$uid) with $getNumTrees trees"
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 3049776..02b6b5d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -178,7 +178,8 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
assert(raw.size === 2)
// check that raw prediction is tree predictions dot tree weights
val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
- val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
+ val prediction = blas.ddot(gbtModel.getNumTrees, treePredictions, 1,
+ gbtModel.treeWeights, 1)
assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
// Compare rawPrediction with probability
@@ -410,9 +411,9 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
gbt.setValidationIndicatorCol(validationIndicatorCol)
val modelWithValidation = gbt.fit(trainDF.union(validationDF))
- assert(modelWithoutValidation.numTrees === numIter)
+ assert(modelWithoutValidation.getNumTrees === numIter)
// early stop
- assert(modelWithValidation.numTrees < numIter)
+ assert(modelWithValidation.getNumTrees < numIter)
val (errorWithoutValidation, errorWithValidation) = {
val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@@ -428,10 +429,10 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
OldAlgo.Classification)
assert(evaluationArray.length === numIter)
- assert(evaluationArray(modelWithValidation.numTrees) >
- evaluationArray(modelWithValidation.numTrees - 1))
+ assert(evaluationArray(modelWithValidation.getNumTrees) >
+ evaluationArray(modelWithValidation.getNumTrees - 1))
var i = 1
- while (i < modelWithValidation.numTrees) {
+ while (i < modelWithValidation.getNumTrees) {
assert(evaluationArray(i) <= evaluationArray(i - 1))
i += 1
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index b145c7a..9342bc0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -249,9 +249,9 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
gbt.setValidationIndicatorCol(validationIndicatorCol)
val modelWithValidation = gbt.fit(trainDF.union(validationDF))
- assert(modelWithoutValidation.numTrees === numIter)
+ assert(modelWithoutValidation.getNumTrees === numIter)
// early stop
- assert(modelWithValidation.numTrees < numIter)
+ assert(modelWithValidation.getNumTrees < numIter)
val errorWithoutValidation = GradientBoostedTrees.computeError(validationData,
modelWithoutValidation.trees, modelWithoutValidation.treeWeights,
@@ -267,10 +267,10 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
OldAlgo.Regression)
assert(evaluationArray.length === numIter)
- assert(evaluationArray(modelWithValidation.numTrees) >
- evaluationArray(modelWithValidation.numTrees - 1))
+ assert(evaluationArray(modelWithValidation.getNumTrees) >
+ evaluationArray(modelWithValidation.getNumTrees - 1))
var i = 1
- while (i < modelWithValidation.numTrees) {
+ while (i < modelWithValidation.getNumTrees) {
assert(evaluationArray(i) <= evaluationArray(i - 1))
i += 1
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org