You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2009/12/22 10:08:44 UTC

svn commit: r893116 - /lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java

Author: adeneche
Date: Tue Dec 22 09:08:44 2009
New Revision: 893116

URL: http://svn.apache.org/viewvc?rev=893116&view=rev
Log:
fixing the Breiman example

Modified:
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java?rev=893116&r1=893115&r2=893116&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java Tue Dec 22 09:08:44 2009
@@ -81,7 +81,6 @@
    */
   protected static void runIteration(Data data, int m, int nbtrees) {
 
-    int dataSize = data.size();
     int nblabels = data.getDataset().nblabels();
 
     Random rng = RandomUtils.getRandom();
@@ -91,13 +90,13 @@
     
     int[] trainLabels = train.extractLabels();
     int[] testLabels = test.extractLabels();
-    
+
     DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
     
     SequentialBuilder forestBuilder = new SequentialBuilder(rng, treeBuilder, train);
 
     // grow a forest with m = log2(M)+1
-    ForestPredictions errorM = new ForestPredictions(dataSize, nblabels); // oob error when using m = log2(M)+1
+    ForestPredictions errorM = new ForestPredictions(train.size(), nblabels); // oob error when using m = log2(M)+1
     treeBuilder.setM(m);
 
     long time = System.currentTimeMillis();
@@ -108,7 +107,7 @@
     double oobM = ErrorEstimate.errorRate(trainLabels, errorM.computePredictions(rng)); // oob error estimate when m = log2(M)+1
 
     // grow a forest with m=1
-    ForestPredictions errorOne = new ForestPredictions(dataSize, nblabels); // oob error when using m = 1
+    ForestPredictions errorOne = new ForestPredictions(train.size(), nblabels); // oob error when using m = 1
     treeBuilder.setM(1);
 
     time = System.currentTimeMillis();
@@ -120,11 +119,11 @@
 
     // compute the test set error (Selection Error), and mean tree error (One Tree Error),
     // using the lowest oob error forest
-    ForestPredictions testError = new ForestPredictions(dataSize, nblabels); // test set error
-    MeanTreeCollector treeError = new MeanTreeCollector(train, nbtrees); // mean tree error
+    ForestPredictions testError = new ForestPredictions(test.size(), nblabels); // test set error
+    MeanTreeCollector treeError = new MeanTreeCollector(test, nbtrees); // mean tree error
 
     // compute the test set error using m=1 (Single Input Error)
-    errorOne = new ForestPredictions(dataSize, nblabels);
+    errorOne = new ForestPredictions(test.size(), nblabels);
 
     if (oobM < oobOne) {
       forestM.classify(test, new MultiCallback(testError, treeError));