You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2015/07/30 03:18:43 UTC

spark git commit: [SPARK-9016] [ML] make random forest classifiers implement classification trait

Repository: spark
Updated Branches:
  refs/heads/master 103d8cce7 -> 37c2d1927


[SPARK-9016] [ML] make random forest classifiers implement classification trait

Implement the classification trait for RandomForestClassifiers. The plan is to use this in the future to providing thresholding for RandomForestClassifiers (as well as other classifiers that implement that trait).

Author: Holden Karau <ho...@pigscanfly.ca>

Closes #7432 from holdenk/SPARK-9016-make-random-forest-classifiers-implement-classification-trait and squashes the following commits:

bf22fa6 [Holden Karau] Add missing imports for testing suite
e948f0d [Holden Karau] Check the prediction generation from rawprediciton
25320c3 [Holden Karau] Don't supply numClasses when not needed, assert model classes are as expected
1a67e04 [Holden Karau] Use old decission tree stuff instead
673e0c3 [Holden Karau] Merge branch 'master' into SPARK-9016-make-random-forest-classifiers-implement-classification-trait
0d15b96 [Holden Karau] FIx typo
5eafad4 [Holden Karau] add a constructor for rootnode + num classes
fc6156f [Holden Karau] scala style fix
2597915 [Holden Karau] take num classes in constructor
3ccfe4a [Holden Karau] Merge in master, make pass numClasses through randomforest for training
222a10b [Holden Karau] Increase numtrees to 3 in the python test since before the two were equal and the argmax was selecting the last one
16aea1c [Holden Karau] Make tests match the new models
b454a02 [Holden Karau] Make the Tree classifiers extends the Classifier base class
77b4114 [Holden Karau] Import vectors lib


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/37c2d192
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/37c2d192
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/37c2d192

Branch: refs/heads/master
Commit: 37c2d1927cebdd19a14c054f670cb0fb9a263586
Parents: 103d8cc
Author: Holden Karau <ho...@pigscanfly.ca>
Authored: Wed Jul 29 18:18:29 2015 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Wed Jul 29 18:18:29 2015 -0700

----------------------------------------------------------------------
 .../classification/RandomForestClassifier.scala | 30 +++++++++++---------
 .../RandomForestClassifierSuite.scala           | 18 +++++++++---
 python/pyspark/ml/classification.py             |  4 +--
 3 files changed, 32 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/37c2d192/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index fc0693f..bc19bd6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -25,7 +25,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel}
 import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
@@ -43,7 +43,7 @@ import org.apache.spark.sql.types.DoubleType
  */
 @Experimental
 final class RandomForestClassifier(override val uid: String)
-  extends Predictor[Vector, RandomForestClassifier, RandomForestClassificationModel]
+  extends Classifier[Vector, RandomForestClassifier, RandomForestClassificationModel]
   with RandomForestParams with TreeClassifierParams {
 
   def this() = this(Identifiable.randomUID("rfc"))
@@ -98,7 +98,7 @@ final class RandomForestClassifier(override val uid: String)
     val trees =
       RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
         .map(_.asInstanceOf[DecisionTreeClassificationModel])
-    new RandomForestClassificationModel(trees)
+    new RandomForestClassificationModel(trees, numClasses)
   }
 
   override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
@@ -125,8 +125,9 @@ object RandomForestClassifier {
 @Experimental
 final class RandomForestClassificationModel private[ml] (
     override val uid: String,
-    private val _trees: Array[DecisionTreeClassificationModel])
-  extends PredictionModel[Vector, RandomForestClassificationModel]
+    private val _trees: Array[DecisionTreeClassificationModel],
+    override val numClasses: Int)
+  extends ClassificationModel[Vector, RandomForestClassificationModel]
   with TreeEnsembleModel with Serializable {
 
   require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.")
@@ -135,8 +136,8 @@ final class RandomForestClassificationModel private[ml] (
    * Construct a random forest classification model, with all trees weighted equally.
    * @param trees  Component trees
    */
-  def this(trees: Array[DecisionTreeClassificationModel]) =
-    this(Identifiable.randomUID("rfc"), trees)
+  def this(trees: Array[DecisionTreeClassificationModel], numClasses: Int) =
+    this(Identifiable.randomUID("rfc"), trees, numClasses)
 
   override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
 
@@ -153,20 +154,20 @@ final class RandomForestClassificationModel private[ml] (
     dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
   }
 
-  override protected def predict(features: Vector): Double = {
+  override protected def predictRaw(features: Vector): Vector = {
     // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128
     // Classifies using majority votes.
     // Ignore the weights since all are 1.0 for now.
-    val votes = mutable.Map.empty[Int, Double]
+    val votes = new Array[Double](numClasses)
     _trees.view.foreach { tree =>
       val prediction = tree.rootNode.predict(features).toInt
-      votes(prediction) = votes.getOrElse(prediction, 0.0) + 1.0 // 1.0 = weight
+      votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
     }
-    votes.maxBy(_._2)._1
+    Vectors.dense(votes)
   }
 
   override def copy(extra: ParamMap): RandomForestClassificationModel = {
-    copyValues(new RandomForestClassificationModel(uid, _trees), extra)
+    copyValues(new RandomForestClassificationModel(uid, _trees, numClasses), extra)
   }
 
   override def toString: String = {
@@ -185,7 +186,8 @@ private[ml] object RandomForestClassificationModel {
   def fromOld(
       oldModel: OldRandomForestModel,
       parent: RandomForestClassifier,
-      categoricalFeatures: Map[Int, Int]): RandomForestClassificationModel = {
+      categoricalFeatures: Map[Int, Int],
+      numClasses: Int): RandomForestClassificationModel = {
     require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
       s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
     val newTrees = oldModel.trees.map { tree =>
@@ -193,6 +195,6 @@ private[ml] object RandomForestClassificationModel {
       DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
     }
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
-    new RandomForestClassificationModel(uid, newTrees)
+    new RandomForestClassificationModel(uid, newTrees, numClasses)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/37c2d192/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 1b6b69c..ab711c8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -21,13 +21,13 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.impl.TreeTests
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.tree.LeafNode
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 
 /**
  * Test suite for [[RandomForestClassifier]].
@@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
   test("params") {
     ParamsSuite.checkParams(new RandomForestClassifier)
     val model = new RandomForestClassificationModel("rfc",
-      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))))
+      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
     ParamsSuite.checkParams(model)
   }
 
@@ -167,9 +167,19 @@ private object RandomForestClassifierSuite {
     val newModel = rf.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldModelAsNew = RandomForestClassificationModel.fromOld(
-      oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
+      oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures,
+      numClasses)
     TreeTests.checkEqual(oldModelAsNew, newModel)
     assert(newModel.hasParent)
     assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
+    assert(newModel.numClasses == numClasses)
+    val results = newModel.transform(newData)
+    results.select("rawPrediction", "prediction").collect().foreach {
+      case Row(raw: Vector, prediction: Double) => {
+        assert(raw.size == numClasses)
+        val predFromRaw = raw.toArray.zipWithIndex.maxBy(_._1)._2
+        assert(predFromRaw == prediction)
+      }
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/37c2d192/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 89117e4..5a82bc2 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -299,9 +299,9 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
     >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
     >>> si_model = stringIndexer.fit(df)
     >>> td = si_model.transform(df)
-    >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
+    >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
     >>> model = rf.fit(td)
-    >>> allclose(model.treeWeights, [1.0, 1.0])
+    >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
     True
     >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
     >>> model.transform(test0).head().prediction


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org