You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/08/16 18:56:47 UTC

svn commit: r986045 [2/2] - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/ core/src/main/java/org/apache/mahout/classifier/sgd/ core/src/test/java/org/apache/mahout/classifier/ core/src/test/java/org/apache/mahout/classifier/sgd/ ma...

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+public class ContinuousValueEncoderTest {
+  @Test
+  public void testAddToVector() {
+    FeatureVectorEncoder enc = new ContinuousValueEncoder("foo");
+    Vector v1 = new DenseVector(20);
+    enc.addToVector("-123", v1);
+    assertEquals(-123, v1.minValue(), 0);
+    assertEquals(0, v1.maxValue(), 0);
+    assertEquals(123, v1.norm(1), 0);
+
+    v1 = new DenseVector(20);
+    enc.addToVector("123", v1);
+    assertEquals(123, v1.maxValue(), 0);
+    assertEquals(0, v1.minValue(), 0);
+    assertEquals(123, v1.norm(1), 0);
+
+    Vector v2 = new DenseVector(20);
+    enc.setProbes(2);
+    enc.addToVector("123", v2);
+    assertEquals(123, v2.maxValue(), 0);
+    assertEquals(2 * 123, v2.norm(1), 0);
+
+    v1 = v2.minus(v1);
+    assertEquals(123, v1.maxValue(), 0);
+    assertEquals(123, v1.norm(1), 0);
+
+    Vector v3 = new DenseVector(20);
+    enc.setProbes(2);
+    enc.addToVector("100", v3);
+    v1 = v2.minus(v3);
+    assertEquals(23, v1.maxValue(), 0);
+    assertEquals(2 * 23, v1.norm(1), 0);
+
+    enc.addToVector("7", v1);
+    assertEquals(30, v1.maxValue(), 0);
+    assertEquals(2 * 30, v1.norm(1), 0);
+    assertEquals(30, v1.get(10), 0);
+    assertEquals(30, v1.get(18), 0);
+
+    try {
+      enc.addToVector("foobar", v1);
+      fail("Should have noticed back numeric format");
+    } catch (NumberFormatException e) {
+      assertEquals("For input string: \"foobar\"", e.getMessage());
+    }
+  }
+
+  @Test
+  public void testAsString() {
+    ContinuousValueEncoder enc = new ContinuousValueEncoder("foo");
+    assertEquals("foo:123", enc.asString("123"));
+  }
+
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.ImmutableMap;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class CsvRecordFactoryTest {
+  @Test
+  public void testAddToVector() {
+    CsvRecordFactory csv = new CsvRecordFactory("y", ImmutableMap.of("x1", "n", "x2", "w", "x3", "t"));
+    csv.firstLine("z,x1,y,x2,x3,q");
+    csv.maxTargetValue(2);
+
+    Vector v = new DenseVector(2000);
+    int t = csv.processLine("ignore,3.1,yes,tiger, \"this is text\",ignore", v);
+    assertEquals(0, t);
+    // should have 9 values set
+    assertEquals(9.0, v.norm(0), 0);
+    // all should be = 1 except for the 3.1
+    assertEquals(3.1, v.maxValue(), 0);
+    v.set(v.maxValueIndex(), 0);
+    assertEquals(8.0, v.norm(0), 0);
+    assertEquals(8.0, v.norm(1), 0);
+    assertEquals(1.0, v.maxValue(), 0);
+
+    v.assign(0);
+    t = csv.processLine("ignore,5.3,no,line, \"and more text and more\",ignore", v);
+    assertEquals(1, t);
+
+    // should have 9 values set
+    assertEquals(9.0, v.norm(0), 0);
+    // all should be = 1 except for the 3.1
+    assertEquals(5.3, v.maxValue(), 0);
+    v.set(v.maxValueIndex(), 0);
+    assertEquals(8.0, v.norm(0), 0);
+    assertEquals(12.0, v.norm(1), 0);
+    assertEquals(2, v.maxValue(), 0);
+
+    v.assign(0);
+    t = csv.processLine("ignore,5.3,invalid,line, \"and more text and more\",ignore", v);
+    assertEquals(1, t);
+
+    // should have 9 values set
+    assertEquals(9.0, v.norm(0), 0);
+    // all should be = 1 except for the 3.1
+    assertEquals(5.3, v.maxValue(), 0);
+    v.set(v.maxValueIndex(), 0);
+    assertEquals(8.0, v.norm(0), 0);
+    assertEquals(12.0, v.norm(1), 0);
+    assertEquals(2, v.maxValue(), 0);
+  }
+
+  @Test
+  public void testDictionaryOrder() {
+    Dictionary dict = new Dictionary();
+
+    dict.intern("a");
+    dict.intern("d");
+    dict.intern("c");
+    dict.intern("b");
+    dict.intern("qrz");
+
+    Assert.assertEquals("[a, d, c, b, qrz]", dict.values().toString());
+
+    Dictionary dict2 = Dictionary.fromList(dict.values());
+    Assert.assertEquals("[a, d, c, b, qrz]", dict2.values().toString());
+
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.CharStreams;
+import com.google.common.io.Resources;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.junit.Assert.assertEquals;
+
+public class OnlineLogisticRegressionTest {
+  private Matrix input;
+
+  /**
+   * The CrossFoldLearner is probably the best learner to use for new applications.
+    * @throws IOException If test resources aren't readable.
+   */
+  @Test
+  public void crossValidation() throws IOException {
+    Vector target = readStandardData();
+
+    CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
+            .lambda(1 * 1e-3)
+            .learningRate(50);
+
+
+    train(input, target, lr);
+
+    System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
+    test(input, target, lr);
+
+  }
+
+  @Test
+  public void crossValidatedAuc() throws IOException {
+    RandomUtils.useTestSeed();
+    Random gen = RandomUtils.getRandom();
+
+    Matrix data = readCsv("cancer.csv");
+    CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
+            .stepOffset(10)
+            .decayExponent(0.7)
+            .lambda(1 * 1e-3)
+            .learningRate(5);
+    int k = 0;
+    int[] ordering = permute(gen, data.numRows());
+    for (int epoch = 0; epoch < 100; epoch++) {
+      for (int row : ordering) {
+        lr.train(row, (int) data.get(row, 9), data.viewRow(row));
+        System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc());
+      }
+    }
+  }
+
+  /**
+   * Verifies that a classifier with known coefficients does the right thing.
+   */
+  @Test
+  public void testClassify() {
+    OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1));
+    // set up some internal coefficients as if we had learned them
+    lr.setBeta(0, 0, -1);
+    lr.setBeta(1, 0, -2);
+
+    // zero vector gives no information.  All classes are equal.
+    Vector v = lr.classify(new DenseVector(new double[]{0, 0}));
+    assertEquals(1 / 3.0, v.get(0), 1e-8);
+    assertEquals(1 / 3.0, v.get(1), 1e-8);
+
+    v = lr.classifyFull(new DenseVector(new double[]{0, 0}));
+    assertEquals(1.0, v.zSum(), 1e-8);
+    assertEquals(1 / 3.0, v.get(0), 1e-8);
+    assertEquals(1 / 3.0, v.get(1), 1e-8);
+    assertEquals(1 / 3.0, v.get(2), 1e-8);
+
+    // weights for second vector component are still zero so all classifications are equally likely
+    v = lr.classify(new DenseVector(new double[]{0, 1}));
+    assertEquals(1 / 3.0, v.get(0), 1e-3);
+    assertEquals(1 / 3.0, v.get(1), 1e-3);
+
+    v = lr.classifyFull(new DenseVector(new double[]{0, 1}));
+    assertEquals(1.0, v.zSum(), 1e-8);
+    assertEquals(1 / 3.0, v.get(0), 1e-3);
+    assertEquals(1 / 3.0, v.get(1), 1e-3);
+    assertEquals(1 / 3.0, v.get(2), 1e-3);
+
+    // but the weights on the first component are non-zero
+    v = lr.classify(new DenseVector(new double[]{1, 0}));
+    assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1e-8);
+    assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1e-8);
+
+    v = lr.classifyFull(new DenseVector(new double[]{1, 0}));
+    assertEquals(1.0, v.zSum(), 1e-8);
+    assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1e-8);
+    assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1e-8);
+    assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1e-8);
+
+    lr.setBeta(0, 1, 1);
+
+    v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+    assertEquals(1.0, v.zSum(), 1e-8);
+    assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1e-3);
+    assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1e-3);
+    assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1e-3);
+
+    lr.setBeta(1, 1, 3);
+
+    v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+    assertEquals(1.0, v.zSum(), 1e-8);
+    assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1e-8);
+    assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1e-8);
+    assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1e-8);
+  }
+
+  @Test
+  public void testTrain() throws IOException {
+    Vector target = readStandardData();
+
+
+    // lambda here needs to be relatively small to avoid swamping the actual signal, but can be
+    // larger than usual because the data are dense.  The learning rate doesn't matter too much
+    // for this example, but should generally be < 1
+    // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias --target y --categories 2 --predictors  V2 V3 V4 V5 V6 V7 --types n
+    OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+            .lambda(1 * 1e-3)
+            .learningRate(50);
+
+    train(input, target, lr);
+    test(input, target, lr);
+  }
+
+  private Vector readStandardData() throws IOException {
+    // 60 test samples.  First column is constant.  Second and third are normally distributed from
+    // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59).  The first 30 rows have a
+    // target variable of 0, the last 30 a target of 1.  The remaining columns are are random noise.
+    input = readCsv("sgd.csv");
+
+    // regenerate the target variable
+    Vector target = new DenseVector(60);
+    target.assign(0);
+    target.viewPart(30, 30).assign(1);
+    return target;
+  }
+
+  private void train(Matrix input, Vector target, OnlineLearner lr) {
+    RandomUtils.useTestSeed();
+    Random gen = RandomUtils.getRandom();
+
+    // train on samples in random order (but only one pass)
+    for (int row : permute(gen, 60)) {
+      lr.train((int) target.get(row), input.getRow(row));
+    }
+    lr.close();
+  }
+
+  private void test(Matrix input, Vector target, AbstractVectorClassifier lr) {
+    // now test the accuracy
+    Matrix tmp = lr.classify(input);
+    // mean(abs(tmp - target))
+    double meanAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.plus, Functions.abs) / 60;
+
+    // max(abs(tmp - target)
+    double maxAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.max, Functions.abs);
+
+    System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
+    assertEquals(0, meanAbsoluteError , 0.05);
+    assertEquals(0, maxAbsoluteError, 0.3);
+
+    // convenience methods should give the same results
+    Vector v = lr.classifyScalar(input);
+    assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1e-5);
+    v = lr.classifyFull(input).getColumn(1);
+    assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1e-4);
+  }
+
+  /**
+   * Permute the integers from 0 ... max-1
+   *
+   * @param gen The random number generator to use.
+   * @param max The number of integers to permute
+   * @return An array of jumbled integer values
+   */
+  private int[] permute(Random gen, int max) {
+    int[] permutation = new int[max];
+    permutation[0] = 0;
+    for (int i = 1; i < max; i++) {
+      int n = gen.nextInt(i + 1);
+      if (n != i) {
+        permutation[i] = permutation[n];
+        permutation[n] = i;
+      } else {
+        permutation[i] = i;
+      }
+    }
+    return permutation;
+  }
+
+
+  /**
+   * Reads a file containing CSV data.  This isn't implemented quite the way you might like for a
+   * real program, but does the job for reading test data.  Most notably, it will only read numbers,
+   * not quoted strings.
+   *
+   * @param resourceName Where to get the data.
+   * @return A matrix of the results.
+   * @throws java.io.IOException If there is an error reading the data
+   */
+  private Matrix readCsv(String resourceName) throws IOException {
+    Splitter onCommas = Splitter.on(",").trimResults(CharMatcher.anyOf(" \""));
+
+    InputStreamReader isr = new InputStreamReader(Resources.getResource(resourceName).openStream());
+    List<String> data = CharStreams.readLines(isr);
+    String first = data.get(0);
+    data = data.subList(1, data.size());
+
+    List<String> values = Lists.newArrayList(onCommas.split(first));
+    Matrix r = new DenseMatrix(data.size(), values.size());
+
+    int column = 0;
+    Map<String, Integer> labels = Maps.newHashMap();
+    for (String value : values) {
+      labels.put(value, column);
+      column++;
+    }
+    r.setColumnLabelBindings(labels);
+
+    int row = 0;
+    for (String line : data) {
+      column = 0;
+      values = Lists.newArrayList(onCommas.split(line));
+      for (String value : values) {
+        r.set(row, column, Double.parseDouble(value));
+        column++;
+      }
+      row++;
+    }
+
+    return r;
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.ImmutableMap;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class TextValueEncoderTest {
+  @Test
+  public void testAddToVector() {
+    TextValueEncoder enc = new TextValueEncoder("text");
+    Vector v1 = new DenseVector(200);
+    enc.addToVector("test1 and more", v1);
+    // should set 6 distinct locations to 1
+    assertEquals(6.0, v1.norm(1), 0);
+    assertEquals(1.0, v1.maxValue(), 0);
+
+    // now some fancy weighting
+    StaticWordValueEncoder w = new StaticWordValueEncoder("text");
+    w.setDictionary(ImmutableMap.<String, Double>of("word1", 3.0, "word2", 1.5));
+    enc.setWordEncoder(w);
+
+    // should set 6 locations to something
+    Vector v2 = new DenseVector(200);
+    enc.addToVector("test1 and more", v2);
+
+    // this should set the same 6 locations to the same values
+    Vector v3 = new DenseVector(200);
+    w.addToVector("test1", v3);
+    w.addToVector("and", v3);
+    w.addToVector("more", v3);
+
+    assertEquals(0, v3.minus(v2).norm(1), 0);
+
+    // moreover, the locations set in the unweighted case should be the same as in the weighted case
+    assertEquals(v3.zSum(), v3.dot(v1), 0);
+  }
+
+  @Test
+  public void testAsString() {
+    TextValueEncoder enc = new TextValueEncoder("text");
+    assertEquals("[text:test1:1.0000, text:and:1.0000, text:more:1.0000]", enc.asString("test1 and more"));
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.util.Iterator;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+public class WordLikeValueEncoderTest {
+  @Test
+  public void testAddToVector() {
+    FeatureVectorEncoder enc = new StaticWordValueEncoder("word");
+    Vector v = new DenseVector(200);
+    enc.addToVector("word1", v);
+    enc.addToVector("word2", v);
+    Iterator<Vector.Element> i = v.iterateNonZero();
+    Iterator<Integer> j = ImmutableList.of(7, 118, 119, 199).iterator();
+    while (i.hasNext()) {
+      Vector.Element element = i.next();
+      assertEquals(j.next().intValue(), element.index());
+      assertEquals(1, element.get(), 0);
+    }
+    assertFalse(j.hasNext());
+  }
+
+  @Test
+  public void testAsString() {
+    FeatureVectorEncoder enc = new StaticWordValueEncoder("word");
+    assertEquals("word:w1:1.0000", enc.asString("w1"));
+  }
+
+  @Test
+  public void testStaticWeights() {
+    StaticWordValueEncoder enc = new StaticWordValueEncoder("word");
+    enc.setDictionary(ImmutableMap.<String, Double>of("word1", 3.0, "word2", 1.5));
+    Vector v = new DenseVector(200);
+    enc.addToVector("word1", v);
+    enc.addToVector("word2", v);
+    enc.addToVector("word3", v);
+    Iterator<Vector.Element> i = v.iterateNonZero();
+    Iterator<Integer> j = ImmutableList.of(7, 101, 118, 119, 152, 199).iterator();
+    Iterator<Double> k = ImmutableList.of(3.0, 0.75, 1.5, 1.5, 0.75, 3.0).iterator();
+    while (i.hasNext()) {
+      Vector.Element element = i.next();
+      assertEquals(j.next().intValue(), element.index());
+      assertEquals(k.next(), element.get(), 0);
+    }
+    assertFalse(j.hasNext());
+  }
+
+  @Test
+  public void testDynamicWeights() {
+    FeatureVectorEncoder enc = new AdaptiveWordValueEncoder("word");
+    Vector v = new DenseVector(200);
+    enc.addToVector("word1", v);  // weight is log(2/1.5)
+    enc.addToVector("word2", v);  // weight is log(3.5 / 1.5)
+    enc.addToVector("word1", v);  // weight is log(4.5 / 2.5) (but overlays on first value)
+    enc.addToVector("word3", v);  // weight is log(6 / 1.5)
+    Iterator<Vector.Element> i = v.iterateNonZero();
+    Iterator<Integer> j = ImmutableList.of(7, 101, 118, 119, 152, 199).iterator();
+    Iterator<Double> k = ImmutableList.of(Math.log(2 / 1.5) + Math.log(4.5 / 2.5), Math.log(6 / 1.5), Math.log(3.5 / 1.5), Math.log(3.5 / 1.5), Math.log(6 / 1.5), Math.log(2 / 1.5) + Math.log(4.5 / 2.5)).iterator();
+    while (i.hasNext()) {
+      Vector.Element element = i.next();
+      assertEquals(j.next().intValue(), element.index());
+      assertEquals(k.next(), element.get(), 1e-6);
+    }
+    assertFalse(j.hasNext());
+  }
+}

Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,100 @@
+package org.apache.mahout.math.stats;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+import java.util.Random;
+
+/**
+ * Computes a running estimate of AUC (see http://en.wikipedia.org/wiki/Receiver_operating_characteristic).
+ * <p/>
+ * Since AUC is normally a global property of labeled scores, it is almost always computed in a
+ * batch fashion.  The probabilistic definition (the probability that a random element of one set
+ * has a higher score than a random element of another set) gives us a way to estimate this
+ * on-line.
+ */
+public class OnlineAuc {
+  private Random random = new Random();
+
+  enum ReplacementPolicy {
+    FIFO, FAIR, RANDOM
+  }
+
+  public static final int HISTORY = 10;
+
+  private ReplacementPolicy policy = ReplacementPolicy.FAIR;
+
+  private Matrix scores;
+  private Vector averages;
+
+  private Vector samples;
+
+  public OnlineAuc() {
+    int numCategories = 2;
+    scores = new DenseMatrix(numCategories, HISTORY);
+    scores.assign(Double.NaN);
+    averages = new DenseVector(numCategories);
+    averages.assign(0.5);
+    samples = new DenseVector(numCategories);
+  }
+
+  public double addSample(int category, final double score) {
+    int n = (int) samples.get(category);
+    if (n < HISTORY) {
+      scores.set(category, n, score);
+    } else {
+      switch (policy) {
+        case FIFO:
+          scores.set(category, n % HISTORY, score);
+          break;
+        case FAIR:
+          int j = random.nextInt(n + 1);
+          if (j < HISTORY) {
+            scores.set(category, j, score);
+          }
+          break;
+        case RANDOM:
+          j = random.nextInt(HISTORY);
+          scores.set(category, j, score);
+          break;
+      }
+    }
+
+    samples.set(category, n + 1);
+
+    if (samples.minValue() >= 1) {
+      // compare to previous scores for other category
+      Vector row = scores.viewRow(1 - category);
+      double m = 0;
+      int count = 0;
+      for (Vector.Element element : row) {
+        double v = element.get();
+        if (!Double.isNaN(v)) {
+          count++;
+          double z = 0.5;
+          if (score > v) {
+            z = 1;
+          } else if (score < v) {
+            z = 0;
+          }
+          m += (z - m) / count;
+        } else {
+          break;
+        }
+      }
+      averages.set(category, averages.get(category) + (m - averages.get(category)) / samples.get(category));
+    }
+    return auc();
+  }
+
+  public double auc() {
+    // return an unweighted average of all averages.
+    return 0.5 - averages.get(0) / 2 + averages.get(1) / 2;
+  }
+
+  public void setPolicy(ReplacementPolicy policy) {
+    this.policy = policy;
+  }
+}

Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java?rev=986045&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java (added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java Mon Aug 16 16:56:46 2010
@@ -0,0 +1,64 @@
+package org.apache.mahout.math.stats;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Random;
+
+public class OnlineAucTest {
+  @Test
+  public void binaryCase() {
+    OnlineAuc a1 = new OnlineAuc();
+    a1.setPolicy(OnlineAuc.ReplacementPolicy.FAIR);
+
+    OnlineAuc a2 = new OnlineAuc();
+    a2.setPolicy(OnlineAuc.ReplacementPolicy.FIFO);
+
+    OnlineAuc a3 = new OnlineAuc();
+    a3.setPolicy(OnlineAuc.ReplacementPolicy.RANDOM);
+
+    Random gen = new Random(1);
+    for (int i = 0; i < 10000; i++) {
+      double x = gen.nextGaussian();
+
+      a1.addSample(0, x);
+      a2.addSample(0, x);
+      a3.addSample(0, x);
+
+      x = gen.nextGaussian() + 1;
+
+      a1.addSample(1, x);
+      a2.addSample(1, x);
+      a3.addSample(1, x);
+    }
+
+    a1 = new OnlineAuc();
+    a1.setPolicy(OnlineAuc.ReplacementPolicy.FAIR);
+
+    a2 = new OnlineAuc();
+    a2.setPolicy(OnlineAuc.ReplacementPolicy.FIFO);
+
+    a3 = new OnlineAuc();
+    a3.setPolicy(OnlineAuc.ReplacementPolicy.RANDOM);
+
+    gen = new Random(1);
+    for (int i = 0; i < 10000; i++) {
+      double x = gen.nextGaussian();
+
+      a1.addSample(1, x);
+      a2.addSample(1, x);
+      a3.addSample(1, x);
+
+      x = gen.nextGaussian() + 1;
+
+      a1.addSample(0, x);
+      a2.addSample(0, x);
+      a3.addSample(0, x);
+    }
+
+    // reference value computed using R: mean(rnorm(1000000) < rnorm(1000000,1))
+    Assert.assertEquals(1 - 0.76, a1.auc(), 0.05);
+    Assert.assertEquals(1 - 0.76, a2.auc(), 0.05);
+    Assert.assertEquals(1 - 0.76, a3.auc(), 0.05);
+  }
+}