You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2009/12/22 10:08:44 UTC
svn commit: r893116 -
/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
Author: adeneche
Date: Tue Dec 22 09:08:44 2009
New Revision: 893116
URL: http://svn.apache.org/viewvc?rev=893116&view=rev
Log:
fixing the Breiman example
Modified:
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java?rev=893116&r1=893115&r2=893116&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java Tue Dec 22 09:08:44 2009
@@ -81,7 +81,6 @@
*/
protected static void runIteration(Data data, int m, int nbtrees) {
- int dataSize = data.size();
int nblabels = data.getDataset().nblabels();
Random rng = RandomUtils.getRandom();
@@ -91,13 +90,13 @@
int[] trainLabels = train.extractLabels();
int[] testLabels = test.extractLabels();
-
+
DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
SequentialBuilder forestBuilder = new SequentialBuilder(rng, treeBuilder, train);
// grow a forest with m = log2(M)+1
- ForestPredictions errorM = new ForestPredictions(dataSize, nblabels); // oob error when using m = log2(M)+1
+ ForestPredictions errorM = new ForestPredictions(train.size(), nblabels); // oob error when using m = log2(M)+1
treeBuilder.setM(m);
long time = System.currentTimeMillis();
@@ -108,7 +107,7 @@
double oobM = ErrorEstimate.errorRate(trainLabels, errorM.computePredictions(rng)); // oob error estimate when m = log2(M)+1
// grow a forest with m=1
- ForestPredictions errorOne = new ForestPredictions(dataSize, nblabels); // oob error when using m = 1
+ ForestPredictions errorOne = new ForestPredictions(train.size(), nblabels); // oob error when using m = 1
treeBuilder.setM(1);
time = System.currentTimeMillis();
@@ -120,11 +119,11 @@
// compute the test set error (Selection Error), and mean tree error (One Tree Error),
// using the lowest oob error forest
- ForestPredictions testError = new ForestPredictions(dataSize, nblabels); // test set error
- MeanTreeCollector treeError = new MeanTreeCollector(train, nbtrees); // mean tree error
+ ForestPredictions testError = new ForestPredictions(test.size(), nblabels); // test set error
+ MeanTreeCollector treeError = new MeanTreeCollector(test, nbtrees); // mean tree error
// compute the test set error using m=1 (Single Input Error)
- errorOne = new ForestPredictions(dataSize, nblabels);
+ errorOne = new ForestPredictions(test.size(), nblabels);
if (oobM < oobOne) {
forestM.classify(test, new MultiCallback(testError, treeError));