You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2012/12/15 01:35:14 UTC
svn commit: r1422165 -
/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
Author: tdunning
Date: Sat Dec 15 00:35:13 2012
New Revision: 1422165
URL: http://svn.apache.org/viewvc?rev=1422165&view=rev
Log:
MAHOUT-1127 - Fixed iris test to be better test (and be deterministic, too)
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=1422165&r1=1422164&r2=1422165&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java Sat Dec 15 00:35:13 2012
@@ -28,6 +28,8 @@ import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Collections;
@@ -35,6 +37,7 @@ import java.util.List;
import java.util.Random;
public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
+ Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class);
/**
* The CrossFoldLearner is probably the best learner to use for new applications.
@@ -142,6 +145,27 @@ public final class OnlineLogisticRegress
@Test
public void iris() throws IOException {
+ // this test trains a 3-way classifier on the famous Iris dataset.
+ // a similar exercise can be accomplished in R using this code:
+ // library(nnet)
+ // correct = rep(0,100)
+ // for (j in 1:100) {
+ // i = order(runif(150))
+ // train = iris[i[1:100],]
+ // test = iris[i[101:150],]
+ // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
+ // correct[j] = mean(predict(m, newdata=test) == test$Species)
+ // }
+ // hist(correct)
+ //
+ // Note that depending on the training/test split, performance can be better or worse.
+ // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
+ // of 100%
+ //
+ // This test uses a deterministic split that is neither outstandingly good nor bad
+
+
+ RandomUtils.useTestSeed();
Splitter onComma = Splitter.on(",");
// read the data
@@ -158,9 +182,12 @@ public final class OnlineLogisticRegress
// for permuting data later
List<Integer> order = Lists.newArrayList();
+
for (String line : raw.subList(1,raw.size())) {
+ // order gets a list of indexes
order.add(order.size());
+ // parse the predictor variables
Vector v = new DenseVector(5);
v.set(0, 1);
int i = 1;
@@ -169,37 +196,51 @@ public final class OnlineLogisticRegress
v.set(i++, Double.parseDouble(value));
}
data.add(v);
+
+ // and the target
target.add(dict.intern(Iterables.get(values, 4)));
}
- Collections.shuffle(order);
+ // randomize the order ... original data has each species all together
+ // note that this randomization is deterministic
+ Random random = RandomUtils.getRandom();
+ Collections.shuffle(order, random);
+
+ // select training and test data
List<Integer> train = order.subList(0, 100);
List<Integer> test = order.subList(100, 150);
+ logger.warn("Training set = " + train);
+ logger.warn("Test set = " + test);
- int total = 0;
- int correct = 0;
- for (int run = 0; run < 10; run++) {
+ // now train many times and collect information on accuracy each time
+ int[] correct = new int[test.size()];
+ for (int run = 0; run < 200; run++) {
OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
- for (int pass = 0; pass < 20; pass++) {
- Collections.shuffle(train);
+ // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
+ for (int pass = 0; pass < 30; pass++) {
+ Collections.shuffle(train, random);
for (int k : train) {
lr.train(target.get(k), data.get(k));
}
+ }
- int x = 0;
- int[] count = new int[3];
- for (Integer k : test) {
- int r = lr.classifyFull(data.get(k)).maxValueIndex();
- count[r]++;
- x += r == target.get(k) ? 1 : 0;
- total++;
- }
-
-// System.out.printf("%d\t%.0f\t%d\t%d\t%d\n", pass, 2.0 * x, count[0], count[1], count[2]);
- correct += x;
+ // check the accuracy on held out data
+ int x = 0;
+ int[] count = new int[3];
+ for (Integer k : test) {
+ int r = lr.classifyFull(data.get(k)).maxValueIndex();
+ count[r]++;
+ x += r == target.get(k) ? 1 : 0;
}
+ correct[x]++;
+ }
+
+ // verify we never saw worse than 95% correct,
+ for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
+ assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]);
}
- assertTrue("Accuracy should be >= 90% but is " + correct, (100.0 * correct / total) >= 90);
+ // nor perfect
+ assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size() - 1]);
}
@Test