You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/08/16 18:56:47 UTC
svn commit: r986045 [2/2] - in /mahout/trunk:
core/src/main/java/org/apache/mahout/classifier/
core/src/main/java/org/apache/mahout/classifier/sgd/
core/src/test/java/org/apache/mahout/classifier/
core/src/test/java/org/apache/mahout/classifier/sgd/ ma...
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+public class ContinuousValueEncoderTest {
+ @Test
+ public void testAddToVector() {
+ FeatureVectorEncoder enc = new ContinuousValueEncoder("foo");
+ Vector v1 = new DenseVector(20);
+ enc.addToVector("-123", v1);
+ assertEquals(-123, v1.minValue(), 0);
+ assertEquals(0, v1.maxValue(), 0);
+ assertEquals(123, v1.norm(1), 0);
+
+ v1 = new DenseVector(20);
+ enc.addToVector("123", v1);
+ assertEquals(123, v1.maxValue(), 0);
+ assertEquals(0, v1.minValue(), 0);
+ assertEquals(123, v1.norm(1), 0);
+
+ Vector v2 = new DenseVector(20);
+ enc.setProbes(2);
+ enc.addToVector("123", v2);
+ assertEquals(123, v2.maxValue(), 0);
+ assertEquals(2 * 123, v2.norm(1), 0);
+
+ v1 = v2.minus(v1);
+ assertEquals(123, v1.maxValue(), 0);
+ assertEquals(123, v1.norm(1), 0);
+
+ Vector v3 = new DenseVector(20);
+ enc.setProbes(2);
+ enc.addToVector("100", v3);
+ v1 = v2.minus(v3);
+ assertEquals(23, v1.maxValue(), 0);
+ assertEquals(2 * 23, v1.norm(1), 0);
+
+ enc.addToVector("7", v1);
+ assertEquals(30, v1.maxValue(), 0);
+ assertEquals(2 * 30, v1.norm(1), 0);
+ assertEquals(30, v1.get(10), 0);
+ assertEquals(30, v1.get(18), 0);
+
+ try {
+ enc.addToVector("foobar", v1);
+ fail("Should have noticed back numeric format");
+ } catch (NumberFormatException e) {
+ assertEquals("For input string: \"foobar\"", e.getMessage());
+ }
+ }
+
+ @Test
+ public void testAsString() {
+ ContinuousValueEncoder enc = new ContinuousValueEncoder("foo");
+ assertEquals("foo:123", enc.asString("123"));
+ }
+
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.ImmutableMap;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class CsvRecordFactoryTest {
+ @Test
+ public void testAddToVector() {
+ CsvRecordFactory csv = new CsvRecordFactory("y", ImmutableMap.of("x1", "n", "x2", "w", "x3", "t"));
+ csv.firstLine("z,x1,y,x2,x3,q");
+ csv.maxTargetValue(2);
+
+ Vector v = new DenseVector(2000);
+ int t = csv.processLine("ignore,3.1,yes,tiger, \"this is text\",ignore", v);
+ assertEquals(0, t);
+ // should have 9 values set
+ assertEquals(9.0, v.norm(0), 0);
+ // all should be = 1 except for the 3.1
+ assertEquals(3.1, v.maxValue(), 0);
+ v.set(v.maxValueIndex(), 0);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(8.0, v.norm(1), 0);
+ assertEquals(1.0, v.maxValue(), 0);
+
+ v.assign(0);
+ t = csv.processLine("ignore,5.3,no,line, \"and more text and more\",ignore", v);
+ assertEquals(1, t);
+
+ // should have 9 values set
+ assertEquals(9.0, v.norm(0), 0);
+ // all should be = 1 except for the 3.1
+ assertEquals(5.3, v.maxValue(), 0);
+ v.set(v.maxValueIndex(), 0);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(12.0, v.norm(1), 0);
+ assertEquals(2, v.maxValue(), 0);
+
+ v.assign(0);
+ t = csv.processLine("ignore,5.3,invalid,line, \"and more text and more\",ignore", v);
+ assertEquals(1, t);
+
+ // should have 9 values set
+ assertEquals(9.0, v.norm(0), 0);
+ // all should be = 1 except for the 3.1
+ assertEquals(5.3, v.maxValue(), 0);
+ v.set(v.maxValueIndex(), 0);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(12.0, v.norm(1), 0);
+ assertEquals(2, v.maxValue(), 0);
+ }
+
+ @Test
+ public void testDictionaryOrder() {
+ Dictionary dict = new Dictionary();
+
+ dict.intern("a");
+ dict.intern("d");
+ dict.intern("c");
+ dict.intern("b");
+ dict.intern("qrz");
+
+ Assert.assertEquals("[a, d, c, b, qrz]", dict.values().toString());
+
+ Dictionary dict2 = Dictionary.fromList(dict.values());
+ Assert.assertEquals("[a, d, c, b, qrz]", dict2.values().toString());
+
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.CharStreams;
+import com.google.common.io.Resources;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.junit.Assert.assertEquals;
+
+public class OnlineLogisticRegressionTest {
+ private Matrix input;
+
+ /**
+ * The CrossFoldLearner is probably the best learner to use for new applications.
+ * @throws IOException If test resources aren't readable.
+ */
+ @Test
+ public void crossValidation() throws IOException {
+ Vector target = readStandardData();
+
+ CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
+ .lambda(1 * 1e-3)
+ .learningRate(50);
+
+
+ train(input, target, lr);
+
+ System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
+ test(input, target, lr);
+
+ }
+
+ @Test
+ public void crossValidatedAuc() throws IOException {
+ RandomUtils.useTestSeed();
+ Random gen = RandomUtils.getRandom();
+
+ Matrix data = readCsv("cancer.csv");
+ CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
+ .stepOffset(10)
+ .decayExponent(0.7)
+ .lambda(1 * 1e-3)
+ .learningRate(5);
+ int k = 0;
+ int[] ordering = permute(gen, data.numRows());
+ for (int epoch = 0; epoch < 100; epoch++) {
+ for (int row : ordering) {
+ lr.train(row, (int) data.get(row, 9), data.viewRow(row));
+ System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc());
+ }
+ }
+ }
+
+ /**
+ * Verifies that a classifier with known coefficients does the right thing.
+ */
+ @Test
+ public void testClassify() {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1));
+ // set up some internal coefficients as if we had learned them
+ lr.setBeta(0, 0, -1);
+ lr.setBeta(1, 0, -2);
+
+ // zero vector gives no information. All classes are equal.
+ Vector v = lr.classify(new DenseVector(new double[]{0, 0}));
+ assertEquals(1 / 3.0, v.get(0), 1e-8);
+ assertEquals(1 / 3.0, v.get(1), 1e-8);
+
+ v = lr.classifyFull(new DenseVector(new double[]{0, 0}));
+ assertEquals(1.0, v.zSum(), 1e-8);
+ assertEquals(1 / 3.0, v.get(0), 1e-8);
+ assertEquals(1 / 3.0, v.get(1), 1e-8);
+ assertEquals(1 / 3.0, v.get(2), 1e-8);
+
+ // weights for second vector component are still zero so all classifications are equally likely
+ v = lr.classify(new DenseVector(new double[]{0, 1}));
+ assertEquals(1 / 3.0, v.get(0), 1e-3);
+ assertEquals(1 / 3.0, v.get(1), 1e-3);
+
+ v = lr.classifyFull(new DenseVector(new double[]{0, 1}));
+ assertEquals(1.0, v.zSum(), 1e-8);
+ assertEquals(1 / 3.0, v.get(0), 1e-3);
+ assertEquals(1 / 3.0, v.get(1), 1e-3);
+ assertEquals(1 / 3.0, v.get(2), 1e-3);
+
+ // but the weights on the first component are non-zero
+ v = lr.classify(new DenseVector(new double[]{1, 0}));
+ assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1e-8);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1e-8);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 0}));
+ assertEquals(1.0, v.zSum(), 1e-8);
+ assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1e-8);
+ assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1e-8);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1e-8);
+
+ lr.setBeta(0, 1, 1);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+ assertEquals(1.0, v.zSum(), 1e-8);
+ assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1e-3);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1e-3);
+ assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1e-3);
+
+ lr.setBeta(1, 1, 3);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+ assertEquals(1.0, v.zSum(), 1e-8);
+ assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1e-8);
+ assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1e-8);
+ assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1e-8);
+ }
+
+ @Test
+ public void testTrain() throws IOException {
+ Vector target = readStandardData();
+
+
+ // lambda here needs to be relatively small to avoid swamping the actual signal, but can be
+ // larger than usual because the data are dense. The learning rate doesn't matter too much
+ // for this example, but should generally be < 1
+ // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+ .lambda(1 * 1e-3)
+ .learningRate(50);
+
+ train(input, target, lr);
+ test(input, target, lr);
+ }
+
+ private Vector readStandardData() throws IOException {
+ // 60 test samples. First column is constant. Second and third are normally distributed from
+ // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The first 30 rows have a
+ // target variable of 0, the last 30 a target of 1. The remaining columns are are random noise.
+ input = readCsv("sgd.csv");
+
+ // regenerate the target variable
+ Vector target = new DenseVector(60);
+ target.assign(0);
+ target.viewPart(30, 30).assign(1);
+ return target;
+ }
+
+ private void train(Matrix input, Vector target, OnlineLearner lr) {
+ RandomUtils.useTestSeed();
+ Random gen = RandomUtils.getRandom();
+
+ // train on samples in random order (but only one pass)
+ for (int row : permute(gen, 60)) {
+ lr.train((int) target.get(row), input.getRow(row));
+ }
+ lr.close();
+ }
+
+ private void test(Matrix input, Vector target, AbstractVectorClassifier lr) {
+ // now test the accuracy
+ Matrix tmp = lr.classify(input);
+ // mean(abs(tmp - target))
+ double meanAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.plus, Functions.abs) / 60;
+
+ // max(abs(tmp - target)
+ double maxAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.max, Functions.abs);
+
+ System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
+ assertEquals(0, meanAbsoluteError , 0.05);
+ assertEquals(0, maxAbsoluteError, 0.3);
+
+ // convenience methods should give the same results
+ Vector v = lr.classifyScalar(input);
+ assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1e-5);
+ v = lr.classifyFull(input).getColumn(1);
+ assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1e-4);
+ }
+
+ /**
+ * Permute the integers from 0 ... max-1
+ *
+ * @param gen The random number generator to use.
+ * @param max The number of integers to permute
+ * @return An array of jumbled integer values
+ */
+ private int[] permute(Random gen, int max) {
+ int[] permutation = new int[max];
+ permutation[0] = 0;
+ for (int i = 1; i < max; i++) {
+ int n = gen.nextInt(i + 1);
+ if (n != i) {
+ permutation[i] = permutation[n];
+ permutation[n] = i;
+ } else {
+ permutation[i] = i;
+ }
+ }
+ return permutation;
+ }
+
+
+ /**
+ * Reads a file containing CSV data. This isn't implemented quite the way you might like for a
+ * real program, but does the job for reading test data. Most notably, it will only read numbers,
+ * not quoted strings.
+ *
+ * @param resourceName Where to get the data.
+ * @return A matrix of the results.
+ * @throws java.io.IOException If there is an error reading the data
+ */
+ private Matrix readCsv(String resourceName) throws IOException {
+ Splitter onCommas = Splitter.on(",").trimResults(CharMatcher.anyOf(" \""));
+
+ InputStreamReader isr = new InputStreamReader(Resources.getResource(resourceName).openStream());
+ List<String> data = CharStreams.readLines(isr);
+ String first = data.get(0);
+ data = data.subList(1, data.size());
+
+ List<String> values = Lists.newArrayList(onCommas.split(first));
+ Matrix r = new DenseMatrix(data.size(), values.size());
+
+ int column = 0;
+ Map<String, Integer> labels = Maps.newHashMap();
+ for (String value : values) {
+ labels.put(value, column);
+ column++;
+ }
+ r.setColumnLabelBindings(labels);
+
+ int row = 0;
+ for (String line : data) {
+ column = 0;
+ values = Lists.newArrayList(onCommas.split(line));
+ for (String value : values) {
+ r.set(row, column, Double.parseDouble(value));
+ column++;
+ }
+ row++;
+ }
+
+ return r;
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.ImmutableMap;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class TextValueEncoderTest {
+ @Test
+ public void testAddToVector() {
+ TextValueEncoder enc = new TextValueEncoder("text");
+ Vector v1 = new DenseVector(200);
+ enc.addToVector("test1 and more", v1);
+ // should set 6 distinct locations to 1
+ assertEquals(6.0, v1.norm(1), 0);
+ assertEquals(1.0, v1.maxValue(), 0);
+
+ // now some fancy weighting
+ StaticWordValueEncoder w = new StaticWordValueEncoder("text");
+ w.setDictionary(ImmutableMap.<String, Double>of("word1", 3.0, "word2", 1.5));
+ enc.setWordEncoder(w);
+
+ // should set 6 locations to something
+ Vector v2 = new DenseVector(200);
+ enc.addToVector("test1 and more", v2);
+
+ // this should set the same 6 locations to the same values
+ Vector v3 = new DenseVector(200);
+ w.addToVector("test1", v3);
+ w.addToVector("and", v3);
+ w.addToVector("more", v3);
+
+ assertEquals(0, v3.minus(v2).norm(1), 0);
+
+ // moreover, the locations set in the unweighted case should be the same as in the weighted case
+ assertEquals(v3.zSum(), v3.dot(v1), 0);
+ }
+
+ @Test
+ public void testAsString() {
+ TextValueEncoder enc = new TextValueEncoder("text");
+ assertEquals("[text:test1:1.0000, text:and:1.0000, text:more:1.0000]", enc.asString("test1 and more"));
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.util.Iterator;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+public class WordLikeValueEncoderTest {
+ @Test
+ public void testAddToVector() {
+ FeatureVectorEncoder enc = new StaticWordValueEncoder("word");
+ Vector v = new DenseVector(200);
+ enc.addToVector("word1", v);
+ enc.addToVector("word2", v);
+ Iterator<Vector.Element> i = v.iterateNonZero();
+ Iterator<Integer> j = ImmutableList.of(7, 118, 119, 199).iterator();
+ while (i.hasNext()) {
+ Vector.Element element = i.next();
+ assertEquals(j.next().intValue(), element.index());
+ assertEquals(1, element.get(), 0);
+ }
+ assertFalse(j.hasNext());
+ }
+
+ @Test
+ public void testAsString() {
+ FeatureVectorEncoder enc = new StaticWordValueEncoder("word");
+ assertEquals("word:w1:1.0000", enc.asString("w1"));
+ }
+
+ @Test
+ public void testStaticWeights() {
+ StaticWordValueEncoder enc = new StaticWordValueEncoder("word");
+ enc.setDictionary(ImmutableMap.<String, Double>of("word1", 3.0, "word2", 1.5));
+ Vector v = new DenseVector(200);
+ enc.addToVector("word1", v);
+ enc.addToVector("word2", v);
+ enc.addToVector("word3", v);
+ Iterator<Vector.Element> i = v.iterateNonZero();
+ Iterator<Integer> j = ImmutableList.of(7, 101, 118, 119, 152, 199).iterator();
+ Iterator<Double> k = ImmutableList.of(3.0, 0.75, 1.5, 1.5, 0.75, 3.0).iterator();
+ while (i.hasNext()) {
+ Vector.Element element = i.next();
+ assertEquals(j.next().intValue(), element.index());
+ assertEquals(k.next(), element.get(), 0);
+ }
+ assertFalse(j.hasNext());
+ }
+
+ @Test
+ public void testDynamicWeights() {
+ FeatureVectorEncoder enc = new AdaptiveWordValueEncoder("word");
+ Vector v = new DenseVector(200);
+ enc.addToVector("word1", v); // weight is log(2/1.5)
+ enc.addToVector("word2", v); // weight is log(3.5 / 1.5)
+ enc.addToVector("word1", v); // weight is log(4.5 / 2.5) (but overlays on first value)
+ enc.addToVector("word3", v); // weight is log(6 / 1.5)
+ Iterator<Vector.Element> i = v.iterateNonZero();
+ Iterator<Integer> j = ImmutableList.of(7, 101, 118, 119, 152, 199).iterator();
+ Iterator<Double> k = ImmutableList.of(Math.log(2 / 1.5) + Math.log(4.5 / 2.5), Math.log(6 / 1.5), Math.log(3.5 / 1.5), Math.log(3.5 / 1.5), Math.log(6 / 1.5), Math.log(2 / 1.5) + Math.log(4.5 / 2.5)).iterator();
+ while (i.hasNext()) {
+ Vector.Element element = i.next();
+ assertEquals(j.next().intValue(), element.index());
+ assertEquals(k.next(), element.get(), 1e-6);
+ }
+ assertFalse(j.hasNext());
+ }
+}
Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,100 @@
+package org.apache.mahout.math.stats;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+import java.util.Random;
+
+/**
+ * Computes a running estimate of AUC (see http://en.wikipedia.org/wiki/Receiver_operating_characteristic).
+ * <p/>
+ * Since AUC is normally a global property of labeled scores, it is almost always computed in a
+ * batch fashion. The probabilistic definition (the probability that a random element of one set
+ * has a higher score than a random element of another set) gives us a way to estimate this
+ * on-line.
+ */
+public class OnlineAuc {
+ private Random random = new Random();
+
+ enum ReplacementPolicy {
+ FIFO, FAIR, RANDOM
+ }
+
+ public static final int HISTORY = 10;
+
+ private ReplacementPolicy policy = ReplacementPolicy.FAIR;
+
+ private Matrix scores;
+ private Vector averages;
+
+ private Vector samples;
+
+ public OnlineAuc() {
+ int numCategories = 2;
+ scores = new DenseMatrix(numCategories, HISTORY);
+ scores.assign(Double.NaN);
+ averages = new DenseVector(numCategories);
+ averages.assign(0.5);
+ samples = new DenseVector(numCategories);
+ }
+
+ public double addSample(int category, final double score) {
+ int n = (int) samples.get(category);
+ if (n < HISTORY) {
+ scores.set(category, n, score);
+ } else {
+ switch (policy) {
+ case FIFO:
+ scores.set(category, n % HISTORY, score);
+ break;
+ case FAIR:
+ int j = random.nextInt(n + 1);
+ if (j < HISTORY) {
+ scores.set(category, j, score);
+ }
+ break;
+ case RANDOM:
+ j = random.nextInt(HISTORY);
+ scores.set(category, j, score);
+ break;
+ }
+ }
+
+ samples.set(category, n + 1);
+
+ if (samples.minValue() >= 1) {
+ // compare to previous scores for other category
+ Vector row = scores.viewRow(1 - category);
+ double m = 0;
+ int count = 0;
+ for (Vector.Element element : row) {
+ double v = element.get();
+ if (!Double.isNaN(v)) {
+ count++;
+ double z = 0.5;
+ if (score > v) {
+ z = 1;
+ } else if (score < v) {
+ z = 0;
+ }
+ m += (z - m) / count;
+ } else {
+ break;
+ }
+ }
+ averages.set(category, averages.get(category) + (m - averages.get(category)) / samples.get(category));
+ }
+ return auc();
+ }
+
+ public double auc() {
+ // return an unweighted average of all averages.
+ return 0.5 - averages.get(0) / 2 + averages.get(1) / 2;
+ }
+
+ public void setPolicy(ReplacementPolicy policy) {
+ this.policy = policy;
+ }
+}
Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java (added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,64 @@
+package org.apache.mahout.math.stats;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Random;
+
+public class OnlineAucTest {
+ @Test
+ public void binaryCase() {
+ OnlineAuc a1 = new OnlineAuc();
+ a1.setPolicy(OnlineAuc.ReplacementPolicy.FAIR);
+
+ OnlineAuc a2 = new OnlineAuc();
+ a2.setPolicy(OnlineAuc.ReplacementPolicy.FIFO);
+
+ OnlineAuc a3 = new OnlineAuc();
+ a3.setPolicy(OnlineAuc.ReplacementPolicy.RANDOM);
+
+ Random gen = new Random(1);
+ for (int i = 0; i < 10000; i++) {
+ double x = gen.nextGaussian();
+
+ a1.addSample(0, x);
+ a2.addSample(0, x);
+ a3.addSample(0, x);
+
+ x = gen.nextGaussian() + 1;
+
+ a1.addSample(1, x);
+ a2.addSample(1, x);
+ a3.addSample(1, x);
+ }
+
+ a1 = new OnlineAuc();
+ a1.setPolicy(OnlineAuc.ReplacementPolicy.FAIR);
+
+ a2 = new OnlineAuc();
+ a2.setPolicy(OnlineAuc.ReplacementPolicy.FIFO);
+
+ a3 = new OnlineAuc();
+ a3.setPolicy(OnlineAuc.ReplacementPolicy.RANDOM);
+
+ gen = new Random(1);
+ for (int i = 0; i < 10000; i++) {
+ double x = gen.nextGaussian();
+
+ a1.addSample(1, x);
+ a2.addSample(1, x);
+ a3.addSample(1, x);
+
+ x = gen.nextGaussian() + 1;
+
+ a1.addSample(0, x);
+ a2.addSample(0, x);
+ a3.addSample(0, x);
+ }
+
+ // reference value computed using R: mean(rnorm(1000000) < rnorm(1000000,1))
+ Assert.assertEquals(1 - 0.76, a1.auc(), 0.05);
+ Assert.assertEquals(1 - 0.76, a2.auc(), 0.05);
+ Assert.assertEquals(1 - 0.76, a3.auc(), 0.05);
+ }
+}