You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/04/05 05:12:15 UTC

spark git commit: [SPARK-14386][ML] Changed spark.ml ensemble trees methods to return concrete types

Repository: spark
Updated Branches:
  refs/heads/master ba24d1ee9 -> 8f50574ab


[SPARK-14386][ML] Changed spark.ml ensemble trees methods to return concrete types

## What changes were proposed in this pull request?

In spark.ml, GBT and RandomForest expose the trait DecisionTreeModel in the trees method, but they should not since it is a private trait (and not ready to be made public). It will also be more useful to users if we return the concrete types.

This PR: return concrete types

The MIMA checks appear to be OK with this change.

## How was this patch tested?

Existing unit tests

Author: Joseph K. Bradley <jo...@databricks.com>

Closes #12158 from jkbradley/hide-dtm.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8f50574a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8f50574a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8f50574a

Branch: refs/heads/master
Commit: 8f50574ab4021b9984b0017cd47ba012a894c19a
Parents: ba24d1e
Author: Joseph K. Bradley <jo...@databricks.com>
Authored: Mon Apr 4 20:12:09 2016 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Apr 4 20:12:09 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/classification/GBTClassifier.scala       |  7 +++----
 .../ml/classification/RandomForestClassifier.scala    |  6 +++---
 .../org/apache/spark/ml/regression/GBTRegressor.scala |  7 +++----
 .../spark/ml/regression/RandomForestRegressor.scala   |  5 +++--
 .../scala/org/apache/spark/ml/tree/treeModels.scala   | 14 +++++++++-----
 .../org/apache/spark/ml/tree/impl/TreeTests.scala     |  2 +-
 6 files changed, 22 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8f50574a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index bfefaf1..bee90fb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -24,8 +24,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams,
-  TreeEnsembleModel}
+import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel}
 import org.apache.spark.ml.tree.impl.GradientBoostedTrees
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
 import org.apache.spark.mllib.linalg.Vector
@@ -190,7 +189,7 @@ final class GBTClassificationModel private[ml](
     private val _treeWeights: Array[Double],
     @Since("1.6.0") override val numFeatures: Int)
   extends PredictionModel[Vector, GBTClassificationModel]
-  with TreeEnsembleModel with Serializable {
+  with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
 
   require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.")
   require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" +
@@ -206,7 +205,7 @@ final class GBTClassificationModel private[ml](
     this(uid, _trees, _treeWeights, -1)
 
   @Since("1.4.0")
-  override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+  override def trees: Array[DecisionTreeRegressionModel] = _trees
 
   @Since("1.4.0")
   override def treeWeights: Array[Double] = _treeWeights

http://git-wip-us.apache.org/repos/asf/spark/blob/8f50574a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 2ad893f..cb42532 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -155,8 +155,8 @@ final class RandomForestClassificationModel private[ml] (
     @Since("1.6.0") override val numFeatures: Int,
     @Since("1.5.0") override val numClasses: Int)
   extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
-  with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable
-  with Serializable {
+  with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel]
+  with MLWritable with Serializable {
 
   require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.")
 
@@ -172,7 +172,7 @@ final class RandomForestClassificationModel private[ml] (
     this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses)
 
   @Since("1.4.0")
-  override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+  override def trees: Array[DecisionTreeClassificationModel] = _trees
 
   // Note: We may add support for weights (based on tree performance) later on.
   private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)

http://git-wip-us.apache.org/repos/asf/spark/blob/8f50574a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 02e124a..cef7c64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -23,8 +23,7 @@ import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel,
-  TreeRegressorParams}
+import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams}
 import org.apache.spark.ml.tree.impl.GradientBoostedTrees
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
 import org.apache.spark.mllib.linalg.Vector
@@ -177,7 +176,7 @@ final class GBTRegressionModel private[ml](
     private val _treeWeights: Array[Double],
     override val numFeatures: Int)
   extends PredictionModel[Vector, GBTRegressionModel]
