You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:07:37 UTC

[06/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy into mahout-hdfs and mahout-mr, closes apache/mahout#86

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
new file mode 100644
index 0000000..3104cb1
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.junit.Test;
+
+public class HMMEvaluatorTest extends HMMTestBase {
+
+  /**
+   * Test to make sure the computed model likelihood ist valid. Included tests
+   * are: a) forwad == backward likelihood b) model likelihood for test seqeunce
+   * is the expected one from R reference
+   */
+  @Test
+  public void testModelLikelihood() {
+    // compute alpha and beta values
+    Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), false);
+    Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), false);
+    // now test whether forward == backward likelihood
+    double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, false);
+    double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(),
+        beta, false);
+    assertEquals(forwardLikelihood, backwardLikelihood, EPSILON);
+    // also make sure that the likelihood matches the expected one
+    assertEquals(1.8425e-4, forwardLikelihood, EPSILON);
+  }
+
+  /**
+   * Test to make sure the computed model likelihood ist valid. Included tests
+   * are: a) forwad == backward likelihood b) model likelihood for test seqeunce
+   * is the expected one from R reference
+   */
+  @Test
+  public void testScaledModelLikelihood() {
+    // compute alpha and beta values
+    Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), true);
+    Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), true);
+    // now test whether forward == backward likelihood
+    double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, true);
+    double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(),
+        beta, true);
+    assertEquals(forwardLikelihood, backwardLikelihood, EPSILON);
+    // also make sure that the likelihood matches the expected one
+    assertEquals(1.8425e-4, forwardLikelihood, EPSILON);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
new file mode 100644
index 0000000..3260f51
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.junit.Test;
+
+public class HMMModelTest extends HMMTestBase {
+
+  @Test
+  public void testRandomModelGeneration() {
+    // make sure we generate a valid random model
+    HmmModel model = new HmmModel(10, 20);
+    // check whether the model is valid
+    HmmUtils.validate(model);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
new file mode 100644
index 0000000..90f1cd8
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+
+public class HMMTestBase extends MahoutTestCase {
+
+  private HmmModel model;
+  private final int[] sequence = {1, 0, 2, 2, 0, 0, 1};
+
+  /**
+   * We initialize a new HMM model using the following parameters # hidden
+   * states: 4 ("H0","H1","H2","H3") # output states: 3 ("O0","O1","O2") #
+   * transition matrix to: H0 H1 H2 H3 from: H0 0.5 0.1 0.1 0.3 H1 0.4 0.4 0.1
+   * 0.1 H2 0.1 0.0 0.8 0.1 H3 0.1 0.1 0.1 0.7 # output matrix to: O0 O1 O2
+   * from: H0 0.8 0.1 0.1 H1 0.6 0.1 0.3 H2 0.1 0.8 0.1 H3 0.0 0.1 0.9 # initial
+   * probabilities H0 0.2
+   * <p/>
+   * H1 0.1 H2 0.4 H3 0.3
+   * <p/>
+   * We also intialize an observation sequence: "O1" "O0" "O2" "O2" "O0" "O0"
+   * "O1"
+   */
+
+  @Override
+  public void setUp() throws Exception {
+    super.setUp();
+    // intialize the hidden/output state names
+    String[] hiddenNames = {"H0", "H1", "H2", "H3"};
+    String[] outputNames = {"O0", "O1", "O2"};
+    // initialize the transition matrix
+    double[][] transitionP = {{0.5, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
+        {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
+    // initialize the emission matrix
+    double[][] emissionP = {{0.8, 0.1, 0.1}, {0.6, 0.1, 0.3},
+        {0.1, 0.8, 0.1}, {0.0, 0.1, 0.9}};
+    // initialize the initial probability vector
+    double[] initialP = {0.2, 0.1, 0.4, 0.3};
+    // now generate the model
+    model = new HmmModel(new DenseMatrix(transitionP), new DenseMatrix(
+        emissionP), new DenseVector(initialP));
+    model.registerHiddenStateNames(hiddenNames);
+    model.registerOutputStateNames(outputNames);
+    // make sure the model is valid :)
+    HmmUtils.validate(model);
+  }
+
+  protected HmmModel getModel() {
+    return model;
+  }
+
+  protected int[] getSequence() {
+    return sequence;
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
new file mode 100644
index 0000000..b8f3186
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
@@ -0,0 +1,163 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class HMMTrainerTest extends HMMTestBase {
+
+  @Test
+  public void testViterbiTraining() {
+    // initialize the expected model parameters (from R)
+    // expected transition matrix
+    double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125},
+        {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
+        {0.5, 0.1, 0.1, 0.3}};
+    // initialize the emission matrix
+    double[][] emissionE = {{0.882353, 0.058824, 0.058824},
+        {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
+        {0.111111, 0.111111, 0.777778}};
+
+    // train the given network to the following output sequence
+    int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+    HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10, false);
+
+    // now check whether the model matches our expectations
+    Matrix emissionMatrix = trained.getEmissionMatrix();
+    Matrix transitionMatrix = trained.getTransitionMatrix();
+
+    for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+      for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+        assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j], EPSILON);
+      }
+
+      for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+        assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j], EPSILON);
+      }
+    }
+
+  }
+
+  @Test
+  public void testScaledViterbiTraining() {
+    // initialize the expected model parameters (from R)
+    // expected transition matrix
+    double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125},
+        {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
+        {0.5, 0.1, 0.1, 0.3}};
+    // initialize the emission matrix
+    double[][] emissionE = {{0.882353, 0.058824, 0.058824},
+        {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
+        {0.111111, 0.111111, 0.777778}};
+
+    // train the given network to the following output sequence
+    int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+    HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10,
+        true);
+
+    // now check whether the model matches our expectations
+    Matrix emissionMatrix = trained.getEmissionMatrix();
+    Matrix transitionMatrix = trained.getTransitionMatrix();
+
+    for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+      for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+        assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j],
+            EPSILON);
+      }
+
+      for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+        assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j],
+            EPSILON);
+      }
+    }
+
+  }
+
+  @Test
+  public void testBaumWelchTraining() {
+    // train the given network to the following output sequence
+    int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+    // expected values from Matlab HMM package / R HMM package
+    double[] initialExpected = {0, 0, 1.0, 0};
+    double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683},
+        {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0},
+        {0.0024, 0.6657, 0, 0.3319}};
+    double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
+        {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
+
+    HmmModel trained = HmmTrainer.trainBaumWelch(getModel(), observed, 0.1, 10,
+        false);
+
+    Vector initialProbabilities = trained.getInitialProbabilities();
+    Matrix emissionMatrix = trained.getEmissionMatrix();
+    Matrix transitionMatrix = trained.getTransitionMatrix();
+
+    for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+      assertEquals(initialProbabilities.get(i), initialExpected[i],
+          0.0001);
+      for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+        assertEquals(transitionMatrix.getQuick(i, j),
+            transitionExpected[i][j], 0.0001);
+      }
+      for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+        assertEquals(emissionMatrix.getQuick(i, j),
+            emissionExpected[i][j], 0.0001);
+      }
+    }
+  }
+
+  @Test
+  public void testScaledBaumWelchTraining() {
+    // train the given network to the following output sequence
+    int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+    // expected values from Matlab HMM package / R HMM package
+    double[] initialExpected = {0, 0, 1.0, 0};
+    double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683},
+        {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0},
+        {0.0024, 0.6657, 0, 0.3319}};
+    double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
+        {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
+
+    HmmModel trained = HmmTrainer
+        .trainBaumWelch(getModel(), observed, 0.1, 10, true);
+
+    Vector initialProbabilities = trained.getInitialProbabilities();
+    Matrix emissionMatrix = trained.getEmissionMatrix();
+    Matrix transitionMatrix = trained.getTransitionMatrix();
+
+    for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+      assertEquals(initialProbabilities.get(i), initialExpected[i],
+          0.0001);
+      for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+        assertEquals(transitionMatrix.getQuick(i, j),
+            transitionExpected[i][j], 0.0001);
+      }
+      for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+        assertEquals(emissionMatrix.getQuick(i, j),
+            emissionExpected[i][j], 0.0001);
+      }
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
new file mode 100644
index 0000000..6c34718
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class HMMUtilsTest extends HMMTestBase {
+
+  private Matrix legal22;
+  private Matrix legal23;
+  private Matrix legal33;
+  private Vector legal2;
+  private Matrix illegal22;
+
+  @Override
+  public void setUp() throws Exception {
+    super.setUp();
+    legal22 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}});
+    legal23 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6},
+        {0.3, 0.3, 0.4}});
+    legal33 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8},
+        {0.1, 0.2, 0.7}, {0.2, 0.3, 0.5}});
+    legal2 = new DenseVector(new double[]{0.4, 0.6});
+    illegal22 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}});
+  }
+
+  @Test
+  public void testValidatorLegal() {
+    HmmUtils.validate(new HmmModel(legal22, legal23, legal2));
+  }
+
+  @Test
+  public void testValidatorDimensionError() {
+    try {
+      HmmUtils.validate(new HmmModel(legal33, legal23, legal2));
+    } catch (IllegalArgumentException e) {
+      // success
+      return;
+    }
+    fail();
+  }
+
+  @Test
+  public void testValidatorIllegelMatrixError() {
+    try {
+      HmmUtils.validate(new HmmModel(illegal22, legal23, legal2));
+    } catch (IllegalArgumentException e) {
+      // success
+      return;
+    }
+    fail();
+  }
+
+  @Test
+  public void testEncodeStateSequence() {
+    String[] hiddenSequence = {"H1", "H2", "H0", "H3", "H4"};
+    String[] outputSequence = {"O1", "O2", "O4", "O0"};
+    // test encoding the hidden Sequence
+    int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays
+        .asList(hiddenSequence), false, -1);
+    int[] outputSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays
+        .asList(outputSequence), true, -1);
+    // expected state sequences
+    int[] hiddenSequenceExp = {1, 2, 0, 3, -1};
+    int[] outputSequenceExp = {1, 2, -1, 0};
+    // compare
+    for (int i = 0; i < hiddenSequenceEnc.length; ++i) {
+      assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]);
+    }
+    for (int i = 0; i < outputSequenceEnc.length; ++i) {
+      assertEquals(outputSequenceExp[i], outputSequenceEnc[i]);
+    }
+  }
+
+  @Test
+  public void testDecodeStateSequence() {
+    int[] hiddenSequence = {1, 2, 0, 3, 10};
+    int[] outputSequence = {1, 2, 10, 0};
+    // test encoding the hidden Sequence
+    List<String> hiddenSequenceDec = HmmUtils.decodeStateSequence(
+        getModel(), hiddenSequence, false, "unknown");
+    List<String> outputSequenceDec = HmmUtils.decodeStateSequence(
+        getModel(), outputSequence, true, "unknown");
+    // expected state sequences
+    String[] hiddenSequenceExp = {"H1", "H2", "H0", "H3", "unknown"};
+    String[] outputSequenceExp = {"O1", "O2", "unknown", "O0"};
+    // compare
+    for (int i = 0; i < hiddenSequenceExp.length; ++i) {
+      assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i));
+    }
+    for (int i = 0; i < outputSequenceExp.length; ++i) {
+      assertEquals(outputSequenceExp[i], outputSequenceDec.get(i));
+    }
+  }
+
+  @Test
+  public void testNormalizeModel() {
+    DenseVector ip = new DenseVector(new double[]{10, 20});
+    DenseMatrix tr = new DenseMatrix(new double[][]{{10, 10}, {20, 25}});
+    DenseMatrix em = new DenseMatrix(new double[][]{{5, 7}, {10, 15}});
+    HmmModel model = new HmmModel(tr, em, ip);
+    HmmUtils.normalizeModel(model);
+    // the model should be valid now
+    HmmUtils.validate(model);
+  }
+
+  @Test
+  public void testTruncateModel() {
+    DenseVector ip = new DenseVector(new double[]{0.0001, 0.0001, 0.9998});
+    DenseMatrix tr = new DenseMatrix(new double[][]{
+        {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
+        {0.0001, 0.0001, 0.9998}});
+    DenseMatrix em = new DenseMatrix(new double[][]{
+        {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
+        {0.0001, 0.0001, 0.9998}});
+    HmmModel model = new HmmModel(tr, em, ip);
+    // now truncate the model
+    HmmModel sparseModel = HmmUtils.truncateModel(model, 0.01);
+    // first make sure this is a valid model
+    HmmUtils.validate(sparseModel);
+    // now check whether the values are as expected
+    Vector sparse_ip = sparseModel.getInitialProbabilities();
+    Matrix sparse_tr = sparseModel.getTransitionMatrix();
+    Matrix sparse_em = sparseModel.getEmissionMatrix();
+    for (int i = 0; i < sparseModel.getNrOfHiddenStates(); ++i) {
+      assertEquals(i == 2 ? 1.0 : 0.0, sparse_ip.getQuick(i), EPSILON);
+      for (int j = 0; j < sparseModel.getNrOfHiddenStates(); ++j) {
+        if (i == j) {
+          assertEquals(1.0, sparse_tr.getQuick(i, j), EPSILON);
+          assertEquals(1.0, sparse_em.getQuick(i, j), EPSILON);
+        } else {
+          assertEquals(0.0, sparse_tr.getQuick(i, j), EPSILON);
+          assertEquals(0.0, sparse_em.getQuick(i, j), EPSILON);
+        }
+      }
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
new file mode 100644
index 0000000..7ea8cb2
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.jet.random.Exponential;
+import org.junit.Test;
+
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
+
+import java.util.Random;
+
+public final class AdaptiveLogisticRegressionTest extends MahoutTestCase {
+
+  @ThreadLeakLingering(linger=1000)
+  @Test
+  public void testTrain() {
+
+    Random gen = RandomUtils.getRandom();
+    Exponential exp = new Exponential(0.5, gen);
+    Vector beta = new DenseVector(200);
+    for (Vector.Element element : beta.all()) {
+      int sign = 1;
+      if (gen.nextDouble() < 0.5) {
+        sign = -1;
+      }
+      element.set(sign * exp.nextDouble());
+    }
+
+    AdaptiveLogisticRegression.Wrapper cl = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1());
+    cl.update(new double[]{1.0e-5, 1});
+
+    for (int i = 0; i < 10000; i++) {
+      AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+      cl.train(r);
+      if (i % 1000 == 0) {
+        System.out.printf("%10d %10.3f\n", i, cl.getLearner().auc());
+      }
+    }
+    assertEquals(1, cl.getLearner().auc(), 0.1);
+
+    AdaptiveLogisticRegression adaptiveLogisticRegression = new AdaptiveLogisticRegression(2, 200, new L1());
+    adaptiveLogisticRegression.setInterval(1000);
+
+    for (int i = 0; i < 20000; i++) {
+      AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+      adaptiveLogisticRegression.train(r.getKey(), r.getActual(), r.getInstance());
+      if (i % 1000 == 0 && adaptiveLogisticRegression.getBest() != null) {
+        System.out.printf("%10d %10.4f %10.8f %.3f\n",
+                          i, adaptiveLogisticRegression.auc(),
+                          Math.log10(adaptiveLogisticRegression.getBest().getMappedParams()[0]), adaptiveLogisticRegression.getBest().getMappedParams()[1]);
+      }
+    }
+    assertEquals(1, adaptiveLogisticRegression.auc(), 0.1);
+    adaptiveLogisticRegression.close();
+  }
+
+  private static AdaptiveLogisticRegression.TrainingExample getExample(int i, Random gen, Vector beta) {
+    Vector data = new DenseVector(200);
+
+    for (Vector.Element element : data.all()) {
+      element.set(gen.nextDouble() < 0.3 ? 1 : 0);
+    }
+
+    double p = 1 / (1 + Math.exp(1.5 - data.dot(beta)));
+    int target = 0;
+    if (gen.nextDouble() < p) {
+      target = 1;
+    }
+    return new AdaptiveLogisticRegression.TrainingExample(i, null, target, data);
+  }
+
+  @Test
+  public void copyLearnsAsExpected() {
+    Random gen = RandomUtils.getRandom();
+    Exponential exp = new Exponential(0.5, gen);
+    Vector beta = new DenseVector(200);
+    for (Vector.Element element : beta.all()) {
+        int sign = 1;
+        if (gen.nextDouble() < 0.5) {
+          sign = -1;
+        }
+      element.set(sign * exp.nextDouble());
+    }
+
+    // train one copy of a wrapped learner
+    AdaptiveLogisticRegression.Wrapper w = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1());
+    for (int i = 0; i < 3000; i++) {
+      AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+      w.train(r);
+      if (i % 1000 == 0) {
+        System.out.printf("%10d %.3f\n", i, w.getLearner().auc());
+      }
+    }
+    System.out.printf("%10d %.3f\n", 3000, w.getLearner().auc());
+    double auc1 = w.getLearner().auc();
+
+    // then switch to a copy of that learner ... progress should continue
+    AdaptiveLogisticRegression.Wrapper w2 = w.copy();
+
+    for (int i = 0; i < 5000; i++) {
+      if (i % 1000 == 0) {
+        if (i == 0) {
+          assertEquals("Should have started with no data", 0.5, w2.getLearner().auc(), 0.0001);
+        }
+        if (i == 1000) {
+          double auc2 = w2.getLearner().auc();
+          assertTrue("Should have had head-start", Math.abs(auc2 - 0.5) > 0.1);
+          assertTrue("AUC should improve quickly on copy", auc1 < auc2);
+        }
+        System.out.printf("%10d %.3f\n", i, w2.getLearner().auc());
+      }
+      AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+      w2.train(r);
+    }
+    assertEquals("Original should not change after copy is updated", auc1, w.getLearner().auc(), 1.0e-5);
+
+    // this improvement is really quite lenient
+    assertTrue("AUC should improve significantly on copy", auc1 < w2.getLearner().auc() - 0.05);
+
+    // make sure that the copy didn't lose anything
+    assertEquals(auc1, w.getLearner().auc(), 0);
+  }
+
+  @Test
+  public void stepSize() {
+    assertEquals(500, AdaptiveLogisticRegression.stepSize(15000, 2));
+    assertEquals(2000, AdaptiveLogisticRegression.stepSize(15000, 2.6));
+    assertEquals(5000, AdaptiveLogisticRegression.stepSize(24000, 2.6));
+    assertEquals(10000, AdaptiveLogisticRegression.stepSize(15000, 3));
+  }
+
+  @Test
+  @ThreadLeakLingering(linger = 1000)
+  public void constantStep() {
+    AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1());
+    lr.setInterval(5000);
+    assertEquals(20000, lr.nextStep(15000));
+    assertEquals(20000, lr.nextStep(15001));
+    assertEquals(20000, lr.nextStep(16500));
+    assertEquals(20000, lr.nextStep(19999));
+    lr.close(); 
+  }
+    
+
+  @Test
+  @ThreadLeakLingering(linger = 1000)
+  public void growingStep() {
+    AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1());
+    lr.setInterval(2000, 10000);
+
+    // start with minimum step size
+    for (int i = 2000; i < 20000; i+=2000) {
+      assertEquals(i + 2000, lr.nextStep(i));
+    }
+
+    // then level up a bit
+    for (int i = 20000; i < 50000; i += 5000) {
+      assertEquals(i + 5000, lr.nextStep(i));
+    }
+
+    // and more, but we top out with this step size
+    for (int i = 50000; i < 500000; i += 10000) {
+      assertEquals(i + 10000, lr.nextStep(i));
+    }
+    lr.close();
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
new file mode 100644
index 0000000..6ee0ddf
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.ImmutableMap;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.junit.Test;
+
+public final class CsvRecordFactoryTest extends MahoutTestCase {
+
+  @Test
+  public void testAddToVector() {
+    RecordFactory csv = new CsvRecordFactory("y", ImmutableMap.of("x1", "n", "x2", "w", "x3", "t"));
+    csv.firstLine("z,x1,y,x2,x3,q");
+    csv.maxTargetValue(2);
+
+    Vector v = new DenseVector(2000);
+    int t = csv.processLine("ignore,3.1,yes,tiger, \"this is text\",ignore", v);
+    assertEquals(0, t);
+    // should have 9 values set
+    assertEquals(9.0, v.norm(0), 0);
+    // all should be = 1 except for the 3.1
+    assertEquals(3.1, v.maxValue(), 0);
+    v.set(v.maxValueIndex(), 0);
+    assertEquals(8.0, v.norm(0), 0);
+    assertEquals(8.0, v.norm(1), 0);
+    assertEquals(1.0, v.maxValue(), 0);
+
+    v.assign(0);
+    t = csv.processLine("ignore,5.3,no,line, \"and more text and more\",ignore", v);
+    assertEquals(1, t);
+
+    // should have 9 values set
+    assertEquals(9.0, v.norm(0), 0);
+    // all should be = 1 except for the 3.1
+    assertEquals(5.3, v.maxValue(), 0);
+    v.set(v.maxValueIndex(), 0);
+    assertEquals(8.0, v.norm(0), 0);
+    assertEquals(10.339850002884626, v.norm(1), 1.0e-6);
+    assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6);
+
+    v.assign(0);
+    t = csv.processLine("ignore,5.3,invalid,line, \"and more text and more\",ignore", v);
+    assertEquals(1, t);
+
+    // should have 9 values set
+    assertEquals(9.0, v.norm(0), 0);
+    // all should be = 1 except for the 3.1
+    assertEquals(5.3, v.maxValue(), 0);
+    v.set(v.maxValueIndex(), 0);
+    assertEquals(8.0, v.norm(0), 0);
+    assertEquals(10.339850002884626, v.norm(1), 1.0e-6);
+    assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6);
+  }
+
+  @Test
+  public void testDictionaryOrder() {
+    Dictionary dict = new Dictionary();
+
+    dict.intern("a");
+    dict.intern("d");
+    dict.intern("c");
+    dict.intern("b");
+    dict.intern("qrz");
+
+    assertEquals("[a, d, c, b, qrz]", dict.values().toString());
+
+    Dictionary dict2 = Dictionary.fromList(dict.values());
+    assertEquals("[a, d, c, b, qrz]", dict2.values().toString());
+
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
new file mode 100644
index 0000000..06a876e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Random;
+
+public final class GradientMachineTest extends OnlineBaseTest {
+
+  @Test
+  public void testGradientmachine() throws IOException {
+    Vector target = readStandardData();
+    GradientMachine grad = new GradientMachine(8,4,2).learningRate(0.1).regularization(0.01);
+    Random gen = RandomUtils.getRandom();
+    grad.initWeights(gen);
+    train(getInput(), target, grad);
+    // TODO not sure why the RNG change made this fail. Value is 0.5-1.0 no matter what seed is chosen?
+    test(getInput(), target, grad, 1.0, 1);
+    //test(getInput(), target, grad, 0.05, 1);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
new file mode 100644
index 0000000..2373b9d
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Random;
+
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+import org.junit.Test;
+
+public final class ModelSerializerTest extends MahoutTestCase {
+
+  private static <T extends Writable> T roundTrip(T m, Class<T> clazz) throws IOException {
+    ByteArrayOutputStream buf = new ByteArrayOutputStream(1000);
+    DataOutputStream dos = new DataOutputStream(buf);
+    try {
+      PolymorphicWritable.write(dos, m);
+    } finally {
+      Closeables.close(dos, false);
+    }
+    return PolymorphicWritable.read(new DataInputStream(new ByteArrayInputStream(buf.toByteArray())), clazz);
+  }
+
+  @Test
+  public void onlineAucRoundtrip() throws IOException {
+    RandomUtils.useTestSeed();
+    OnlineAuc auc1 = new GlobalOnlineAuc();
+    Random gen = RandomUtils.getRandom();
+    for (int i = 0; i < 10000; i++) {
+      auc1.addSample(0, gen.nextGaussian());
+      auc1.addSample(1, gen.nextGaussian() + 1);
+    }
+    assertEquals(0.76, auc1.auc(), 0.01);
+
+    OnlineAuc auc3 = roundTrip(auc1, OnlineAuc.class);
+
+    assertEquals(auc1.auc(), auc3.auc(), 0);
+
+    for (int i = 0; i < 1000; i++) {
+      auc1.addSample(0, gen.nextGaussian());
+      auc1.addSample(1, gen.nextGaussian() + 1);
+
+      auc3.addSample(0, gen.nextGaussian());
+      auc3.addSample(1, gen.nextGaussian() + 1);
+    }
+
+    assertEquals(auc1.auc(), auc3.auc(), 0.01);
+  }
+
+  @Test
+  public void onlineLogisticRegressionRoundTrip() throws IOException {
+    OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, new L1());
+    train(olr, 100);
+    OnlineLogisticRegression olr3 = roundTrip(olr, OnlineLogisticRegression.class);
+    assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6);
+
+    train(olr, 100);
+    train(olr3, 100);
+
+    assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6);
+    olr.close();
+    olr3.close();
+  }
+
+  @Test
+  public void crossFoldLearnerRoundTrip() throws IOException {
+    CrossFoldLearner learner = new CrossFoldLearner(5, 2, 5, new L1());
+    train(learner, 100);
+    CrossFoldLearner olr3 = roundTrip(learner, CrossFoldLearner.class);
+    double auc1 = learner.auc();
+    assertTrue(auc1 > 0.85);
+    assertEquals(auc1, learner.auc(), 1.0e-6);
+    assertEquals(auc1, olr3.auc(), 1.0e-6);
+
+    train(learner, 100);
+    train(learner, 100);
+    train(olr3, 100);
+
+    assertEquals(learner.auc(), learner.auc(), 0.02);
+    assertEquals(learner.auc(), olr3.auc(), 0.02);
+    double auc2 = learner.auc();
+    assertTrue(auc2 > auc1);
+    learner.close();
+    olr3.close();
+  }
+
+  @ThreadLeakLingering(linger = 1000)
+  @Test
+  public void adaptiveLogisticRegressionRoundTrip() throws IOException {
+    AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5, new L1());
+    learner.setInterval(200);
+    train(learner, 400);
+    AdaptiveLogisticRegression olr3 = roundTrip(learner, AdaptiveLogisticRegression.class);
+    double auc1 = learner.auc();
+    assertTrue(auc1 > 0.85);
+    assertEquals(auc1, learner.auc(), 1.0e-6);
+    assertEquals(auc1, olr3.auc(), 1.0e-6);
+
+    train(learner, 1000);
+    train(learner, 1000);
+    train(olr3, 1000);
+
+    assertEquals(learner.auc(), learner.auc(), 0.005);
+    assertEquals(learner.auc(), olr3.auc(), 0.005);
+    double auc2 = learner.auc();
+    assertTrue(String.format("%.3f > %.3f", auc2, auc1), auc2 > auc1);
+    learner.close();
+    olr3.close();
+  }
+
+  private static void train(OnlineLearner olr, int n) {
+    Vector beta = new DenseVector(new double[]{1, -1, 0, 0.5, -0.5});
+    Random gen = RandomUtils.getRandom();
+    for (int i = 0; i < n; i++) {
+      Vector x = randomVector(gen, 5);
+
+      int target = gen.nextDouble() < beta.dot(x) ? 1 : 0;
+      olr.train(target, x);
+    }
+  }
+
+  private static Vector randomVector(final Random gen, int n) {
+    Vector x = new DenseVector(n);
+    x.assign(new DoubleFunction() {
+      @Override
+      public double apply(double v) {
+        return gen.nextGaussian();
+      }
+    });
+    return x;
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
new file mode 100644
index 0000000..e0a252c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.CharStreams;
+import com.google.common.io.Resources;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+public abstract class OnlineBaseTest extends MahoutTestCase {
+
+  private Matrix input;
+
+  Matrix getInput() {
+    return input;
+  }
+
+  Vector readStandardData() throws IOException {
+    // 60 test samples.  First column is constant.  Second and third are normally distributed from
+    // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59).  The first 30 rows have a
+    // target variable of 0, the last 30 a target of 1.  The remaining columns are are random noise.
+    input = readCsv("sgd.csv");
+
+    // regenerate the target variable
+    Vector target = new DenseVector(60);
+    target.assign(0);
+    target.viewPart(30, 30).assign(1);
+    return target;
+  }
+
+  static void train(Matrix input, Vector target, OnlineLearner lr) {
+    RandomUtils.useTestSeed();
+    Random gen = RandomUtils.getRandom();
+
+    // train on samples in random order (but only one pass)
+    for (int row : permute(gen, 60)) {
+      lr.train((int) target.get(row), input.viewRow(row));
+    }
+    lr.close();
+  }
+
+  static void test(Matrix input, Vector target, AbstractVectorClassifier lr,
+                   double expected_mean_error, double expected_absolute_error) {
+    // now test the accuracy
+    Matrix tmp = lr.classify(input);
+    // mean(abs(tmp - target))
+    double meanAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;
+
+    // max(abs(tmp - target)
+    double maxAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
+
+    System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
+    assertEquals(0, meanAbsoluteError , expected_mean_error);
+    assertEquals(0, maxAbsoluteError, expected_absolute_error);
+
+    // convenience methods should give the same results
+    Vector v = lr.classifyScalar(input);
+    assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-5);
+    v = lr.classifyFull(input).viewColumn(1);
+    assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-4);
+  }
+
+  /**
+   * Permute the integers from 0 ... max-1
+   *
+   * @param gen The random number generator to use.
+   * @param max The number of integers to permute
+   * @return An array of jumbled integer values
+   */
+  static int[] permute(Random gen, int max) {
+    int[] permutation = new int[max];
+    permutation[0] = 0;
+    for (int i = 1; i < max; i++) {
+      int n = gen.nextInt(i + 1);
+      if (n == i) {
+        permutation[i] = i;
+      } else {
+        permutation[i] = permutation[n];
+        permutation[n] = i;
+      }
+    }
+    return permutation;
+  }
+
+
+  /**
+   * Reads a file containing CSV data.  This isn't implemented quite the way you might like for a
+   * real program, but does the job for reading test data.  Most notably, it will only read numbers,
+   * not quoted strings.
+   *
+   * @param resourceName Where to get the data.
+   * @return A matrix of the results.
+   * @throws IOException If there is an error reading the data
+   */
+  static Matrix readCsv(String resourceName) throws IOException {
+    Splitter onCommas = Splitter.on(',').trimResults(CharMatcher.anyOf(" \""));
+
+    Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8);
+    List<String> data = CharStreams.readLines(isr);
+    String first = data.get(0);
+    data = data.subList(1, data.size());
+
+    List<String> values = Lists.newArrayList(onCommas.split(first));
+    Matrix r = new DenseMatrix(data.size(), values.size());
+
+    int column = 0;
+    Map<String, Integer> labels = Maps.newHashMap();
+    for (String value : values) {
+      labels.put(value, column);
+      column++;
+    }
+    r.setColumnLabelBindings(labels);
+
+    int row = 0;
+    for (String line : data) {
+      column = 0;
+      values = Lists.newArrayList(onCommas.split(line));
+      for (String value : values) {
+        r.set(row, column, Double.parseDouble(value));
+        column++;
+      }
+      row++;
+    }
+
+    return r;
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
new file mode 100644
index 0000000..44b7525
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.io.Resources;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.junit.Assert;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+
+public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
+
+  private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class);
+
+  /**
+   * The CrossFoldLearner is probably the best learner to use for new applications.
+   *
+   * @throws IOException If test resources aren't readable.
+   */
+  @Test
+  public void crossValidation() throws IOException {
+    Vector target = readStandardData();
+
+    CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
+      .lambda(1 * 1.0e-3)
+      .learningRate(50);
+
+
+    train(getInput(), target, lr);
+
+    System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
+    test(getInput(), target, lr, 0.05, 0.3);
+
+  }
+
+  @Test
+  public void crossValidatedAuc() throws IOException {
+    RandomUtils.useTestSeed();
+    Random gen = RandomUtils.getRandom();
+
+    Matrix data = readCsv("cancer.csv");
+    CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
+      .stepOffset(10)
+      .decayExponent(0.7)
+      .lambda(1 * 1.0e-3)
+      .learningRate(5);
+    int k = 0;
+    int[] ordering = permute(gen, data.numRows());
+    for (int epoch = 0; epoch < 100; epoch++) {
+      for (int row : ordering) {
+        lr.train(row, (int) data.get(row, 9), data.viewRow(row));
+        System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc());
+      }
+      assertEquals(1, lr.auc(), 0.2);
+    }
+    assertEquals(1, lr.auc(), 0.1);
+  }
+
+  /**
+   * Verifies that a classifier with known coefficients does the right thing.
+   */
+  @Test
+  public void testClassify() {
+    OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1));
+    // set up some internal coefficients as if we had learned them
+    lr.setBeta(0, 0, -1);
+    lr.setBeta(1, 0, -2);
+
+    // zero vector gives no information.  All classes are equal.
+    Vector v = lr.classify(new DenseVector(new double[]{0, 0}));
+    assertEquals(1 / 3.0, v.get(0), 1.0e-8);
+    assertEquals(1 / 3.0, v.get(1), 1.0e-8);
+
+    v = lr.classifyFull(new DenseVector(new double[]{0, 0}));
+    assertEquals(1.0, v.zSum(), 1.0e-8);
+    assertEquals(1 / 3.0, v.get(0), 1.0e-8);
+    assertEquals(1 / 3.0, v.get(1), 1.0e-8);
+    assertEquals(1 / 3.0, v.get(2), 1.0e-8);
+
+    // weights for second vector component are still zero so all classifications are equally likely
+    v = lr.classify(new DenseVector(new double[]{0, 1}));
+    assertEquals(1 / 3.0, v.get(0), 1.0e-3);
+    assertEquals(1 / 3.0, v.get(1), 1.0e-3);
+
+    v = lr.classifyFull(new DenseVector(new double[]{0, 1}));
+    assertEquals(1.0, v.zSum(), 1.0e-8);
+    assertEquals(1 / 3.0, v.get(0), 1.0e-3);
+    assertEquals(1 / 3.0, v.get(1), 1.0e-3);
+    assertEquals(1 / 3.0, v.get(2), 1.0e-3);
+
+    // but the weights on the first component are non-zero
+    v = lr.classify(new DenseVector(new double[]{1, 0}));
+    assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
+    assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
+
+    v = lr.classifyFull(new DenseVector(new double[]{1, 0}));
+    assertEquals(1.0, v.zSum(), 1.0e-8);
+    assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
+    assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
+    assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8);
+
+    lr.setBeta(0, 1, 1);
+
+    v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+    assertEquals(1.0, v.zSum(), 1.0e-8);
+    assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3);
+    assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3);
+    assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3);
+
+    lr.setBeta(1, 1, 3);
+
+    v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+    assertEquals(1.0, v.zSum(), 1.0e-8);
+    assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8);
+    assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8);
+    assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8);
+  }
+
+  @Test
+  public void iris() throws IOException {
+    // this test trains a 3-way classifier on the famous Iris dataset.
+    // a similar exercise can be accomplished in R using this code:
+    //    library(nnet)
+    //    correct = rep(0,100)
+    //    for (j in 1:100) {
+    //      i = order(runif(150))
+    //      train = iris[i[1:100],]
+    //      test = iris[i[101:150],]
+    //      m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
+    //      correct[j] = mean(predict(m, newdata=test) == test$Species)
+    //    }
+    //    hist(correct)
+    //
+    // Note that depending on the training/test split, performance can be better or worse.
+    // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
+    // of 100%
+    //
+    // This test uses a deterministic split that is neither outstandingly good nor bad
+
+
+    RandomUtils.useTestSeed();
+    Splitter onComma = Splitter.on(",");
+
+    // read the data
+    List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);
+
+    // holds features
+    List<Vector> data = Lists.newArrayList();
+
+    // holds target variable
+    List<Integer> target = Lists.newArrayList();
+
+    // for decoding target values
+    Dictionary dict = new Dictionary();
+
+    // for permuting data later
+    List<Integer> order = Lists.newArrayList();
+
+    for (String line : raw.subList(1, raw.size())) {
+      // order gets a list of indexes
+      order.add(order.size());
+
+      // parse the predictor variables
+      Vector v = new DenseVector(5);
+      v.set(0, 1);
+      int i = 1;
+      Iterable<String> values = onComma.split(line);
+      for (String value : Iterables.limit(values, 4)) {
+        v.set(i++, Double.parseDouble(value));
+      }
+      data.add(v);
+
+      // and the target
+      target.add(dict.intern(Iterables.get(values, 4)));
+    }
+
+    // randomize the order ... original data has each species all together
+    // note that this randomization is deterministic
+    Random random = RandomUtils.getRandom();
+    Collections.shuffle(order, random);
+
+    // select training and test data
+    List<Integer> train = order.subList(0, 100);
+    List<Integer> test = order.subList(100, 150);
+    logger.warn("Training set = {}", train);
+    logger.warn("Test set = {}", test);
+
+    // now train many times and collect information on accuracy each time
+    int[] correct = new int[test.size() + 1];
+    for (int run = 0; run < 200; run++) {
+      OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
+      // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
+      for (int pass = 0; pass < 30; pass++) {
+        Collections.shuffle(train, random);
+        for (int k : train) {
+          lr.train(target.get(k), data.get(k));
+        }
+      }
+
+      // check the accuracy on held out data
+      int x = 0;
+      int[] count = new int[3];
+      for (Integer k : test) {
+        int r = lr.classifyFull(data.get(k)).maxValueIndex();
+        count[r]++;
+        x += r == target.get(k) ? 1 : 0;
+      }
+      correct[x]++;
+    }
+
+    // verify we never saw worse than 95% correct,
+    for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
+      assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]);
+    }
+    // nor perfect
+    assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size()]);
+  }
+
+  @Test
+  public void testTrain() throws Exception {
+    Vector target = readStandardData();
+
+
+    // lambda here needs to be relatively small to avoid swamping the actual signal, but can be
+    // larger than usual because the data are dense.  The learning rate doesn't matter too much
+    // for this example, but should generally be < 1
+    // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias
+    //   --target y --categories 2 --predictors  V2 V3 V4 V5 V6 V7 --types n
+    OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+      .lambda(1 * 1.0e-3)
+      .learningRate(50);
+
+    train(getInput(), target, lr);
+    test(getInput(), target, lr, 0.05, 0.3);
+  }
+
+  /**
+   * Test for Serialization/DeSerialization
+   *
+   */
+  @Test
+  public void testSerializationAndDeSerialization() throws Exception {
+    OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+      .lambda(1 * 1.0e-3)
+      .stepOffset(11)
+      .alpha(0.01)
+      .learningRate(50)
+      .decayExponent(-0.02);
+
+    lr.close();
+
+    byte[] output;
+
+    try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+         DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream)) {
+      PolymorphicWritable.write(dataOutputStream, lr);
+      output = byteArrayOutputStream.toByteArray();
+    }
+
+    OnlineLogisticRegression read;
+
+    try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(output);
+         DataInputStream dataInputStream = new DataInputStream(byteArrayInputStream)) {
+      read = PolymorphicWritable.read(dataInputStream, OnlineLogisticRegression.class);
+    }
+
+    //lambda
+    Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7);
+
+    // Reflection to get private variables
+    //stepOffset
+    Field stepOffset = lr.getClass().getDeclaredField("stepOffset");
+    stepOffset.setAccessible(true);
+    int stepOffsetVal = (Integer) stepOffset.get(lr);
+    Assert.assertEquals(11, stepOffsetVal);
+
+    //decayFactor (alpha)
+    Field decayFactor = lr.getClass().getDeclaredField("decayFactor");
+    decayFactor.setAccessible(true);
+    double decayFactorVal = (Double) decayFactor.get(lr);
+    Assert.assertEquals(0.01, decayFactorVal, 1.0e-7);
+
+    //learning rate (mu0)
+    Field mu0 = lr.getClass().getDeclaredField("mu0");
+    mu0.setAccessible(true);
+    double mu0Val = (Double) mu0.get(lr);
+    Assert.assertEquals(50, mu0Val, 1.0e-7);
+
+    //forgettingExponent (decayExponent)
+    Field forgettingExponent = lr.getClass().getDeclaredField("forgettingExponent");
+    forgettingExponent.setAccessible(true);
+    double forgettingExponentVal = (Double) forgettingExponent.get(lr);
+    Assert.assertEquals(-0.02, forgettingExponentVal, 1.0e-7);
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
new file mode 100644
index 0000000..df97d38
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public final class PassiveAggressiveTest extends OnlineBaseTest {
+
+  @Test
+  public void testPassiveAggressive() throws IOException {
+    Vector target = readStandardData();
+    PassiveAggressive pa = new PassiveAggressive(2,8).learningRate(0.1);
+    train(getInput(), target, pa);
+    test(getInput(), target, pa, 0.11, 0.31);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java b/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
new file mode 100644
index 0000000..62e10c6
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
@@ -0,0 +1,152 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import java.io.IOException;
+import java.util.Random;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.stats.Sampler;
+
+public final class ClusteringTestUtils {
+
+  private ClusteringTestUtils() {
+  }
+
+  public static void writePointsToFile(Iterable<VectorWritable> points,
+                                       Path path,
+                                       FileSystem fs,
+                                       Configuration conf) throws IOException {
+    writePointsToFile(points, false, path, fs, conf);
+  }
+
+  public static void writePointsToFile(Iterable<VectorWritable> points,
+                                       boolean intWritable,
+                                       Path path,
+                                       FileSystem fs,
+                                       Configuration conf) throws IOException {
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs,
+                                                         conf,
+                                                         path,
+                                                         intWritable ? IntWritable.class : LongWritable.class,
+                                                         VectorWritable.class);
+    try {
+      int recNum = 0;
+      for (VectorWritable point : points) {
+        writer.append(intWritable ? new IntWritable(recNum++) : new LongWritable(recNum++), point);
+      }
+    } finally {
+      Closeables.close(writer, false);
+    }
+  }
+
+  public static Matrix sampledCorpus(Matrix matrix, Random random,
+      int numDocs, int numSamples, int numTopicsPerDoc) {
+    Matrix corpus = new SparseRowMatrix(numDocs, matrix.numCols());
+    LDASampler modelSampler = new LDASampler(matrix, random);
+    Vector topicVector = new DenseVector(matrix.numRows());
+    for (int i = 0; i < numTopicsPerDoc; i++) {
+      int topic = random.nextInt(topicVector.size());
+      topicVector.set(topic, topicVector.get(topic) + 1);
+    }
+    for (int docId = 0; docId < numDocs; docId++) {
+      for (int sample : modelSampler.sample(topicVector, numSamples)) {
+        corpus.set(docId, sample, corpus.get(docId, sample) + 1);
+      }
+    }
+    return corpus;
+  }
+
+  public static Matrix randomStructuredModel(int numTopics, int numTerms) {
+    return randomStructuredModel(numTopics, numTerms, new DoubleFunction() {
+      @Override public double apply(double d) {
+        return 1.0 / (1 + Math.abs(d));
+      }
+    });
+  }
+
+  public static Matrix randomStructuredModel(int numTopics, int numTerms, DoubleFunction decay) {
+    Matrix model = new DenseMatrix(numTopics, numTerms);
+    int width = numTerms / numTopics;
+    for (int topic = 0; topic < numTopics; topic++) {
+      int topicCentroid = width * (1+topic);
+      for (int i = 0; i < numTerms; i++) {
+        int distance = Math.abs(topicCentroid - i);
+        if (distance > numTerms / 2) {
+          distance = numTerms - distance;
+        }
+        double v = decay.apply(distance);
+        model.set(topic, i, v);
+      }
+    }
+    return model;
+  }
+
+  /**
+   * Takes in a {@link Matrix} of topic distributions (such as generated by {@link org.apache.mahout.clustering.lda.cvb.CVB0Driver} or
+   * {@link org.apache.mahout.clustering.lda.cvb.InMemoryCollapsedVariationalBayes0}, and constructs
+   * a set of samplers over this distribution, which may be sampled from by providing a distribution
+   * over topics, and a number of samples desired
+   */
+  static class LDASampler {
+      private final Random random;
+      private final Sampler[] samplers;
+
+      LDASampler(Matrix model, Random random) {
+          this.random = random;
+          samplers = new Sampler[model.numRows()];
+          for (int i = 0; i < samplers.length; i++) {
+              samplers[i] = new Sampler(random, model.viewRow(i));
+          }
+      }
+
+      /**
+       *
+       * @param topicDistribution vector of p(topicId) for all topicId < model.numTopics()
+       * @param numSamples the number of times to sample (with replacement) from the model
+       * @return array of length numSamples, with each entry being a sample from the model.  There
+       * may be repeats
+       */
+      public int[] sample(Vector topicDistribution, int numSamples) {
+          Preconditions.checkNotNull(topicDistribution);
+          Preconditions.checkArgument(numSamples > 0, "numSamples must be positive");
+          Preconditions.checkArgument(topicDistribution.size() == samplers.length,
+                  "topicDistribution must have same cardinality as the sampling model");
+          int[] samples = new int[numSamples];
+          Sampler topicSampler = new Sampler(random, topicDistribution);
+          for (int i = 0; i < numSamples; i++) {
+              samples[i] = samplers[topicSampler.sample()].sample();
+          }
+          return samples;
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java b/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
new file mode 100644
index 0000000..1cbfb02
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public final class TestClusterInterface extends MahoutTestCase {
+
+  private static final DistanceMeasure measure = new ManhattanDistanceMeasure();
+
+  @Test
+  public void testClusterAsFormatString() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+    String formatString = cluster.asFormatString(null);
+    assertTrue(formatString.contains("\"r\":[]"));
+    assertTrue(formatString.contains("\"c\":[1.1,2.2,3.3]"));
+    assertTrue(formatString.contains("\"n\":0"));
+    assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+  }
+
+  @Test
+  public void testClusterAsFormatStringSparse() {
+    double[] d = { 1.1, 0.0, 3.3 };
+    Vector m = new SequentialAccessSparseVector(3);
+    m.assign(d);
+    Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+    String formatString = cluster.asFormatString(null);
+    assertTrue(formatString.contains("\"r\":[]"));
+    assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]"));
+    assertTrue(formatString.contains("\"n\":0"));
+    assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+  }
+
+  @Test
+  public void testClusterAsFormatStringWithBindings() {
+    double[] d = { 1.1, 2.2, 3.3 };
+    Vector m = new DenseVector(d);
+    Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+    String[] bindings = { "fee", null, "foo" };
+    String formatString = cluster.asFormatString(bindings);
+    assertTrue(formatString.contains("\"r\":[]"));
+    assertTrue(formatString.contains("\"c\":[{\"fee\":1.1},{\"1\":2.2},{\"foo\":3.3}]"));
+    assertTrue(formatString.contains("\"n\":0"));
+    assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+  }
+
+  @Test
+  public void testClusterAsFormatStringSparseWithBindings() {
+    double[] d = { 1.1, 0.0, 3.3 };
+    Vector m = new SequentialAccessSparseVector(3);
+    m.assign(d);
+    Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+    String formatString = cluster.asFormatString(null);
+    assertTrue(formatString.contains("\"r\":[]"));
+    assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]"));
+    assertTrue(formatString.contains("\"n\":0"));
+    assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java b/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
new file mode 100644
index 0000000..43417fc
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
@@ -0,0 +1,186 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.util.Collection;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.SquareRootFunction;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class TestGaussianAccumulators extends MahoutTestCase {
+
+  private static final Logger log = LoggerFactory.getLogger(TestGaussianAccumulators.class);
+
+  private Collection<VectorWritable> sampleData = Lists.newArrayList();
+  private int sampleN;
+  private Vector sampleMean;
+  private Vector sampleStd;
+
+  @Override
+  @Before
+  public void setUp() throws Exception {
+    super.setUp();
+    sampleData = Lists.newArrayList();
+    generateSamples();
+    sampleN = 0;
+    Vector sum = new DenseVector(2);
+    for (VectorWritable v : sampleData) {
+      sum.assign(v.get(), Functions.PLUS);
+      sampleN++;
+    }
+    sampleMean = sum.divide(sampleN);
+
+    Vector sampleVar = new DenseVector(2);
+    for (VectorWritable v : sampleData) {
+      Vector delta = v.get().minus(sampleMean);
+      sampleVar.assign(delta.times(delta), Functions.PLUS);
+    }
+    sampleVar = sampleVar.divide(sampleN - 1);
+    sampleStd = sampleVar.clone();
+    sampleStd.assign(new SquareRootFunction());
+    log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]",
+             sampleN, sampleMean.get(0), sampleMean.get(1), sampleStd.get(0), sampleStd.get(1));
+  }
+
+  /**
+   * Generate random samples and add them to the sampleData
+   *
+   * @param num
+   *          int number of samples to generate
+   * @param mx
+   *          double x-value of the sample mean
+   * @param my
+   *          double y-value of the sample mean
+   * @param sdx
+   *          double x-value standard deviation of the samples
+   * @param sdy
+   *          double y-value standard deviation of the samples
+   */
+  private void generate2dSamples(int num, double mx, double my, double sdx, double sdy) {
+    log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy);
+    for (int i = 0; i < num; i++) {
+      sampleData.add(new VectorWritable(new DenseVector(new double[] { UncommonDistributions.rNorm(mx, sdx),
+          UncommonDistributions.rNorm(my, sdy) })));
+    }
+  }
+
+  private void generateSamples() {
+    generate2dSamples(50000, 1, 2, 3, 4);
+  }
+
+  @Test
+  public void testAccumulatorNoSamples() {
+    GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+    GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+    accumulator0.compute();
+    accumulator1.compute();
+    assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+    assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
+    assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON);
+  }
+
+  @Test
+  public void testAccumulatorOneSample() {
+    GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+    GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+    Vector sample = new DenseVector(2);
+    accumulator0.observe(sample, 1.0);
+    accumulator1.observe(sample, 1.0);
+    accumulator0.compute();
+    accumulator1.compute();
+    assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+    assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
+    assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON);
+  }
+
+  @Test
+  public void testOLAccumulatorResults() {
+    GaussianAccumulator accumulator = new OnlineGaussianAccumulator();
+    for (VectorWritable vw : sampleData) {
+      accumulator.observe(vw.get(), 1.0);
+    }
+    accumulator.compute();
+    log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]",
+             accumulator.getN(),
+             accumulator.getMean().get(0),
+             accumulator.getMean().get(1),
+             accumulator.getStd().get(0),
+             accumulator.getStd().get(1));
+    assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+    assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
+    assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), EPSILON);
+  }
+
+  @Test
+  public void testRSAccumulatorResults() {
+    GaussianAccumulator accumulator = new RunningSumsGaussianAccumulator();
+    for (VectorWritable vw : sampleData) {
+      accumulator.observe(vw.get(), 1.0);
+    }
+    accumulator.compute();
+    log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]",
+             (int) accumulator.getN(),
+             accumulator.getMean().get(0),
+             accumulator.getMean().get(1),
+             accumulator.getStd().get(0),
+             accumulator.getStd().get(1));
+    assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+    assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
+    assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), 0.0001);
+  }
+
+  @Test
+  public void testAccumulatorWeightedResults() {
+    GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+    GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+    for (VectorWritable vw : sampleData) {
+      accumulator0.observe(vw.get(), 0.5);
+      accumulator1.observe(vw.get(), 0.5);
+    }
+    accumulator0.compute();
+    accumulator1.compute();
+    assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+    assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
+    assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
+    assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01);
+  }
+
+  @Test
+  public void testAccumulatorWeightedResults2() {
+    GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+    GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+    for (VectorWritable vw : sampleData) {
+      accumulator0.observe(vw.get(), 1.5);
+      accumulator1.observe(vw.get(), 1.5);
+    }
+    accumulator0.compute();
+    accumulator1.compute();
+    assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+    assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
+    assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
+    assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01);
+  }
+}