You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2012/11/15 09:19:27 UTC
svn commit: r1409685 - in /mahout/trunk/core/src/test:
java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
resources/iris.csv
Author: tdunning
Date: Thu Nov 15 08:19:26 2012
New Revision: 1409685
URL: http://svn.apache.org/viewvc?rev=1409685&view=rev
Log:
MAHOUT-1113 - Added test case and data
Added:
mahout/trunk/core/src/test/resources/iris.csv
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=1409685&r1=1409684&r2=1409685&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java Thu Nov 15 08:19:26 2012
@@ -17,28 +17,37 @@
package org.apache.mahout.classifier.sgd;
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.io.Resources;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
import org.junit.Test;
import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
import java.util.Random;
public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
/**
* The CrossFoldLearner is probably the best learner to use for new applications.
- * @throws IOException If test resources aren't readable.
+ *
+ * @throws IOException If test resources aren't readable.
*/
@Test
public void crossValidation() throws IOException {
Vector target = readStandardData();
CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
- .lambda(1 * 1.0e-3)
- .learningRate(50);
+ .lambda(1 * 1.0e-3)
+ .learningRate(50);
train(getInput(), target, lr);
@@ -55,10 +64,10 @@ public final class OnlineLogisticRegress
Matrix data = readCsv("cancer.csv");
CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
- .stepOffset(10)
- .decayExponent(0.7)
- .lambda(1 * 1.0e-3)
- .learningRate(5);
+ .stepOffset(10)
+ .decayExponent(0.7)
+ .lambda(1 * 1.0e-3)
+ .learningRate(5);
int k = 0;
int[] ordering = permute(gen, data.numRows());
for (int epoch = 0; epoch < 100; epoch++) {
@@ -132,6 +141,68 @@ public final class OnlineLogisticRegress
}
@Test
+ public void iris() throws IOException {
+ Splitter onComma = Splitter.on(",");
+
+ // read the data
+ List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);
+
+ // holds features
+ List<Vector> data = Lists.newArrayList();
+
+ // holds target variable
+ List<Integer> target = Lists.newArrayList();
+
+ // for decoding target values
+ Dictionary dict = new Dictionary();
+
+ // for permuting data later
+ List<Integer> order = Lists.newArrayList();
+ for (String line : raw.subList(1,raw.size())) {
+ order.add(order.size());
+
+ Vector v = new DenseVector(5);
+ v.set(0, 1);
+ int i = 1;
+ Iterable<String> values = onComma.split(line);
+ for (String value : Iterables.limit(values, 4)) {
+ v.set(i++, Double.parseDouble(value));
+ }
+ data.add(v);
+ target.add(dict.intern(Iterables.get(values, 4)));
+ }
+
+ Collections.shuffle(order);
+ List<Integer> train = order.subList(0, 100);
+ List<Integer> test = order.subList(100, 150);
+
+ int total = 0;
+ int correct = 0;
+ for (int run = 0; run < 10; run++) {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
+ for (int pass = 0; pass < 20; pass++) {
+ Collections.shuffle(train);
+ for (int k : train) {
+ lr.train(target.get(k), data.get(k));
+ }
+
+ int x = 0;
+ int[] count = new int[3];
+ for (Integer k : test) {
+ int r = lr.classifyFull(data.get(k)).maxValueIndex();
+ count[r]++;
+ x += r == target.get(k) ? 1 : 0;
+ total++;
+ }
+
+// System.out.printf("%d\t%.0f\t%d\t%d\t%d\n", pass, 2.0 * x, count[0], count[1], count[2]);
+ correct += x;
+ }
+ }
+ assertTrue("Accuracy should be >= 90% but is " + correct, (100.0 * correct / total) >= 90);
+ }
+
+ @Test
public void testTrain() throws Exception {
Vector target = readStandardData();
@@ -142,8 +213,8 @@ public final class OnlineLogisticRegress
// --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias
// --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n
OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
- .lambda(1 * 1.0e-3)
- .learningRate(50);
+ .lambda(1 * 1.0e-3)
+ .learningRate(50);
train(getInput(), target, lr);
test(getInput(), target, lr, 0.05, 0.3);
Added: mahout/trunk/core/src/test/resources/iris.csv
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/resources/iris.csv?rev=1409685&view=auto
==============================================================================
--- mahout/trunk/core/src/test/resources/iris.csv (added)
+++ mahout/trunk/core/src/test/resources/iris.csv Thu Nov 15 08:19:26 2012
@@ -0,0 +1,151 @@
+Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species
+5.1,3.5,1.4,0.2,setosa
+4.9,3.0,1.4,0.2,setosa
+4.7,3.2,1.3,0.2,setosa
+4.6,3.1,1.5,0.2,setosa
+5.0,3.6,1.4,0.2,setosa
+5.4,3.9,1.7,0.4,setosa
+4.6,3.4,1.4,0.3,setosa
+5.0,3.4,1.5,0.2,setosa
+4.4,2.9,1.4,0.2,setosa
+4.9,3.1,1.5,0.1,setosa
+5.4,3.7,1.5,0.2,setosa
+4.8,3.4,1.6,0.2,setosa
+4.8,3.0,1.4,0.1,setosa
+4.3,3.0,1.1,0.1,setosa
+5.8,4.0,1.2,0.2,setosa
+5.7,4.4,1.5,0.4,setosa
+5.4,3.9,1.3,0.4,setosa
+5.1,3.5,1.4,0.3,setosa
+5.7,3.8,1.7,0.3,setosa
+5.1,3.8,1.5,0.3,setosa
+5.4,3.4,1.7,0.2,setosa
+5.1,3.7,1.5,0.4,setosa
+4.6,3.6,1.0,0.2,setosa
+5.1,3.3,1.7,0.5,setosa
+4.8,3.4,1.9,0.2,setosa
+5.0,3.0,1.6,0.2,setosa
+5.0,3.4,1.6,0.4,setosa
+5.2,3.5,1.5,0.2,setosa
+5.2,3.4,1.4,0.2,setosa
+4.7,3.2,1.6,0.2,setosa
+4.8,3.1,1.6,0.2,setosa
+5.4,3.4,1.5,0.4,setosa
+5.2,4.1,1.5,0.1,setosa
+5.5,4.2,1.4,0.2,setosa
+4.9,3.1,1.5,0.2,setosa
+5.0,3.2,1.2,0.2,setosa
+5.5,3.5,1.3,0.2,setosa
+4.9,3.6,1.4,0.1,setosa
+4.4,3.0,1.3,0.2,setosa
+5.1,3.4,1.5,0.2,setosa
+5.0,3.5,1.3,0.3,setosa
+4.5,2.3,1.3,0.3,setosa
+4.4,3.2,1.3,0.2,setosa
+5.0,3.5,1.6,0.6,setosa
+5.1,3.8,1.9,0.4,setosa
+4.8,3.0,1.4,0.3,setosa
+5.1,3.8,1.6,0.2,setosa
+4.6,3.2,1.4,0.2,setosa
+5.3,3.7,1.5,0.2,setosa
+5.0,3.3,1.4,0.2,setosa
+7.0,3.2,4.7,1.4,versicolor
+6.4,3.2,4.5,1.5,versicolor
+6.9,3.1,4.9,1.5,versicolor
+5.5,2.3,4.0,1.3,versicolor
+6.5,2.8,4.6,1.5,versicolor
+5.7,2.8,4.5,1.3,versicolor
+6.3,3.3,4.7,1.6,versicolor
+4.9,2.4,3.3,1.0,versicolor
+6.6,2.9,4.6,1.3,versicolor
+5.2,2.7,3.9,1.4,versicolor
+5.0,2.0,3.5,1.0,versicolor
+5.9,3.0,4.2,1.5,versicolor
+6.0,2.2,4.0,1.0,versicolor
+6.1,2.9,4.7,1.4,versicolor
+5.6,2.9,3.6,1.3,versicolor
+6.7,3.1,4.4,1.4,versicolor
+5.6,3.0,4.5,1.5,versicolor
+5.8,2.7,4.1,1.0,versicolor
+6.2,2.2,4.5,1.5,versicolor
+5.6,2.5,3.9,1.1,versicolor
+5.9,3.2,4.8,1.8,versicolor
+6.1,2.8,4.0,1.3,versicolor
+6.3,2.5,4.9,1.5,versicolor
+6.1,2.8,4.7,1.2,versicolor
+6.4,2.9,4.3,1.3,versicolor
+6.6,3.0,4.4,1.4,versicolor
+6.8,2.8,4.8,1.4,versicolor
+6.7,3.0,5.0,1.7,versicolor
+6.0,2.9,4.5,1.5,versicolor
+5.7,2.6,3.5,1.0,versicolor
+5.5,2.4,3.8,1.1,versicolor
+5.5,2.4,3.7,1.0,versicolor
+5.8,2.7,3.9,1.2,versicolor
+6.0,2.7,5.1,1.6,versicolor
+5.4,3.0,4.5,1.5,versicolor
+6.0,3.4,4.5,1.6,versicolor
+6.7,3.1,4.7,1.5,versicolor
+6.3,2.3,4.4,1.3,versicolor
+5.6,3.0,4.1,1.3,versicolor
+5.5,2.5,4.0,1.3,versicolor
+5.5,2.6,4.4,1.2,versicolor
+6.1,3.0,4.6,1.4,versicolor
+5.8,2.6,4.0,1.2,versicolor
+5.0,2.3,3.3,1.0,versicolor
+5.6,2.7,4.2,1.3,versicolor
+5.7,3.0,4.2,1.2,versicolor
+5.7,2.9,4.2,1.3,versicolor
+6.2,2.9,4.3,1.3,versicolor
+5.1,2.5,3.0,1.1,versicolor
+5.7,2.8,4.1,1.3,versicolor
+6.3,3.3,6.0,2.5,virginica
+5.8,2.7,5.1,1.9,virginica
+7.1,3.0,5.9,2.1,virginica
+6.3,2.9,5.6,1.8,virginica
+6.5,3.0,5.8,2.2,virginica
+7.6,3.0,6.6,2.1,virginica
+4.9,2.5,4.5,1.7,virginica
+7.3,2.9,6.3,1.8,virginica
+6.7,2.5,5.8,1.8,virginica
+7.2,3.6,6.1,2.5,virginica
+6.5,3.2,5.1,2.0,virginica
+6.4,2.7,5.3,1.9,virginica
+6.8,3.0,5.5,2.1,virginica
+5.7,2.5,5.0,2.0,virginica
+5.8,2.8,5.1,2.4,virginica
+6.4,3.2,5.3,2.3,virginica
+6.5,3.0,5.5,1.8,virginica
+7.7,3.8,6.7,2.2,virginica
+7.7,2.6,6.9,2.3,virginica
+6.0,2.2,5.0,1.5,virginica
+6.9,3.2,5.7,2.3,virginica
+5.6,2.8,4.9,2.0,virginica
+7.7,2.8,6.7,2.0,virginica
+6.3,2.7,4.9,1.8,virginica
+6.7,3.3,5.7,2.1,virginica
+7.2,3.2,6.0,1.8,virginica
+6.2,2.8,4.8,1.8,virginica
+6.1,3.0,4.9,1.8,virginica
+6.4,2.8,5.6,2.1,virginica
+7.2,3.0,5.8,1.6,virginica
+7.4,2.8,6.1,1.9,virginica
+7.9,3.8,6.4,2.0,virginica
+6.4,2.8,5.6,2.2,virginica
+6.3,2.8,5.1,1.5,virginica
+6.1,2.6,5.6,1.4,virginica
+7.7,3.0,6.1,2.3,virginica
+6.3,3.4,5.6,2.4,virginica
+6.4,3.1,5.5,1.8,virginica
+6.0,3.0,4.8,1.8,virginica
+6.9,3.1,5.4,2.1,virginica
+6.7,3.1,5.6,2.4,virginica
+6.9,3.1,5.1,2.3,virginica
+5.8,2.7,5.1,1.9,virginica
+6.8,3.2,5.9,2.3,virginica
+6.7,3.3,5.7,2.5,virginica
+6.7,3.0,5.2,2.3,virginica
+6.3,2.5,5.0,1.9,virginica
+6.5,3.0,5.2,2.0,virginica
+6.2,3.4,5.4,2.3,virginica
+5.9,3.0,5.1,1.8,virginica
\ No newline at end of file