You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2012/11/15 09:19:27 UTC

svn commit: r1409685 - in /mahout/trunk/core/src/test: java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java resources/iris.csv

Author: tdunning
Date: Thu Nov 15 08:19:26 2012
New Revision: 1409685

URL: http://svn.apache.org/viewvc?rev=1409685&view=rev
Log:
MAHOUT-1113 - Added test case and data

Added:
    mahout/trunk/core/src/test/resources/iris.csv
Modified:
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=1409685&r1=1409684&r2=1409685&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java Thu Nov 15 08:19:26 2012
@@ -17,28 +17,37 @@
 
 package org.apache.mahout.classifier.sgd;
 
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.io.Resources;
 import org.apache.mahout.common.RandomUtils;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
 import org.junit.Test;
 
 import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
 import java.util.Random;
 
 public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
 
   /**
    * The CrossFoldLearner is probably the best learner to use for new applications.
-    * @throws IOException If test resources aren't readable.
+   *
+   * @throws IOException If test resources aren't readable.
    */
   @Test
   public void crossValidation() throws IOException {
     Vector target = readStandardData();
 
     CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
-            .lambda(1 * 1.0e-3)
-            .learningRate(50);
+      .lambda(1 * 1.0e-3)
+      .learningRate(50);
 
 
     train(getInput(), target, lr);
@@ -55,10 +64,10 @@ public final class OnlineLogisticRegress
 
     Matrix data = readCsv("cancer.csv");
     CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
-            .stepOffset(10)
-            .decayExponent(0.7)
-            .lambda(1 * 1.0e-3)
-            .learningRate(5);
+      .stepOffset(10)
+      .decayExponent(0.7)
+      .lambda(1 * 1.0e-3)
+      .learningRate(5);
     int k = 0;
     int[] ordering = permute(gen, data.numRows());
     for (int epoch = 0; epoch < 100; epoch++) {
@@ -132,6 +141,68 @@ public final class OnlineLogisticRegress
   }
 
   @Test
+  public void iris() throws IOException {
+    Splitter onComma = Splitter.on(",");
+
+    // read the data
+    List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);
+
+    // holds features
+    List<Vector> data = Lists.newArrayList();
+
+    // holds target variable
+    List<Integer> target = Lists.newArrayList();
+
+    // for decoding target values
+    Dictionary dict = new Dictionary();
+
+    // for permuting data later
+    List<Integer> order = Lists.newArrayList();
+    for (String line : raw.subList(1,raw.size())) {
+      order.add(order.size());
+
+      Vector v = new DenseVector(5);
+      v.set(0, 1);
+      int i = 1;
+      Iterable<String> values = onComma.split(line);
+      for (String value : Iterables.limit(values, 4)) {
+        v.set(i++, Double.parseDouble(value));
+      }
+      data.add(v);
+      target.add(dict.intern(Iterables.get(values, 4)));
+    }
+
+    Collections.shuffle(order);
+    List<Integer> train = order.subList(0, 100);
+    List<Integer> test = order.subList(100, 150);
+
+    int total = 0;
+    int correct = 0;
+    for (int run = 0; run < 10; run++) {
+      OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
+      for (int pass = 0; pass < 20; pass++) {
+        Collections.shuffle(train);
+        for (int k : train) {
+          lr.train(target.get(k), data.get(k));
+        }
+
+        int x = 0;
+        int[] count = new int[3];
+        for (Integer k : test) {
+          int r = lr.classifyFull(data.get(k)).maxValueIndex();
+          count[r]++;
+          x += r == target.get(k) ? 1 : 0;
+          total++;
+        }
+
+//        System.out.printf("%d\t%.0f\t%d\t%d\t%d\n", pass, 2.0 * x, count[0], count[1], count[2]);
+        correct += x;
+      }
+    }
+    assertTrue("Accuracy should be >= 90% but is " + correct, (100.0 * correct / total) >= 90);
+  }
+
+  @Test
   public void testTrain() throws Exception {
     Vector target = readStandardData();
 
@@ -142,8 +213,8 @@ public final class OnlineLogisticRegress
     // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias
     //   --target y --categories 2 --predictors  V2 V3 V4 V5 V6 V7 --types n
     OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
-            .lambda(1 * 1.0e-3)
-            .learningRate(50);
+      .lambda(1 * 1.0e-3)
+      .learningRate(50);
 
     train(getInput(), target, lr);
     test(getInput(), target, lr, 0.05, 0.3);

Added: mahout/trunk/core/src/test/resources/iris.csv
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/resources/iris.csv?rev=1409685&view=auto
==============================================================================
--- mahout/trunk/core/src/test/resources/iris.csv (added)
+++ mahout/trunk/core/src/test/resources/iris.csv Thu Nov 15 08:19:26 2012
@@ -0,0 +1,151 @@
+Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species
+5.1,3.5,1.4,0.2,setosa
+4.9,3.0,1.4,0.2,setosa
+4.7,3.2,1.3,0.2,setosa
+4.6,3.1,1.5,0.2,setosa
+5.0,3.6,1.4,0.2,setosa
+5.4,3.9,1.7,0.4,setosa
+4.6,3.4,1.4,0.3,setosa
+5.0,3.4,1.5,0.2,setosa
+4.4,2.9,1.4,0.2,setosa
+4.9,3.1,1.5,0.1,setosa
+5.4,3.7,1.5,0.2,setosa
+4.8,3.4,1.6,0.2,setosa
+4.8,3.0,1.4,0.1,setosa
+4.3,3.0,1.1,0.1,setosa
+5.8,4.0,1.2,0.2,setosa
+5.7,4.4,1.5,0.4,setosa
+5.4,3.9,1.3,0.4,setosa
+5.1,3.5,1.4,0.3,setosa
+5.7,3.8,1.7,0.3,setosa
+5.1,3.8,1.5,0.3,setosa
+5.4,3.4,1.7,0.2,setosa
+5.1,3.7,1.5,0.4,setosa
+4.6,3.6,1.0,0.2,setosa
+5.1,3.3,1.7,0.5,setosa
+4.8,3.4,1.9,0.2,setosa
+5.0,3.0,1.6,0.2,setosa
+5.0,3.4,1.6,0.4,setosa
+5.2,3.5,1.5,0.2,setosa
+5.2,3.4,1.4,0.2,setosa
+4.7,3.2,1.6,0.2,setosa
+4.8,3.1,1.6,0.2,setosa
+5.4,3.4,1.5,0.4,setosa
+5.2,4.1,1.5,0.1,setosa
+5.5,4.2,1.4,0.2,setosa
+4.9,3.1,1.5,0.2,setosa
+5.0,3.2,1.2,0.2,setosa
+5.5,3.5,1.3,0.2,setosa
+4.9,3.6,1.4,0.1,setosa
+4.4,3.0,1.3,0.2,setosa
+5.1,3.4,1.5,0.2,setosa
+5.0,3.5,1.3,0.3,setosa
+4.5,2.3,1.3,0.3,setosa
+4.4,3.2,1.3,0.2,setosa
+5.0,3.5,1.6,0.6,setosa
+5.1,3.8,1.9,0.4,setosa
+4.8,3.0,1.4,0.3,setosa
+5.1,3.8,1.6,0.2,setosa
+4.6,3.2,1.4,0.2,setosa
+5.3,3.7,1.5,0.2,setosa
+5.0,3.3,1.4,0.2,setosa
+7.0,3.2,4.7,1.4,versicolor
+6.4,3.2,4.5,1.5,versicolor
+6.9,3.1,4.9,1.5,versicolor
+5.5,2.3,4.0,1.3,versicolor
+6.5,2.8,4.6,1.5,versicolor
+5.7,2.8,4.5,1.3,versicolor
+6.3,3.3,4.7,1.6,versicolor
+4.9,2.4,3.3,1.0,versicolor
+6.6,2.9,4.6,1.3,versicolor
+5.2,2.7,3.9,1.4,versicolor
+5.0,2.0,3.5,1.0,versicolor
+5.9,3.0,4.2,1.5,versicolor
+6.0,2.2,4.0,1.0,versicolor
+6.1,2.9,4.7,1.4,versicolor
+5.6,2.9,3.6,1.3,versicolor
+6.7,3.1,4.4,1.4,versicolor
+5.6,3.0,4.5,1.5,versicolor
+5.8,2.7,4.1,1.0,versicolor
+6.2,2.2,4.5,1.5,versicolor
+5.6,2.5,3.9,1.1,versicolor
+5.9,3.2,4.8,1.8,versicolor
+6.1,2.8,4.0,1.3,versicolor
+6.3,2.5,4.9,1.5,versicolor
+6.1,2.8,4.7,1.2,versicolor
+6.4,2.9,4.3,1.3,versicolor
+6.6,3.0,4.4,1.4,versicolor
+6.8,2.8,4.8,1.4,versicolor
+6.7,3.0,5.0,1.7,versicolor
+6.0,2.9,4.5,1.5,versicolor
+5.7,2.6,3.5,1.0,versicolor
+5.5,2.4,3.8,1.1,versicolor
+5.5,2.4,3.7,1.0,versicolor
+5.8,2.7,3.9,1.2,versicolor
+6.0,2.7,5.1,1.6,versicolor
+5.4,3.0,4.5,1.5,versicolor
+6.0,3.4,4.5,1.6,versicolor
+6.7,3.1,4.7,1.5,versicolor
+6.3,2.3,4.4,1.3,versicolor
+5.6,3.0,4.1,1.3,versicolor
+5.5,2.5,4.0,1.3,versicolor
+5.5,2.6,4.4,1.2,versicolor
+6.1,3.0,4.6,1.4,versicolor
+5.8,2.6,4.0,1.2,versicolor
+5.0,2.3,3.3,1.0,versicolor
+5.6,2.7,4.2,1.3,versicolor
+5.7,3.0,4.2,1.2,versicolor
+5.7,2.9,4.2,1.3,versicolor
+6.2,2.9,4.3,1.3,versicolor
+5.1,2.5,3.0,1.1,versicolor
+5.7,2.8,4.1,1.3,versicolor
+6.3,3.3,6.0,2.5,virginica
+5.8,2.7,5.1,1.9,virginica
+7.1,3.0,5.9,2.1,virginica
+6.3,2.9,5.6,1.8,virginica
+6.5,3.0,5.8,2.2,virginica
+7.6,3.0,6.6,2.1,virginica
+4.9,2.5,4.5,1.7,virginica
+7.3,2.9,6.3,1.8,virginica
+6.7,2.5,5.8,1.8,virginica
+7.2,3.6,6.1,2.5,virginica
+6.5,3.2,5.1,2.0,virginica
+6.4,2.7,5.3,1.9,virginica
+6.8,3.0,5.5,2.1,virginica
+5.7,2.5,5.0,2.0,virginica
+5.8,2.8,5.1,2.4,virginica
+6.4,3.2,5.3,2.3,virginica
+6.5,3.0,5.5,1.8,virginica
+7.7,3.8,6.7,2.2,virginica
+7.7,2.6,6.9,2.3,virginica
+6.0,2.2,5.0,1.5,virginica
+6.9,3.2,5.7,2.3,virginica
+5.6,2.8,4.9,2.0,virginica
+7.7,2.8,6.7,2.0,virginica
+6.3,2.7,4.9,1.8,virginica
+6.7,3.3,5.7,2.1,virginica
+7.2,3.2,6.0,1.8,virginica
+6.2,2.8,4.8,1.8,virginica
+6.1,3.0,4.9,1.8,virginica
+6.4,2.8,5.6,2.1,virginica
+7.2,3.0,5.8,1.6,virginica
+7.4,2.8,6.1,1.9,virginica
+7.9,3.8,6.4,2.0,virginica
+6.4,2.8,5.6,2.2,virginica
+6.3,2.8,5.1,1.5,virginica
+6.1,2.6,5.6,1.4,virginica
+7.7,3.0,6.1,2.3,virginica
+6.3,3.4,5.6,2.4,virginica
+6.4,3.1,5.5,1.8,virginica
+6.0,3.0,4.8,1.8,virginica
+6.9,3.1,5.4,2.1,virginica
+6.7,3.1,5.6,2.4,virginica
+6.9,3.1,5.1,2.3,virginica
+5.8,2.7,5.1,1.9,virginica
+6.8,3.2,5.9,2.3,virginica
+6.7,3.3,5.7,2.5,virginica
+6.7,3.0,5.2,2.3,virginica
+6.3,2.5,5.0,1.9,virginica
+6.5,3.0,5.2,2.0,virginica
+6.2,3.4,5.4,2.3,virginica
+5.9,3.0,5.1,1.8,virginica
\ No newline at end of file