You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:07:37 UTC
[06/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
new file mode 100644
index 0000000..3104cb1
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.junit.Test;
+
+public class HMMEvaluatorTest extends HMMTestBase {
+
+ /**
+ * Test to make sure the computed model likelihood ist valid. Included tests
+ * are: a) forwad == backward likelihood b) model likelihood for test seqeunce
+ * is the expected one from R reference
+ */
+ @Test
+ public void testModelLikelihood() {
+ // compute alpha and beta values
+ Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), false);
+ Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), false);
+ // now test whether forward == backward likelihood
+ double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, false);
+ double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(),
+ beta, false);
+ assertEquals(forwardLikelihood, backwardLikelihood, EPSILON);
+ // also make sure that the likelihood matches the expected one
+ assertEquals(1.8425e-4, forwardLikelihood, EPSILON);
+ }
+
+ /**
+ * Test to make sure the computed model likelihood ist valid. Included tests
+ * are: a) forwad == backward likelihood b) model likelihood for test seqeunce
+ * is the expected one from R reference
+ */
+ @Test
+ public void testScaledModelLikelihood() {
+ // compute alpha and beta values
+ Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), true);
+ Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), true);
+ // now test whether forward == backward likelihood
+ double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, true);
+ double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(),
+ beta, true);
+ assertEquals(forwardLikelihood, backwardLikelihood, EPSILON);
+ // also make sure that the likelihood matches the expected one
+ assertEquals(1.8425e-4, forwardLikelihood, EPSILON);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
new file mode 100644
index 0000000..3260f51
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.junit.Test;
+
+public class HMMModelTest extends HMMTestBase {
+
+ @Test
+ public void testRandomModelGeneration() {
+ // make sure we generate a valid random model
+ HmmModel model = new HmmModel(10, 20);
+ // check whether the model is valid
+ HmmUtils.validate(model);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
new file mode 100644
index 0000000..90f1cd8
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+
+public class HMMTestBase extends MahoutTestCase {
+
+ private HmmModel model;
+ private final int[] sequence = {1, 0, 2, 2, 0, 0, 1};
+
+ /**
+ * We initialize a new HMM model using the following parameters # hidden
+ * states: 4 ("H0","H1","H2","H3") # output states: 3 ("O0","O1","O2") #
+ * transition matrix to: H0 H1 H2 H3 from: H0 0.5 0.1 0.1 0.3 H1 0.4 0.4 0.1
+ * 0.1 H2 0.1 0.0 0.8 0.1 H3 0.1 0.1 0.1 0.7 # output matrix to: O0 O1 O2
+ * from: H0 0.8 0.1 0.1 H1 0.6 0.1 0.3 H2 0.1 0.8 0.1 H3 0.0 0.1 0.9 # initial
+ * probabilities H0 0.2
+ * <p/>
+ * H1 0.1 H2 0.4 H3 0.3
+ * <p/>
+ * We also intialize an observation sequence: "O1" "O0" "O2" "O2" "O0" "O0"
+ * "O1"
+ */
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ // intialize the hidden/output state names
+ String[] hiddenNames = {"H0", "H1", "H2", "H3"};
+ String[] outputNames = {"O0", "O1", "O2"};
+ // initialize the transition matrix
+ double[][] transitionP = {{0.5, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
+ {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
+ // initialize the emission matrix
+ double[][] emissionP = {{0.8, 0.1, 0.1}, {0.6, 0.1, 0.3},
+ {0.1, 0.8, 0.1}, {0.0, 0.1, 0.9}};
+ // initialize the initial probability vector
+ double[] initialP = {0.2, 0.1, 0.4, 0.3};
+ // now generate the model
+ model = new HmmModel(new DenseMatrix(transitionP), new DenseMatrix(
+ emissionP), new DenseVector(initialP));
+ model.registerHiddenStateNames(hiddenNames);
+ model.registerOutputStateNames(outputNames);
+ // make sure the model is valid :)
+ HmmUtils.validate(model);
+ }
+
+ protected HmmModel getModel() {
+ return model;
+ }
+
+ protected int[] getSequence() {
+ return sequence;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
new file mode 100644
index 0000000..b8f3186
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
@@ -0,0 +1,163 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class HMMTrainerTest extends HMMTestBase {
+
+ @Test
+ public void testViterbiTraining() {
+ // initialize the expected model parameters (from R)
+ // expected transition matrix
+ double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125},
+ {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
+ {0.5, 0.1, 0.1, 0.3}};
+ // initialize the emission matrix
+ double[][] emissionE = {{0.882353, 0.058824, 0.058824},
+ {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
+ {0.111111, 0.111111, 0.777778}};
+
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10, false);
+
+ // now check whether the model matches our expectations
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j], EPSILON);
+ }
+
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j], EPSILON);
+ }
+ }
+
+ }
+
+ @Test
+ public void testScaledViterbiTraining() {
+ // initialize the expected model parameters (from R)
+ // expected transition matrix
+ double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125},
+ {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
+ {0.5, 0.1, 0.1, 0.3}};
+ // initialize the emission matrix
+ double[][] emissionE = {{0.882353, 0.058824, 0.058824},
+ {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
+ {0.111111, 0.111111, 0.777778}};
+
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10,
+ true);
+
+ // now check whether the model matches our expectations
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j],
+ EPSILON);
+ }
+
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j],
+ EPSILON);
+ }
+ }
+
+ }
+
+ @Test
+ public void testBaumWelchTraining() {
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ // expected values from Matlab HMM package / R HMM package
+ double[] initialExpected = {0, 0, 1.0, 0};
+ double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683},
+ {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0},
+ {0.0024, 0.6657, 0, 0.3319}};
+ double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
+ {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
+
+ HmmModel trained = HmmTrainer.trainBaumWelch(getModel(), observed, 0.1, 10,
+ false);
+
+ Vector initialProbabilities = trained.getInitialProbabilities();
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ assertEquals(initialProbabilities.get(i), initialExpected[i],
+ 0.0001);
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j),
+ transitionExpected[i][j], 0.0001);
+ }
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j),
+ emissionExpected[i][j], 0.0001);
+ }
+ }
+ }
+
+ @Test
+ public void testScaledBaumWelchTraining() {
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ // expected values from Matlab HMM package / R HMM package
+ double[] initialExpected = {0, 0, 1.0, 0};
+ double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683},
+ {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0},
+ {0.0024, 0.6657, 0, 0.3319}};
+ double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
+ {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
+
+ HmmModel trained = HmmTrainer
+ .trainBaumWelch(getModel(), observed, 0.1, 10, true);
+
+ Vector initialProbabilities = trained.getInitialProbabilities();
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ assertEquals(initialProbabilities.get(i), initialExpected[i],
+ 0.0001);
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j),
+ transitionExpected[i][j], 0.0001);
+ }
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j),
+ emissionExpected[i][j], 0.0001);
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
new file mode 100644
index 0000000..6c34718
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class HMMUtilsTest extends HMMTestBase {
+
+ private Matrix legal22;
+ private Matrix legal23;
+ private Matrix legal33;
+ private Vector legal2;
+ private Matrix illegal22;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ legal22 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}});
+ legal23 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6},
+ {0.3, 0.3, 0.4}});
+ legal33 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8},
+ {0.1, 0.2, 0.7}, {0.2, 0.3, 0.5}});
+ legal2 = new DenseVector(new double[]{0.4, 0.6});
+ illegal22 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}});
+ }
+
+ @Test
+ public void testValidatorLegal() {
+ HmmUtils.validate(new HmmModel(legal22, legal23, legal2));
+ }
+
+ @Test
+ public void testValidatorDimensionError() {
+ try {
+ HmmUtils.validate(new HmmModel(legal33, legal23, legal2));
+ } catch (IllegalArgumentException e) {
+ // success
+ return;
+ }
+ fail();
+ }
+
+ @Test
+ public void testValidatorIllegelMatrixError() {
+ try {
+ HmmUtils.validate(new HmmModel(illegal22, legal23, legal2));
+ } catch (IllegalArgumentException e) {
+ // success
+ return;
+ }
+ fail();
+ }
+
+ @Test
+ public void testEncodeStateSequence() {
+ String[] hiddenSequence = {"H1", "H2", "H0", "H3", "H4"};
+ String[] outputSequence = {"O1", "O2", "O4", "O0"};
+ // test encoding the hidden Sequence
+ int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays
+ .asList(hiddenSequence), false, -1);
+ int[] outputSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays
+ .asList(outputSequence), true, -1);
+ // expected state sequences
+ int[] hiddenSequenceExp = {1, 2, 0, 3, -1};
+ int[] outputSequenceExp = {1, 2, -1, 0};
+ // compare
+ for (int i = 0; i < hiddenSequenceEnc.length; ++i) {
+ assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]);
+ }
+ for (int i = 0; i < outputSequenceEnc.length; ++i) {
+ assertEquals(outputSequenceExp[i], outputSequenceEnc[i]);
+ }
+ }
+
+ @Test
+ public void testDecodeStateSequence() {
+ int[] hiddenSequence = {1, 2, 0, 3, 10};
+ int[] outputSequence = {1, 2, 10, 0};
+ // test encoding the hidden Sequence
+ List<String> hiddenSequenceDec = HmmUtils.decodeStateSequence(
+ getModel(), hiddenSequence, false, "unknown");
+ List<String> outputSequenceDec = HmmUtils.decodeStateSequence(
+ getModel(), outputSequence, true, "unknown");
+ // expected state sequences
+ String[] hiddenSequenceExp = {"H1", "H2", "H0", "H3", "unknown"};
+ String[] outputSequenceExp = {"O1", "O2", "unknown", "O0"};
+ // compare
+ for (int i = 0; i < hiddenSequenceExp.length; ++i) {
+ assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i));
+ }
+ for (int i = 0; i < outputSequenceExp.length; ++i) {
+ assertEquals(outputSequenceExp[i], outputSequenceDec.get(i));
+ }
+ }
+
+ @Test
+ public void testNormalizeModel() {
+ DenseVector ip = new DenseVector(new double[]{10, 20});
+ DenseMatrix tr = new DenseMatrix(new double[][]{{10, 10}, {20, 25}});
+ DenseMatrix em = new DenseMatrix(new double[][]{{5, 7}, {10, 15}});
+ HmmModel model = new HmmModel(tr, em, ip);
+ HmmUtils.normalizeModel(model);
+ // the model should be valid now
+ HmmUtils.validate(model);
+ }
+
+ @Test
+ public void testTruncateModel() {
+ DenseVector ip = new DenseVector(new double[]{0.0001, 0.0001, 0.9998});
+ DenseMatrix tr = new DenseMatrix(new double[][]{
+ {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
+ {0.0001, 0.0001, 0.9998}});
+ DenseMatrix em = new DenseMatrix(new double[][]{
+ {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
+ {0.0001, 0.0001, 0.9998}});
+ HmmModel model = new HmmModel(tr, em, ip);
+ // now truncate the model
+ HmmModel sparseModel = HmmUtils.truncateModel(model, 0.01);
+ // first make sure this is a valid model
+ HmmUtils.validate(sparseModel);
+ // now check whether the values are as expected
+ Vector sparse_ip = sparseModel.getInitialProbabilities();
+ Matrix sparse_tr = sparseModel.getTransitionMatrix();
+ Matrix sparse_em = sparseModel.getEmissionMatrix();
+ for (int i = 0; i < sparseModel.getNrOfHiddenStates(); ++i) {
+ assertEquals(i == 2 ? 1.0 : 0.0, sparse_ip.getQuick(i), EPSILON);
+ for (int j = 0; j < sparseModel.getNrOfHiddenStates(); ++j) {
+ if (i == j) {
+ assertEquals(1.0, sparse_tr.getQuick(i, j), EPSILON);
+ assertEquals(1.0, sparse_em.getQuick(i, j), EPSILON);
+ } else {
+ assertEquals(0.0, sparse_tr.getQuick(i, j), EPSILON);
+ assertEquals(0.0, sparse_em.getQuick(i, j), EPSILON);
+ }
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
new file mode 100644
index 0000000..7ea8cb2
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.jet.random.Exponential;
+import org.junit.Test;
+
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
+
+import java.util.Random;
+
+public final class AdaptiveLogisticRegressionTest extends MahoutTestCase {
+
+ @ThreadLeakLingering(linger=1000)
+ @Test
+ public void testTrain() {
+
+ Random gen = RandomUtils.getRandom();
+ Exponential exp = new Exponential(0.5, gen);
+ Vector beta = new DenseVector(200);
+ for (Vector.Element element : beta.all()) {
+ int sign = 1;
+ if (gen.nextDouble() < 0.5) {
+ sign = -1;
+ }
+ element.set(sign * exp.nextDouble());
+ }
+
+ AdaptiveLogisticRegression.Wrapper cl = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1());
+ cl.update(new double[]{1.0e-5, 1});
+
+ for (int i = 0; i < 10000; i++) {
+ AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+ cl.train(r);
+ if (i % 1000 == 0) {
+ System.out.printf("%10d %10.3f\n", i, cl.getLearner().auc());
+ }
+ }
+ assertEquals(1, cl.getLearner().auc(), 0.1);
+
+ AdaptiveLogisticRegression adaptiveLogisticRegression = new AdaptiveLogisticRegression(2, 200, new L1());
+ adaptiveLogisticRegression.setInterval(1000);
+
+ for (int i = 0; i < 20000; i++) {
+ AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+ adaptiveLogisticRegression.train(r.getKey(), r.getActual(), r.getInstance());
+ if (i % 1000 == 0 && adaptiveLogisticRegression.getBest() != null) {
+ System.out.printf("%10d %10.4f %10.8f %.3f\n",
+ i, adaptiveLogisticRegression.auc(),
+ Math.log10(adaptiveLogisticRegression.getBest().getMappedParams()[0]), adaptiveLogisticRegression.getBest().getMappedParams()[1]);
+ }
+ }
+ assertEquals(1, adaptiveLogisticRegression.auc(), 0.1);
+ adaptiveLogisticRegression.close();
+ }
+
+ private static AdaptiveLogisticRegression.TrainingExample getExample(int i, Random gen, Vector beta) {
+ Vector data = new DenseVector(200);
+
+ for (Vector.Element element : data.all()) {
+ element.set(gen.nextDouble() < 0.3 ? 1 : 0);
+ }
+
+ double p = 1 / (1 + Math.exp(1.5 - data.dot(beta)));
+ int target = 0;
+ if (gen.nextDouble() < p) {
+ target = 1;
+ }
+ return new AdaptiveLogisticRegression.TrainingExample(i, null, target, data);
+ }
+
+ @Test
+ public void copyLearnsAsExpected() {
+ Random gen = RandomUtils.getRandom();
+ Exponential exp = new Exponential(0.5, gen);
+ Vector beta = new DenseVector(200);
+ for (Vector.Element element : beta.all()) {
+ int sign = 1;
+ if (gen.nextDouble() < 0.5) {
+ sign = -1;
+ }
+ element.set(sign * exp.nextDouble());
+ }
+
+ // train one copy of a wrapped learner
+ AdaptiveLogisticRegression.Wrapper w = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1());
+ for (int i = 0; i < 3000; i++) {
+ AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+ w.train(r);
+ if (i % 1000 == 0) {
+ System.out.printf("%10d %.3f\n", i, w.getLearner().auc());
+ }
+ }
+ System.out.printf("%10d %.3f\n", 3000, w.getLearner().auc());
+ double auc1 = w.getLearner().auc();
+
+ // then switch to a copy of that learner ... progress should continue
+ AdaptiveLogisticRegression.Wrapper w2 = w.copy();
+
+ for (int i = 0; i < 5000; i++) {
+ if (i % 1000 == 0) {
+ if (i == 0) {
+ assertEquals("Should have started with no data", 0.5, w2.getLearner().auc(), 0.0001);
+ }
+ if (i == 1000) {
+ double auc2 = w2.getLearner().auc();
+ assertTrue("Should have had head-start", Math.abs(auc2 - 0.5) > 0.1);
+ assertTrue("AUC should improve quickly on copy", auc1 < auc2);
+ }
+ System.out.printf("%10d %.3f\n", i, w2.getLearner().auc());
+ }
+ AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+ w2.train(r);
+ }
+ assertEquals("Original should not change after copy is updated", auc1, w.getLearner().auc(), 1.0e-5);
+
+ // this improvement is really quite lenient
+ assertTrue("AUC should improve significantly on copy", auc1 < w2.getLearner().auc() - 0.05);
+
+ // make sure that the copy didn't lose anything
+ assertEquals(auc1, w.getLearner().auc(), 0);
+ }
+
+ @Test
+ public void stepSize() {
+ assertEquals(500, AdaptiveLogisticRegression.stepSize(15000, 2));
+ assertEquals(2000, AdaptiveLogisticRegression.stepSize(15000, 2.6));
+ assertEquals(5000, AdaptiveLogisticRegression.stepSize(24000, 2.6));
+ assertEquals(10000, AdaptiveLogisticRegression.stepSize(15000, 3));
+ }
+
+ @Test
+ @ThreadLeakLingering(linger = 1000)
+ public void constantStep() {
+ AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1());
+ lr.setInterval(5000);
+ assertEquals(20000, lr.nextStep(15000));
+ assertEquals(20000, lr.nextStep(15001));
+ assertEquals(20000, lr.nextStep(16500));
+ assertEquals(20000, lr.nextStep(19999));
+ lr.close();
+ }
+
+
+ @Test
+ @ThreadLeakLingering(linger = 1000)
+ public void growingStep() {
+ AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1());
+ lr.setInterval(2000, 10000);
+
+ // start with minimum step size
+ for (int i = 2000; i < 20000; i+=2000) {
+ assertEquals(i + 2000, lr.nextStep(i));
+ }
+
+ // then level up a bit
+ for (int i = 20000; i < 50000; i += 5000) {
+ assertEquals(i + 5000, lr.nextStep(i));
+ }
+
+ // and more, but we top out with this step size
+ for (int i = 50000; i < 500000; i += 10000) {
+ assertEquals(i + 10000, lr.nextStep(i));
+ }
+ lr.close();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
new file mode 100644
index 0000000..6ee0ddf
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.ImmutableMap;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.junit.Test;
+
+public final class CsvRecordFactoryTest extends MahoutTestCase {
+
+ @Test
+ public void testAddToVector() {
+ RecordFactory csv = new CsvRecordFactory("y", ImmutableMap.of("x1", "n", "x2", "w", "x3", "t"));
+ csv.firstLine("z,x1,y,x2,x3,q");
+ csv.maxTargetValue(2);
+
+ Vector v = new DenseVector(2000);
+ int t = csv.processLine("ignore,3.1,yes,tiger, \"this is text\",ignore", v);
+ assertEquals(0, t);
+ // should have 9 values set
+ assertEquals(9.0, v.norm(0), 0);
+ // all should be = 1 except for the 3.1
+ assertEquals(3.1, v.maxValue(), 0);
+ v.set(v.maxValueIndex(), 0);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(8.0, v.norm(1), 0);
+ assertEquals(1.0, v.maxValue(), 0);
+
+ v.assign(0);
+ t = csv.processLine("ignore,5.3,no,line, \"and more text and more\",ignore", v);
+ assertEquals(1, t);
+
+ // should have 9 values set
+ assertEquals(9.0, v.norm(0), 0);
+ // all should be = 1 except for the 3.1
+ assertEquals(5.3, v.maxValue(), 0);
+ v.set(v.maxValueIndex(), 0);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(10.339850002884626, v.norm(1), 1.0e-6);
+ assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6);
+
+ v.assign(0);
+ t = csv.processLine("ignore,5.3,invalid,line, \"and more text and more\",ignore", v);
+ assertEquals(1, t);
+
+ // should have 9 values set
+ assertEquals(9.0, v.norm(0), 0);
+ // all should be = 1 except for the 3.1
+ assertEquals(5.3, v.maxValue(), 0);
+ v.set(v.maxValueIndex(), 0);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(10.339850002884626, v.norm(1), 1.0e-6);
+ assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6);
+ }
+
+ @Test
+ public void testDictionaryOrder() {
+ Dictionary dict = new Dictionary();
+
+ dict.intern("a");
+ dict.intern("d");
+ dict.intern("c");
+ dict.intern("b");
+ dict.intern("qrz");
+
+ assertEquals("[a, d, c, b, qrz]", dict.values().toString());
+
+ Dictionary dict2 = Dictionary.fromList(dict.values());
+ assertEquals("[a, d, c, b, qrz]", dict2.values().toString());
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
new file mode 100644
index 0000000..06a876e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Random;
+
+public final class GradientMachineTest extends OnlineBaseTest {
+
+ @Test
+ public void testGradientmachine() throws IOException {
+ Vector target = readStandardData();
+ GradientMachine grad = new GradientMachine(8,4,2).learningRate(0.1).regularization(0.01);
+ Random gen = RandomUtils.getRandom();
+ grad.initWeights(gen);
+ train(getInput(), target, grad);
+ // TODO not sure why the RNG change made this fail. Value is 0.5-1.0 no matter what seed is chosen?
+ test(getInput(), target, grad, 1.0, 1);
+ //test(getInput(), target, grad, 0.05, 1);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
new file mode 100644
index 0000000..2373b9d
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Random;
+
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+import org.junit.Test;
+
+public final class ModelSerializerTest extends MahoutTestCase {
+
+ private static <T extends Writable> T roundTrip(T m, Class<T> clazz) throws IOException {
+ ByteArrayOutputStream buf = new ByteArrayOutputStream(1000);
+ DataOutputStream dos = new DataOutputStream(buf);
+ try {
+ PolymorphicWritable.write(dos, m);
+ } finally {
+ Closeables.close(dos, false);
+ }
+ return PolymorphicWritable.read(new DataInputStream(new ByteArrayInputStream(buf.toByteArray())), clazz);
+ }
+
+ @Test
+ public void onlineAucRoundtrip() throws IOException {
+ RandomUtils.useTestSeed();
+ OnlineAuc auc1 = new GlobalOnlineAuc();
+ Random gen = RandomUtils.getRandom();
+ for (int i = 0; i < 10000; i++) {
+ auc1.addSample(0, gen.nextGaussian());
+ auc1.addSample(1, gen.nextGaussian() + 1);
+ }
+ assertEquals(0.76, auc1.auc(), 0.01);
+
+ OnlineAuc auc3 = roundTrip(auc1, OnlineAuc.class);
+
+ assertEquals(auc1.auc(), auc3.auc(), 0);
+
+ for (int i = 0; i < 1000; i++) {
+ auc1.addSample(0, gen.nextGaussian());
+ auc1.addSample(1, gen.nextGaussian() + 1);
+
+ auc3.addSample(0, gen.nextGaussian());
+ auc3.addSample(1, gen.nextGaussian() + 1);
+ }
+
+ assertEquals(auc1.auc(), auc3.auc(), 0.01);
+ }
+
+ @Test
+ public void onlineLogisticRegressionRoundTrip() throws IOException {
+ OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, new L1());
+ train(olr, 100);
+ OnlineLogisticRegression olr3 = roundTrip(olr, OnlineLogisticRegression.class);
+ assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6);
+
+ train(olr, 100);
+ train(olr3, 100);
+
+ assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6);
+ olr.close();
+ olr3.close();
+ }
+
+ @Test
+ public void crossFoldLearnerRoundTrip() throws IOException {
+ CrossFoldLearner learner = new CrossFoldLearner(5, 2, 5, new L1());
+ train(learner, 100);
+ CrossFoldLearner olr3 = roundTrip(learner, CrossFoldLearner.class);
+ double auc1 = learner.auc();
+ assertTrue(auc1 > 0.85);
+ assertEquals(auc1, learner.auc(), 1.0e-6);
+ assertEquals(auc1, olr3.auc(), 1.0e-6);
+
+ train(learner, 100);
+ train(learner, 100);
+ train(olr3, 100);
+
+ assertEquals(learner.auc(), learner.auc(), 0.02);
+ assertEquals(learner.auc(), olr3.auc(), 0.02);
+ double auc2 = learner.auc();
+ assertTrue(auc2 > auc1);
+ learner.close();
+ olr3.close();
+ }
+
+ @ThreadLeakLingering(linger = 1000)
+ @Test
+ public void adaptiveLogisticRegressionRoundTrip() throws IOException {
+ AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5, new L1());
+ learner.setInterval(200);
+ train(learner, 400);
+ AdaptiveLogisticRegression olr3 = roundTrip(learner, AdaptiveLogisticRegression.class);
+ double auc1 = learner.auc();
+ assertTrue(auc1 > 0.85);
+ assertEquals(auc1, learner.auc(), 1.0e-6);
+ assertEquals(auc1, olr3.auc(), 1.0e-6);
+
+ train(learner, 1000);
+ train(learner, 1000);
+ train(olr3, 1000);
+
+ assertEquals(learner.auc(), learner.auc(), 0.005);
+ assertEquals(learner.auc(), olr3.auc(), 0.005);
+ double auc2 = learner.auc();
+ assertTrue(String.format("%.3f > %.3f", auc2, auc1), auc2 > auc1);
+ learner.close();
+ olr3.close();
+ }
+
+ private static void train(OnlineLearner olr, int n) {
+ Vector beta = new DenseVector(new double[]{1, -1, 0, 0.5, -0.5});
+ Random gen = RandomUtils.getRandom();
+ for (int i = 0; i < n; i++) {
+ Vector x = randomVector(gen, 5);
+
+ int target = gen.nextDouble() < beta.dot(x) ? 1 : 0;
+ olr.train(target, x);
+ }
+ }
+
+ private static Vector randomVector(final Random gen, int n) {
+ Vector x = new DenseVector(n);
+ x.assign(new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return gen.nextGaussian();
+ }
+ });
+ return x;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
new file mode 100644
index 0000000..e0a252c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.CharStreams;
+import com.google.common.io.Resources;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+public abstract class OnlineBaseTest extends MahoutTestCase {
+
+ private Matrix input;
+
+ Matrix getInput() {
+ return input;
+ }
+
+ Vector readStandardData() throws IOException {
+ // 60 test samples. First column is constant. Second and third are normally distributed from
+ // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The first 30 rows have a
+ // target variable of 0, the last 30 a target of 1. The remaining columns are are random noise.
+ input = readCsv("sgd.csv");
+
+ // regenerate the target variable
+ Vector target = new DenseVector(60);
+ target.assign(0);
+ target.viewPart(30, 30).assign(1);
+ return target;
+ }
+
+ static void train(Matrix input, Vector target, OnlineLearner lr) {
+ RandomUtils.useTestSeed();
+ Random gen = RandomUtils.getRandom();
+
+ // train on samples in random order (but only one pass)
+ for (int row : permute(gen, 60)) {
+ lr.train((int) target.get(row), input.viewRow(row));
+ }
+ lr.close();
+ }
+
+ static void test(Matrix input, Vector target, AbstractVectorClassifier lr,
+ double expected_mean_error, double expected_absolute_error) {
+ // now test the accuracy
+ Matrix tmp = lr.classify(input);
+ // mean(abs(tmp - target))
+ double meanAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;
+
+ // max(abs(tmp - target)
+ double maxAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
+
+ System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
+ assertEquals(0, meanAbsoluteError , expected_mean_error);
+ assertEquals(0, maxAbsoluteError, expected_absolute_error);
+
+ // convenience methods should give the same results
+ Vector v = lr.classifyScalar(input);
+ assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-5);
+ v = lr.classifyFull(input).viewColumn(1);
+ assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-4);
+ }
+
+ /**
+ * Permute the integers from 0 ... max-1
+ *
+ * @param gen The random number generator to use.
+ * @param max The number of integers to permute
+ * @return An array of jumbled integer values
+ */
+ static int[] permute(Random gen, int max) {
+ int[] permutation = new int[max];
+ permutation[0] = 0;
+ for (int i = 1; i < max; i++) {
+ int n = gen.nextInt(i + 1);
+ if (n == i) {
+ permutation[i] = i;
+ } else {
+ permutation[i] = permutation[n];
+ permutation[n] = i;
+ }
+ }
+ return permutation;
+ }
+
+
+ /**
+ * Reads a file containing CSV data. This isn't implemented quite the way you might like for a
+ * real program, but does the job for reading test data. Most notably, it will only read numbers,
+ * not quoted strings.
+ *
+ * @param resourceName Where to get the data.
+ * @return A matrix of the results.
+ * @throws IOException If there is an error reading the data
+ */
+ static Matrix readCsv(String resourceName) throws IOException {
+ Splitter onCommas = Splitter.on(',').trimResults(CharMatcher.anyOf(" \""));
+
+ Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8);
+ List<String> data = CharStreams.readLines(isr);
+ String first = data.get(0);
+ data = data.subList(1, data.size());
+
+ List<String> values = Lists.newArrayList(onCommas.split(first));
+ Matrix r = new DenseMatrix(data.size(), values.size());
+
+ int column = 0;
+ Map<String, Integer> labels = Maps.newHashMap();
+ for (String value : values) {
+ labels.put(value, column);
+ column++;
+ }
+ r.setColumnLabelBindings(labels);
+
+ int row = 0;
+ for (String line : data) {
+ column = 0;
+ values = Lists.newArrayList(onCommas.split(line));
+ for (String value : values) {
+ r.set(row, column, Double.parseDouble(value));
+ column++;
+ }
+ row++;
+ }
+
+ return r;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
new file mode 100644
index 0000000..44b7525
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.io.Resources;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.junit.Assert;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+
+public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
+
+ private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class);
+
+ /**
+ * The CrossFoldLearner is probably the best learner to use for new applications.
+ *
+ * @throws IOException If test resources aren't readable.
+ */
+ @Test
+ public void crossValidation() throws IOException {
+ Vector target = readStandardData();
+
+ CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
+ .lambda(1 * 1.0e-3)
+ .learningRate(50);
+
+
+ train(getInput(), target, lr);
+
+ System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
+ test(getInput(), target, lr, 0.05, 0.3);
+
+ }
+
+ @Test
+ public void crossValidatedAuc() throws IOException {
+ RandomUtils.useTestSeed();
+ Random gen = RandomUtils.getRandom();
+
+ Matrix data = readCsv("cancer.csv");
+ CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
+ .stepOffset(10)
+ .decayExponent(0.7)
+ .lambda(1 * 1.0e-3)
+ .learningRate(5);
+ int k = 0;
+ int[] ordering = permute(gen, data.numRows());
+ for (int epoch = 0; epoch < 100; epoch++) {
+ for (int row : ordering) {
+ lr.train(row, (int) data.get(row, 9), data.viewRow(row));
+ System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc());
+ }
+ assertEquals(1, lr.auc(), 0.2);
+ }
+ assertEquals(1, lr.auc(), 0.1);
+ }
+
+ /**
+ * Verifies that a classifier with known coefficients does the right thing.
+ */
+ @Test
+ public void testClassify() {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1));
+ // set up some internal coefficients as if we had learned them
+ lr.setBeta(0, 0, -1);
+ lr.setBeta(1, 0, -2);
+
+ // zero vector gives no information. All classes are equal.
+ Vector v = lr.classify(new DenseVector(new double[]{0, 0}));
+ assertEquals(1 / 3.0, v.get(0), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-8);
+
+ v = lr.classifyFull(new DenseVector(new double[]{0, 0}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(0), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(2), 1.0e-8);
+
+ // weights for second vector component are still zero so all classifications are equally likely
+ v = lr.classify(new DenseVector(new double[]{0, 1}));
+ assertEquals(1 / 3.0, v.get(0), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-3);
+
+ v = lr.classifyFull(new DenseVector(new double[]{0, 1}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(0), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(2), 1.0e-3);
+
+ // but the weights on the first component are non-zero
+ v = lr.classify(new DenseVector(new double[]{1, 0}));
+ assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 0}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
+ assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8);
+
+ lr.setBeta(0, 1, 1);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3);
+ assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3);
+
+ lr.setBeta(1, 1, 3);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8);
+ assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8);
+ assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8);
+ }
+
+ @Test
+ public void iris() throws IOException {
+ // this test trains a 3-way classifier on the famous Iris dataset.
+ // a similar exercise can be accomplished in R using this code:
+ // library(nnet)
+ // correct = rep(0,100)
+ // for (j in 1:100) {
+ // i = order(runif(150))
+ // train = iris[i[1:100],]
+ // test = iris[i[101:150],]
+ // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
+ // correct[j] = mean(predict(m, newdata=test) == test$Species)
+ // }
+ // hist(correct)
+ //
+ // Note that depending on the training/test split, performance can be better or worse.
+ // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
+ // of 100%
+ //
+ // This test uses a deterministic split that is neither outstandingly good nor bad
+
+
+ RandomUtils.useTestSeed();
+ Splitter onComma = Splitter.on(",");
+
+ // read the data
+ List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);
+
+ // holds features
+ List<Vector> data = Lists.newArrayList();
+
+ // holds target variable
+ List<Integer> target = Lists.newArrayList();
+
+ // for decoding target values
+ Dictionary dict = new Dictionary();
+
+ // for permuting data later
+ List<Integer> order = Lists.newArrayList();
+
+ for (String line : raw.subList(1, raw.size())) {
+ // order gets a list of indexes
+ order.add(order.size());
+
+ // parse the predictor variables
+ Vector v = new DenseVector(5);
+ v.set(0, 1);
+ int i = 1;
+ Iterable<String> values = onComma.split(line);
+ for (String value : Iterables.limit(values, 4)) {
+ v.set(i++, Double.parseDouble(value));
+ }
+ data.add(v);
+
+ // and the target
+ target.add(dict.intern(Iterables.get(values, 4)));
+ }
+
+ // randomize the order ... original data has each species all together
+ // note that this randomization is deterministic
+ Random random = RandomUtils.getRandom();
+ Collections.shuffle(order, random);
+
+ // select training and test data
+ List<Integer> train = order.subList(0, 100);
+ List<Integer> test = order.subList(100, 150);
+ logger.warn("Training set = {}", train);
+ logger.warn("Test set = {}", test);
+
+ // now train many times and collect information on accuracy each time
+ int[] correct = new int[test.size() + 1];
+ for (int run = 0; run < 200; run++) {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
+ // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
+ for (int pass = 0; pass < 30; pass++) {
+ Collections.shuffle(train, random);
+ for (int k : train) {
+ lr.train(target.get(k), data.get(k));
+ }
+ }
+
+ // check the accuracy on held out data
+ int x = 0;
+ int[] count = new int[3];
+ for (Integer k : test) {
+ int r = lr.classifyFull(data.get(k)).maxValueIndex();
+ count[r]++;
+ x += r == target.get(k) ? 1 : 0;
+ }
+ correct[x]++;
+ }
+
+ // verify we never saw worse than 95% correct,
+ for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
+ assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]);
+ }
+ // nor perfect
+ assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size()]);
+ }
+
+ @Test
+ public void testTrain() throws Exception {
+ Vector target = readStandardData();
+
+
+ // lambda here needs to be relatively small to avoid swamping the actual signal, but can be
+ // larger than usual because the data are dense. The learning rate doesn't matter too much
+ // for this example, but should generally be < 1
+ // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias
+ // --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+ .lambda(1 * 1.0e-3)
+ .learningRate(50);
+
+ train(getInput(), target, lr);
+ test(getInput(), target, lr, 0.05, 0.3);
+ }
+
+ /**
+ * Test for Serialization/DeSerialization
+ *
+ */
+ @Test
+ public void testSerializationAndDeSerialization() throws Exception {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+ .lambda(1 * 1.0e-3)
+ .stepOffset(11)
+ .alpha(0.01)
+ .learningRate(50)
+ .decayExponent(-0.02);
+
+ lr.close();
+
+ byte[] output;
+
+ try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+ DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream)) {
+ PolymorphicWritable.write(dataOutputStream, lr);
+ output = byteArrayOutputStream.toByteArray();
+ }
+
+ OnlineLogisticRegression read;
+
+ try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(output);
+ DataInputStream dataInputStream = new DataInputStream(byteArrayInputStream)) {
+ read = PolymorphicWritable.read(dataInputStream, OnlineLogisticRegression.class);
+ }
+
+ //lambda
+ Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7);
+
+ // Reflection to get private variables
+ //stepOffset
+ Field stepOffset = lr.getClass().getDeclaredField("stepOffset");
+ stepOffset.setAccessible(true);
+ int stepOffsetVal = (Integer) stepOffset.get(lr);
+ Assert.assertEquals(11, stepOffsetVal);
+
+ //decayFactor (alpha)
+ Field decayFactor = lr.getClass().getDeclaredField("decayFactor");
+ decayFactor.setAccessible(true);
+ double decayFactorVal = (Double) decayFactor.get(lr);
+ Assert.assertEquals(0.01, decayFactorVal, 1.0e-7);
+
+ //learning rate (mu0)
+ Field mu0 = lr.getClass().getDeclaredField("mu0");
+ mu0.setAccessible(true);
+ double mu0Val = (Double) mu0.get(lr);
+ Assert.assertEquals(50, mu0Val, 1.0e-7);
+
+ //forgettingExponent (decayExponent)
+ Field forgettingExponent = lr.getClass().getDeclaredField("forgettingExponent");
+ forgettingExponent.setAccessible(true);
+ double forgettingExponentVal = (Double) forgettingExponent.get(lr);
+ Assert.assertEquals(-0.02, forgettingExponentVal, 1.0e-7);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
new file mode 100644
index 0000000..df97d38
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public final class PassiveAggressiveTest extends OnlineBaseTest {
+
+ @Test
+ public void testPassiveAggressive() throws IOException {
+ Vector target = readStandardData();
+ PassiveAggressive pa = new PassiveAggressive(2,8).learningRate(0.1);
+ train(getInput(), target, pa);
+ test(getInput(), target, pa, 0.11, 0.31);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java b/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
new file mode 100644
index 0000000..62e10c6
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
@@ -0,0 +1,152 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import java.io.IOException;
+import java.util.Random;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.stats.Sampler;
+
+public final class ClusteringTestUtils {
+
+ private ClusteringTestUtils() {
+ }
+
+ public static void writePointsToFile(Iterable<VectorWritable> points,
+ Path path,
+ FileSystem fs,
+ Configuration conf) throws IOException {
+ writePointsToFile(points, false, path, fs, conf);
+ }
+
+ public static void writePointsToFile(Iterable<VectorWritable> points,
+ boolean intWritable,
+ Path path,
+ FileSystem fs,
+ Configuration conf) throws IOException {
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs,
+ conf,
+ path,
+ intWritable ? IntWritable.class : LongWritable.class,
+ VectorWritable.class);
+ try {
+ int recNum = 0;
+ for (VectorWritable point : points) {
+ writer.append(intWritable ? new IntWritable(recNum++) : new LongWritable(recNum++), point);
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ public static Matrix sampledCorpus(Matrix matrix, Random random,
+ int numDocs, int numSamples, int numTopicsPerDoc) {
+ Matrix corpus = new SparseRowMatrix(numDocs, matrix.numCols());
+ LDASampler modelSampler = new LDASampler(matrix, random);
+ Vector topicVector = new DenseVector(matrix.numRows());
+ for (int i = 0; i < numTopicsPerDoc; i++) {
+ int topic = random.nextInt(topicVector.size());
+ topicVector.set(topic, topicVector.get(topic) + 1);
+ }
+ for (int docId = 0; docId < numDocs; docId++) {
+ for (int sample : modelSampler.sample(topicVector, numSamples)) {
+ corpus.set(docId, sample, corpus.get(docId, sample) + 1);
+ }
+ }
+ return corpus;
+ }
+
+ public static Matrix randomStructuredModel(int numTopics, int numTerms) {
+ return randomStructuredModel(numTopics, numTerms, new DoubleFunction() {
+ @Override public double apply(double d) {
+ return 1.0 / (1 + Math.abs(d));
+ }
+ });
+ }
+
+ public static Matrix randomStructuredModel(int numTopics, int numTerms, DoubleFunction decay) {
+ Matrix model = new DenseMatrix(numTopics, numTerms);
+ int width = numTerms / numTopics;
+ for (int topic = 0; topic < numTopics; topic++) {
+ int topicCentroid = width * (1+topic);
+ for (int i = 0; i < numTerms; i++) {
+ int distance = Math.abs(topicCentroid - i);
+ if (distance > numTerms / 2) {
+ distance = numTerms - distance;
+ }
+ double v = decay.apply(distance);
+ model.set(topic, i, v);
+ }
+ }
+ return model;
+ }
+
+ /**
+ * Takes in a {@link Matrix} of topic distributions (such as generated by {@link org.apache.mahout.clustering.lda.cvb.CVB0Driver} or
+ * {@link org.apache.mahout.clustering.lda.cvb.InMemoryCollapsedVariationalBayes0}, and constructs
+ * a set of samplers over this distribution, which may be sampled from by providing a distribution
+ * over topics, and a number of samples desired
+ */
+ static class LDASampler {
+ private final Random random;
+ private final Sampler[] samplers;
+
+ LDASampler(Matrix model, Random random) {
+ this.random = random;
+ samplers = new Sampler[model.numRows()];
+ for (int i = 0; i < samplers.length; i++) {
+ samplers[i] = new Sampler(random, model.viewRow(i));
+ }
+ }
+
+ /**
+ *
+ * @param topicDistribution vector of p(topicId) for all topicId < model.numTopics()
+ * @param numSamples the number of times to sample (with replacement) from the model
+ * @return array of length numSamples, with each entry being a sample from the model. There
+ * may be repeats
+ */
+ public int[] sample(Vector topicDistribution, int numSamples) {
+ Preconditions.checkNotNull(topicDistribution);
+ Preconditions.checkArgument(numSamples > 0, "numSamples must be positive");
+ Preconditions.checkArgument(topicDistribution.size() == samplers.length,
+ "topicDistribution must have same cardinality as the sampling model");
+ int[] samples = new int[numSamples];
+ Sampler topicSampler = new Sampler(random, topicDistribution);
+ for (int i = 0; i < numSamples; i++) {
+ samples[i] = samplers[topicSampler.sample()].sample();
+ }
+ return samples;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java b/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
new file mode 100644
index 0000000..1cbfb02
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public final class TestClusterInterface extends MahoutTestCase {
+
+ private static final DistanceMeasure measure = new ManhattanDistanceMeasure();
+
+ @Test
+ public void testClusterAsFormatString() {
+ double[] d = { 1.1, 2.2, 3.3 };
+ Vector m = new DenseVector(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String formatString = cluster.asFormatString(null);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[1.1,2.2,3.3]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+ @Test
+ public void testClusterAsFormatStringSparse() {
+ double[] d = { 1.1, 0.0, 3.3 };
+ Vector m = new SequentialAccessSparseVector(3);
+ m.assign(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String formatString = cluster.asFormatString(null);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+ @Test
+ public void testClusterAsFormatStringWithBindings() {
+ double[] d = { 1.1, 2.2, 3.3 };
+ Vector m = new DenseVector(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String[] bindings = { "fee", null, "foo" };
+ String formatString = cluster.asFormatString(bindings);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[{\"fee\":1.1},{\"1\":2.2},{\"foo\":3.3}]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+ @Test
+ public void testClusterAsFormatStringSparseWithBindings() {
+ double[] d = { 1.1, 0.0, 3.3 };
+ Vector m = new SequentialAccessSparseVector(3);
+ m.assign(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String formatString = cluster.asFormatString(null);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java b/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
new file mode 100644
index 0000000..43417fc
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
@@ -0,0 +1,186 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.util.Collection;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.SquareRootFunction;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class TestGaussianAccumulators extends MahoutTestCase {
+
+ private static final Logger log = LoggerFactory.getLogger(TestGaussianAccumulators.class);
+
+ private Collection<VectorWritable> sampleData = Lists.newArrayList();
+ private int sampleN;
+ private Vector sampleMean;
+ private Vector sampleStd;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ sampleData = Lists.newArrayList();
+ generateSamples();
+ sampleN = 0;
+ Vector sum = new DenseVector(2);
+ for (VectorWritable v : sampleData) {
+ sum.assign(v.get(), Functions.PLUS);
+ sampleN++;
+ }
+ sampleMean = sum.divide(sampleN);
+
+ Vector sampleVar = new DenseVector(2);
+ for (VectorWritable v : sampleData) {
+ Vector delta = v.get().minus(sampleMean);
+ sampleVar.assign(delta.times(delta), Functions.PLUS);
+ }
+ sampleVar = sampleVar.divide(sampleN - 1);
+ sampleStd = sampleVar.clone();
+ sampleStd.assign(new SquareRootFunction());
+ log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]",
+ sampleN, sampleMean.get(0), sampleMean.get(1), sampleStd.get(0), sampleStd.get(1));
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ *
+ * @param num
+ * int number of samples to generate
+ * @param mx
+ * double x-value of the sample mean
+ * @param my
+ * double y-value of the sample mean
+ * @param sdx
+ * double x-value standard deviation of the samples
+ * @param sdy
+ * double y-value standard deviation of the samples
+ */
+ private void generate2dSamples(int num, double mx, double my, double sdx, double sdy) {
+ log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy);
+ for (int i = 0; i < num; i++) {
+ sampleData.add(new VectorWritable(new DenseVector(new double[] { UncommonDistributions.rNorm(mx, sdx),
+ UncommonDistributions.rNorm(my, sdy) })));
+ }
+ }
+
+ private void generateSamples() {
+ generate2dSamples(50000, 1, 2, 3, 4);
+ }
+
+ @Test
+ public void testAccumulatorNoSamples() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
+ assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON);
+ }
+
+ @Test
+ public void testAccumulatorOneSample() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ Vector sample = new DenseVector(2);
+ accumulator0.observe(sample, 1.0);
+ accumulator1.observe(sample, 1.0);
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
+ assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON);
+ }
+
+ @Test
+ public void testOLAccumulatorResults() {
+ GaussianAccumulator accumulator = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator.observe(vw.get(), 1.0);
+ }
+ accumulator.compute();
+ log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]",
+ accumulator.getN(),
+ accumulator.getMean().get(0),
+ accumulator.getMean().get(1),
+ accumulator.getStd().get(0),
+ accumulator.getStd().get(1));
+ assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+ assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
+ assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), EPSILON);
+ }
+
+ @Test
+ public void testRSAccumulatorResults() {
+ GaussianAccumulator accumulator = new RunningSumsGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator.observe(vw.get(), 1.0);
+ }
+ accumulator.compute();
+ log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]",
+ (int) accumulator.getN(),
+ accumulator.getMean().get(0),
+ accumulator.getMean().get(1),
+ accumulator.getStd().get(0),
+ accumulator.getStd().get(1));
+ assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+ assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
+ assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), 0.0001);
+ }
+
+ @Test
+ public void testAccumulatorWeightedResults() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator0.observe(vw.get(), 0.5);
+ accumulator1.observe(vw.get(), 0.5);
+ }
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
+ assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
+ assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01);
+ }
+
+ @Test
+ public void testAccumulatorWeightedResults2() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator0.observe(vw.get(), 1.5);
+ accumulator1.observe(vw.get(), 1.5);
+ }
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
+ assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
+ assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01);
+ }
+}