You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2015/01/27 04:46:19 UTC

spark git commit: [SPARK-3726] [MLlib] Allow sampling_rate not equal to 1.0 in RandomForests

Repository: spark
Updated Branches:
  refs/heads/master f2ba5c6fc -> d6894b1c5


[SPARK-3726] [MLlib] Allow sampling_rate not equal to 1.0 in RandomForests

I've added support for sampling_rate not equal to 1.0 . I have two major questions.

1. A Scala style test is failing, since the number of parameters now exceed 10.
2. I would like suggestions to understand how to test this.

Author: MechCoder <ma...@gmail.com>

Closes #4073 from MechCoder/spark-3726 and squashes the following commits:

8012fb2 [MechCoder] Add test in Strategy
e0e0d9c [MechCoder] TST: Add better test
d1df1b2 [MechCoder] Add test to verify subsampling behavior
a7bfc70 [MechCoder] [SPARK-3726] Allow sampling_rate not equal to 1.0


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d6894b1c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d6894b1c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d6894b1c

Branch: refs/heads/master
Commit: d6894b1c5314c751cfdaf78005b99b2104e6e4d1
Parents: f2ba5c6
Author: MechCoder <ma...@gmail.com>
Authored: Mon Jan 26 19:46:17 2015 -0800
Committer: Xiangrui Meng <me...@databricks.com>
Committed: Mon Jan 26 19:46:17 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/mllib/tree/RandomForest.scala  | 16 +++++-----------
 .../spark/mllib/tree/configuration/Strategy.scala   |  3 +++
 .../apache/spark/mllib/tree/RandomForestSuite.scala | 16 ++++++++++++++++
 3 files changed, 24 insertions(+), 11 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d6894b1c/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index e9304b5..482dd4b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -140,6 +140,7 @@ private class RandomForest (
     logDebug("maxBins = " + metadata.maxBins)
     logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
     logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+    logDebug("subsamplingRate = " + strategy.subsamplingRate)
 
     // Find the splits and the corresponding bins (interval between the splits) using a sample
     // of the input data.
@@ -155,19 +156,12 @@ private class RandomForest (
     // Cache input RDD for speedup during multiple passes.
     val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
 
-    val (subsample, withReplacement) = {
-      // TODO: Have a stricter check for RF in the strategy
-      val isRandomForest = numTrees > 1
-      if (isRandomForest) {
-        (1.0, true)
-      } else {
-        (strategy.subsamplingRate, false)
-      }
-    }
+    val withReplacement = if (numTrees > 1) true else false
 
     val baggedInput
-      = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
-        .persist(StorageLevel.MEMORY_AND_DISK)
+      = BaggedPoint.convertToBaggedRDD(treeInput,
+          strategy.subsamplingRate, numTrees,
+          withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
 
     // depth of the decision tree
     val maxDepth = strategy.maxDepth

http://git-wip-us.apache.org/repos/asf/spark/blob/d6894b1c/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 9729598..3308adb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -156,6 +156,9 @@ class Strategy (
       s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
     require(maxMemoryInMB <= 10240,
       s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
+    require(subsamplingRate > 0 && subsamplingRate <= 1,
+      s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " +
+      s"$subsamplingRate")
   }
 
   /** Returns a shallow copy of this instance. */

http://git-wip-us.apache.org/repos/asf/spark/blob/d6894b1c/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index f7f0f20..55e9639 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
       featureSubsetStrategy = "sqrt", seed = 12345)
     EnsembleTestHelper.validateClassifier(model, arr, 1.0)
   }
+
+  test("subsampling rate in RandomForest"){
+    val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20)
+    val rdd = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int],
+      useNodeIdCache = true)
+
+    val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
+      featureSubsetStrategy = "auto", seed = 123)
+    strategy.subsamplingRate = 0.5
+    val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
+      featureSubsetStrategy = "auto", seed = 123)
+    assert(rf1.toDebugString != rf2.toDebugString)
+  }
+
 }
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org