-  with TreeEnsembleModel with Serializable {
+  with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable {
 
   require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.")
   require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" +
@@ -193,7 +192,7 @@ final class GBTRegressionModel private[ml](
     this(uid, _trees, _treeWeights, -1)
 
   @Since("1.4.0")
-  override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+  override def trees: Array[DecisionTreeRegressionModel] = _trees
 
   @Since("1.4.0")
   override def treeWeights: Array[Double] = _treeWeights

http://git-wip-us.apache.org/repos/asf/spark/blob/8f50574a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index ba56b5c..736cd9f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -142,7 +142,8 @@ final class RandomForestRegressionModel private[ml] (
     private val _trees: Array[DecisionTreeRegressionModel],
     override val numFeatures: Int)
   extends PredictionModel[Vector, RandomForestRegressionModel]
-  with RandomForestRegressionModelParams with TreeEnsembleModel with MLWritable with Serializable {
+  with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel]
+  with MLWritable with Serializable {
 
   require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.")
 
@@ -155,7 +156,7 @@ final class RandomForestRegressionModel private[ml] (
     this(Identifiable.randomUID("rfr"), trees, numFeatures)
 
   @Since("1.4.0")
-  override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
+  override def trees: Array[DecisionTreeRegressionModel] = _trees
 
   // Note: We may add support for weights (based on tree performance) later on.
   private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0)

http://git-wip-us.apache.org/repos/asf/spark/blob/8f50574a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index 48b8fd1..db0ff28 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.ml.tree
 
+import scala.reflect.ClassTag
+
 import org.apache.hadoop.fs.Path
 import org.json4s._
 import org.json4s.jackson.JsonMethods._
@@ -82,14 +84,16 @@ private[spark] trait DecisionTreeModel {
  * Abstraction for models which are ensembles of decision trees
  *
  * TODO: Add support for predicting probabilities and raw predictions  SPARK-3727
+ *
+ * @tparam M  Type of tree model in this ensemble
  */
-private[ml] trait TreeEnsembleModel {
+private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] {
 
   // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of
   //       DecisionTreeModel.
 
   /** Trees in this ensemble. Warning: These have null parent Estimators. */
-  def trees: Array[DecisionTreeModel]
+  def trees: Array[M]
 
   /**
    * Number of trees in ensemble
@@ -148,7 +152,7 @@ private[ml] object TreeEnsembleModel {
    *                     If -1, then numFeatures is set based on the max feature index in all trees.
    * @return  Feature importance values, of length numFeatures.
    */
-  def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
+  def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = {
     val totalImportances = new OpenHashMap[Int, Double]()
     trees.foreach { tree =>
       // Aggregate feature importance vector for this tree
@@ -199,7 +203,7 @@ private[ml] object TreeEnsembleModel {
    *                     If -1, then numFeatures is set based on the max feature index in all trees.
    * @return  Feature importance values, of length numFeatures.
    */
-  def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
+  def featureImportances[M <: DecisionTreeModel : ClassTag](tree: M, numFeatures: Int): Vector = {
     featureImportances(Array(tree), numFeatures)
   }
 
@@ -386,7 +390,7 @@ private[ml] object EnsembleModelReadWrite {
    * @param path  Path to which to save the ensemble model.
    * @param extraMetadata  Metadata such as numFeatures, numClasses, numTrees.
    */
-  def saveImpl[M <: Params with TreeEnsembleModel](
+  def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]](
       instance: M,
       path: String,
       sql: SQLContext,

http://git-wip-us.apache.org/repos/asf/spark/blob/8f50574a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index bd5bd17..b650a9f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -131,7 +131,7 @@ private[ml] object TreeTests extends SparkFunSuite {
    * Check if the two models are exactly the same.
    * If the models are not equal, this throws an exception.
    */
-  def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
+  def checkEqual[M <: DecisionTreeModel](a: TreeEnsembleModel[M], b: TreeEnsembleModel[M]): Unit = {
     try {
       a.trees.zip(b.trees).foreach { case (treeA, treeB) =>
         TreeTests.checkEqual(treeA, treeB)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org