You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2012/12/15 01:35:14 UTC

svn commit: r1422165 - /mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java

Author: tdunning
Date: Sat Dec 15 00:35:13 2012
New Revision: 1422165

URL: http://svn.apache.org/viewvc?rev=1422165&view=rev
Log:
MAHOUT-1127 - Fixed iris test to be better test (and be deterministic, too)

Modified:
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=1422165&r1=1422164&r2=1422165&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java Sat Dec 15 00:35:13 2012
@@ -28,6 +28,8 @@ import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.vectorizer.encoders.Dictionary;
 import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.util.Collections;
@@ -35,6 +37,7 @@ import java.util.List;
 import java.util.Random;
 
 public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
+  Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class);
 
   /**
    * The CrossFoldLearner is probably the best learner to use for new applications.
@@ -142,6 +145,27 @@ public final class OnlineLogisticRegress
 
   @Test
   public void iris() throws IOException {
+    // this test trains a 3-way classifier on the famous Iris dataset.
+    // a similar exercise can be accomplished in R using this code:
+    //    library(nnet)
+    //    correct = rep(0,100)
+    //    for (j in 1:100) {
+    //      i = order(runif(150))
+    //      train = iris[i[1:100],]
+    //      test = iris[i[101:150],]
+    //      m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
+    //      correct[j] = mean(predict(m, newdata=test) == test$Species)
+    //    }
+    //    hist(correct)
+    //
+    // Note that depending on the training/test split, performance can be better or worse.
+    // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
+    // of 100%
+    //
+    // This test uses a deterministic split that is neither outstandingly good nor bad
+
+
+    RandomUtils.useTestSeed();
     Splitter onComma = Splitter.on(",");
 
     // read the data
@@ -158,9 +182,12 @@ public final class OnlineLogisticRegress
 
     // for permuting data later
     List<Integer> order = Lists.newArrayList();
+
     for (String line : raw.subList(1,raw.size())) {
+      // order gets a list of indexes
       order.add(order.size());
 
+      // parse the predictor variables
       Vector v = new DenseVector(5);
       v.set(0, 1);
       int i = 1;
@@ -169,37 +196,51 @@ public final class OnlineLogisticRegress
         v.set(i++, Double.parseDouble(value));
       }
       data.add(v);
+
+      // and the target
       target.add(dict.intern(Iterables.get(values, 4)));
     }
 
-    Collections.shuffle(order);
+    // randomize the order ... original data has each species all together
+    // note that this randomization is deterministic
+    Random random = RandomUtils.getRandom();
+    Collections.shuffle(order, random);
+
+    // select training and test data
     List<Integer> train = order.subList(0, 100);
     List<Integer> test = order.subList(100, 150);
+    logger.warn("Training set = " + train);
+    logger.warn("Test set = " + test);
 
-    int total = 0;
-    int correct = 0;
-    for (int run = 0; run < 10; run++) {
+    // now train many times and collect information on accuracy each time
+    int[] correct = new int[test.size()];
+    for (int run = 0; run < 200; run++) {
       OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
-      for (int pass = 0; pass < 20; pass++) {
-        Collections.shuffle(train);
+      // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
+      for (int pass = 0; pass < 30; pass++) {
+        Collections.shuffle(train, random);
         for (int k : train) {
           lr.train(target.get(k), data.get(k));
         }
+      }
 
-        int x = 0;
-        int[] count = new int[3];
-        for (Integer k : test) {
-          int r = lr.classifyFull(data.get(k)).maxValueIndex();
-          count[r]++;
-          x += r == target.get(k) ? 1 : 0;
-          total++;
-        }
-
-//        System.out.printf("%d\t%.0f\t%d\t%d\t%d\n", pass, 2.0 * x, count[0], count[1], count[2]);
-        correct += x;
+      // check the accuracy on held out data
+      int x = 0;
+      int[] count = new int[3];
+      for (Integer k : test) {
+        int r = lr.classifyFull(data.get(k)).maxValueIndex();
+        count[r]++;
+        x += r == target.get(k) ? 1 : 0;
       }
+      correct[x]++;
+    }
+
+    // verify we never saw worse than 95% correct,
+    for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
+      assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]);
     }
-    assertTrue("Accuracy should be >= 90% but is " + correct, (100.0 * correct / total) >= 90);
+    // nor perfect
+    assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size() - 1]);
   }
 
   @Test