You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by do...@apache.org on 2020/01/24 10:02:56 UTC

[spark] branch branch-2.4 updated: [SPARK-30630][ML][2.4] Deprecate numTrees in GBT in 2.4.5

This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.4 by this push:
     new d7be535  [SPARK-30630][ML][2.4] Deprecate numTrees in GBT in 2.4.5
d7be535 is described below

commit d7be535e02b84e8ef00bab8036b0014d5a389fbb
Author: Huaxin Gao <hu...@us.ibm.com>
AuthorDate: Fri Jan 24 02:01:49 2020 -0800

    [SPARK-30630][ML][2.4] Deprecate numTrees in GBT in 2.4.5
    
    ### What changes were proposed in this pull request?
    Deprecate numTrees in GBT in 2.4.5 so it can be removed in 3.0.0
    
    ### Why are the changes needed?
    Currently, GBT has
    ```
      /**
       * Number of trees in ensemble
       */
      Since("2.0.0")
      val getNumTrees: Int = trees.length
    ```
    and
    ```
      /** Number of trees in ensemble */
      val numTrees: Int = trees.length
    ```
    I think we should remove one of them. I will deprecate it in 2.4.5 and remove it in 3.0.0
    
    ### Does this PR introduce any user-facing change?
    Deprecate numTrees in 2.4.5
    
    ### How was this patch tested?
    Existing tests
    
    Closes #27352 from huaxingao/spark-tree.
    
    Authored-by: Huaxin Gao <hu...@us.ibm.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../org/apache/spark/ml/classification/GBTClassifier.scala  | 11 ++++++++---
 .../scala/org/apache/spark/ml/regression/GBTRegressor.scala | 11 ++++++++---
 .../apache/spark/ml/classification/GBTClassifierSuite.scala | 13 +++++++------
 .../org/apache/spark/ml/regression/GBTRegressorSuite.scala  | 10 +++++-----
 4 files changed, 28 insertions(+), 17 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 62cfa39..c5cb03e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -319,7 +319,12 @@ class GBTClassificationModel private[ml](
     }
   }
 
-  /** Number of trees in ensemble */
+  /**
+   * Number of trees in ensemble
+   *
+   * @deprecated  Use [[getNumTrees]] instead. This method will be removed in 3.0.0.
+   */
+  @deprecated("Use getNumTrees instead. This method will be removed in 3.0.0.", "2.4.5")
   val numTrees: Int = trees.length
 
   @Since("1.4.0")
@@ -330,7 +335,7 @@ class GBTClassificationModel private[ml](
 
   @Since("1.4.0")
   override def toString: String = {
-    s"GBTClassificationModel (uid=$uid) with $numTrees trees"
+    s"GBTClassificationModel (uid=$uid) with $getNumTrees trees"
   }
 
   /**
@@ -349,7 +354,7 @@ class GBTClassificationModel private[ml](
   /** Raw prediction for the positive class. */
   private def margin(features: Vector): Double = {
     val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
-    blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+    blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
   }
 
   /** (private[ml]) Convert to a model in the old API */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 07f88d8..a56b5c4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -255,10 +255,15 @@ class GBTRegressionModel private[ml](
     // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
     // Classifies by thresholding sum of weighted tree predictions
     val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
-    blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
+    blas.ddot(getNumTrees, treePredictions, 1, _treeWeights, 1)
   }
 
-  /** Number of trees in ensemble */
+  /**
+   * Number of trees in ensemble
+   *
+   * @deprecated  Use [[getNumTrees]] instead. This method will be removed in 3.0.0.
+   */
+  @deprecated("Use getNumTrees instead. This method will be removed in 3.0.0.", "2.4.5")
   val numTrees: Int = trees.length
 
   @Since("1.4.0")
@@ -269,7 +274,7 @@ class GBTRegressionModel private[ml](
 
   @Since("1.4.0")
   override def toString: String = {
-    s"GBTRegressionModel (uid=$uid) with $numTrees trees"
+    s"GBTRegressionModel (uid=$uid) with $getNumTrees trees"
   }
 
   /**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 3049776..02b6b5d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -178,7 +178,8 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
         assert(raw.size === 2)
         // check that raw prediction is tree predictions dot tree weights
         val treePredictions = gbtModel.trees.map(_.rootNode.predictImpl(features).prediction)
-        val prediction = blas.ddot(gbtModel.numTrees, treePredictions, 1, gbtModel.treeWeights, 1)
+        val prediction = blas.ddot(gbtModel.getNumTrees, treePredictions, 1,
+          gbtModel.treeWeights, 1)
         assert(raw ~== Vectors.dense(-prediction, prediction) relTol eps)
 
         // Compare rawPrediction with probability
@@ -410,9 +411,9 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
       gbt.setValidationIndicatorCol(validationIndicatorCol)
       val modelWithValidation = gbt.fit(trainDF.union(validationDF))
 
-      assert(modelWithoutValidation.numTrees === numIter)
+      assert(modelWithoutValidation.getNumTrees === numIter)
       // early stop
-      assert(modelWithValidation.numTrees < numIter)
+      assert(modelWithValidation.getNumTrees < numIter)
 
       val (errorWithoutValidation, errorWithValidation) = {
         val remappedRdd = validationData.map(x => new LabeledPoint(2 * x.label - 1, x.features))
@@ -428,10 +429,10 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
           modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
           OldAlgo.Classification)
       assert(evaluationArray.length === numIter)
-      assert(evaluationArray(modelWithValidation.numTrees) >
-        evaluationArray(modelWithValidation.numTrees - 1))
+      assert(evaluationArray(modelWithValidation.getNumTrees) >
+        evaluationArray(modelWithValidation.getNumTrees - 1))
       var i = 1
-      while (i < modelWithValidation.numTrees) {
+      while (i < modelWithValidation.getNumTrees) {
         assert(evaluationArray(i) <= evaluationArray(i - 1))
         i += 1
       }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index b145c7a..9342bc0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -249,9 +249,9 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
       gbt.setValidationIndicatorCol(validationIndicatorCol)
       val modelWithValidation = gbt.fit(trainDF.union(validationDF))
 
-      assert(modelWithoutValidation.numTrees === numIter)
+      assert(modelWithoutValidation.getNumTrees === numIter)
       // early stop
-      assert(modelWithValidation.numTrees < numIter)
+      assert(modelWithValidation.getNumTrees < numIter)
 
       val errorWithoutValidation = GradientBoostedTrees.computeError(validationData,
         modelWithoutValidation.trees, modelWithoutValidation.treeWeights,
@@ -267,10 +267,10 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
           modelWithoutValidation.treeWeights, modelWithoutValidation.getOldLossType,
           OldAlgo.Regression)
       assert(evaluationArray.length === numIter)
-      assert(evaluationArray(modelWithValidation.numTrees) >
-        evaluationArray(modelWithValidation.numTrees - 1))
+      assert(evaluationArray(modelWithValidation.getNumTrees) >
+        evaluationArray(modelWithValidation.getNumTrees - 1))
       var i = 1
-      while (i < modelWithValidation.numTrees) {
+      while (i < modelWithValidation.getNumTrees) {
         assert(evaluationArray(i) <= evaluationArray(i - 1))
         i += 1
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org