You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/04 14:29:08 UTC
[06/53] [abbrv] [partial] mahout git commit: end of day 6-2-2018
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
new file mode 100644
index 0000000..2373b9d
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Random;
+
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+import org.junit.Test;
+
+public final class ModelSerializerTest extends MahoutTestCase {
+
+ private static <T extends Writable> T roundTrip(T m, Class<T> clazz) throws IOException {
+ ByteArrayOutputStream buf = new ByteArrayOutputStream(1000);
+ DataOutputStream dos = new DataOutputStream(buf);
+ try {
+ PolymorphicWritable.write(dos, m);
+ } finally {
+ Closeables.close(dos, false);
+ }
+ return PolymorphicWritable.read(new DataInputStream(new ByteArrayInputStream(buf.toByteArray())), clazz);
+ }
+
+ @Test
+ public void onlineAucRoundtrip() throws IOException {
+ RandomUtils.useTestSeed();
+ OnlineAuc auc1 = new GlobalOnlineAuc();
+ Random gen = RandomUtils.getRandom();
+ for (int i = 0; i < 10000; i++) {
+ auc1.addSample(0, gen.nextGaussian());
+ auc1.addSample(1, gen.nextGaussian() + 1);
+ }
+ assertEquals(0.76, auc1.auc(), 0.01);
+
+ OnlineAuc auc3 = roundTrip(auc1, OnlineAuc.class);
+
+ assertEquals(auc1.auc(), auc3.auc(), 0);
+
+ for (int i = 0; i < 1000; i++) {
+ auc1.addSample(0, gen.nextGaussian());
+ auc1.addSample(1, gen.nextGaussian() + 1);
+
+ auc3.addSample(0, gen.nextGaussian());
+ auc3.addSample(1, gen.nextGaussian() + 1);
+ }
+
+ assertEquals(auc1.auc(), auc3.auc(), 0.01);
+ }
+
+ @Test
+ public void onlineLogisticRegressionRoundTrip() throws IOException {
+ OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, new L1());
+ train(olr, 100);
+ OnlineLogisticRegression olr3 = roundTrip(olr, OnlineLogisticRegression.class);
+ assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6);
+
+ train(olr, 100);
+ train(olr3, 100);
+
+ assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6);
+ olr.close();
+ olr3.close();
+ }
+
+ @Test
+ public void crossFoldLearnerRoundTrip() throws IOException {
+ CrossFoldLearner learner = new CrossFoldLearner(5, 2, 5, new L1());
+ train(learner, 100);
+ CrossFoldLearner olr3 = roundTrip(learner, CrossFoldLearner.class);
+ double auc1 = learner.auc();
+ assertTrue(auc1 > 0.85);
+ assertEquals(auc1, learner.auc(), 1.0e-6);
+ assertEquals(auc1, olr3.auc(), 1.0e-6);
+
+ train(learner, 100);
+ train(learner, 100);
+ train(olr3, 100);
+
+ assertEquals(learner.auc(), learner.auc(), 0.02);
+ assertEquals(learner.auc(), olr3.auc(), 0.02);
+ double auc2 = learner.auc();
+ assertTrue(auc2 > auc1);
+ learner.close();
+ olr3.close();
+ }
+
+ @ThreadLeakLingering(linger = 1000)
+ @Test
+ public void adaptiveLogisticRegressionRoundTrip() throws IOException {
+ AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5, new L1());
+ learner.setInterval(200);
+ train(learner, 400);
+ AdaptiveLogisticRegression olr3 = roundTrip(learner, AdaptiveLogisticRegression.class);
+ double auc1 = learner.auc();
+ assertTrue(auc1 > 0.85);
+ assertEquals(auc1, learner.auc(), 1.0e-6);
+ assertEquals(auc1, olr3.auc(), 1.0e-6);
+
+ train(learner, 1000);
+ train(learner, 1000);
+ train(olr3, 1000);
+
+ assertEquals(learner.auc(), learner.auc(), 0.005);
+ assertEquals(learner.auc(), olr3.auc(), 0.005);
+ double auc2 = learner.auc();
+ assertTrue(String.format("%.3f > %.3f", auc2, auc1), auc2 > auc1);
+ learner.close();
+ olr3.close();
+ }
+
+ private static void train(OnlineLearner olr, int n) {
+ Vector beta = new DenseVector(new double[]{1, -1, 0, 0.5, -0.5});
+ Random gen = RandomUtils.getRandom();
+ for (int i = 0; i < n; i++) {
+ Vector x = randomVector(gen, 5);
+
+ int target = gen.nextDouble() < beta.dot(x) ? 1 : 0;
+ olr.train(target, x);
+ }
+ }
+
+ private static Vector randomVector(final Random gen, int n) {
+ Vector x = new DenseVector(n);
+ x.assign(new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return gen.nextGaussian();
+ }
+ });
+ return x;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
new file mode 100644
index 0000000..e0a252c
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.CharStreams;
+import com.google.common.io.Resources;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+public abstract class OnlineBaseTest extends MahoutTestCase {
+
+ private Matrix input;
+
+ Matrix getInput() {
+ return input;
+ }
+
+ Vector readStandardData() throws IOException {
+ // 60 test samples. First column is constant. Second and third are normally distributed from
+ // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The first 30 rows have a
+ // target variable of 0, the last 30 a target of 1. The remaining columns are are random noise.
+ input = readCsv("sgd.csv");
+
+ // regenerate the target variable
+ Vector target = new DenseVector(60);
+ target.assign(0);
+ target.viewPart(30, 30).assign(1);
+ return target;
+ }
+
+ static void train(Matrix input, Vector target, OnlineLearner lr) {
+ RandomUtils.useTestSeed();
+ Random gen = RandomUtils.getRandom();
+
+ // train on samples in random order (but only one pass)
+ for (int row : permute(gen, 60)) {
+ lr.train((int) target.get(row), input.viewRow(row));
+ }
+ lr.close();
+ }
+
+ static void test(Matrix input, Vector target, AbstractVectorClassifier lr,
+ double expected_mean_error, double expected_absolute_error) {
+ // now test the accuracy
+ Matrix tmp = lr.classify(input);
+ // mean(abs(tmp - target))
+ double meanAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;
+
+ // max(abs(tmp - target)
+ double maxAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
+
+ System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
+ assertEquals(0, meanAbsoluteError , expected_mean_error);
+ assertEquals(0, maxAbsoluteError, expected_absolute_error);
+
+ // convenience methods should give the same results
+ Vector v = lr.classifyScalar(input);
+ assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-5);
+ v = lr.classifyFull(input).viewColumn(1);
+ assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-4);
+ }
+
+ /**
+ * Permute the integers from 0 ... max-1
+ *
+ * @param gen The random number generator to use.
+ * @param max The number of integers to permute
+ * @return An array of jumbled integer values
+ */
+ static int[] permute(Random gen, int max) {
+ int[] permutation = new int[max];
+ permutation[0] = 0;
+ for (int i = 1; i < max; i++) {
+ int n = gen.nextInt(i + 1);
+ if (n == i) {
+ permutation[i] = i;
+ } else {
+ permutation[i] = permutation[n];
+ permutation[n] = i;
+ }
+ }
+ return permutation;
+ }
+
+
+ /**
+ * Reads a file containing CSV data. This isn't implemented quite the way you might like for a
+ * real program, but does the job for reading test data. Most notably, it will only read numbers,
+ * not quoted strings.
+ *
+ * @param resourceName Where to get the data.
+ * @return A matrix of the results.
+ * @throws IOException If there is an error reading the data
+ */
+ static Matrix readCsv(String resourceName) throws IOException {
+ Splitter onCommas = Splitter.on(',').trimResults(CharMatcher.anyOf(" \""));
+
+ Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8);
+ List<String> data = CharStreams.readLines(isr);
+ String first = data.get(0);
+ data = data.subList(1, data.size());
+
+ List<String> values = Lists.newArrayList(onCommas.split(first));
+ Matrix r = new DenseMatrix(data.size(), values.size());
+
+ int column = 0;
+ Map<String, Integer> labels = Maps.newHashMap();
+ for (String value : values) {
+ labels.put(value, column);
+ column++;
+ }
+ r.setColumnLabelBindings(labels);
+
+ int row = 0;
+ for (String line : data) {
+ column = 0;
+ values = Lists.newArrayList(onCommas.split(line));
+ for (String value : values) {
+ r.set(row, column, Double.parseDouble(value));
+ column++;
+ }
+ row++;
+ }
+
+ return r;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
new file mode 100644
index 0000000..44b7525
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.io.Resources;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.junit.Assert;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+
+public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
+
+ private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class);
+
+ /**
+ * The CrossFoldLearner is probably the best learner to use for new applications.
+ *
+ * @throws IOException If test resources aren't readable.
+ */
+ @Test
+ public void crossValidation() throws IOException {
+ Vector target = readStandardData();
+
+ CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
+ .lambda(1 * 1.0e-3)
+ .learningRate(50);
+
+
+ train(getInput(), target, lr);
+
+ System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
+ test(getInput(), target, lr, 0.05, 0.3);
+
+ }
+
+ @Test
+ public void crossValidatedAuc() throws IOException {
+ RandomUtils.useTestSeed();
+ Random gen = RandomUtils.getRandom();
+
+ Matrix data = readCsv("cancer.csv");
+ CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
+ .stepOffset(10)
+ .decayExponent(0.7)
+ .lambda(1 * 1.0e-3)
+ .learningRate(5);
+ int k = 0;
+ int[] ordering = permute(gen, data.numRows());
+ for (int epoch = 0; epoch < 100; epoch++) {
+ for (int row : ordering) {
+ lr.train(row, (int) data.get(row, 9), data.viewRow(row));
+ System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc());
+ }
+ assertEquals(1, lr.auc(), 0.2);
+ }
+ assertEquals(1, lr.auc(), 0.1);
+ }
+
+ /**
+ * Verifies that a classifier with known coefficients does the right thing.
+ */
+ @Test
+ public void testClassify() {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1));
+ // set up some internal coefficients as if we had learned them
+ lr.setBeta(0, 0, -1);
+ lr.setBeta(1, 0, -2);
+
+ // zero vector gives no information. All classes are equal.
+ Vector v = lr.classify(new DenseVector(new double[]{0, 0}));
+ assertEquals(1 / 3.0, v.get(0), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-8);
+
+ v = lr.classifyFull(new DenseVector(new double[]{0, 0}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(0), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(2), 1.0e-8);
+
+ // weights for second vector component are still zero so all classifications are equally likely
+ v = lr.classify(new DenseVector(new double[]{0, 1}));
+ assertEquals(1 / 3.0, v.get(0), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-3);
+
+ v = lr.classifyFull(new DenseVector(new double[]{0, 1}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(0), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(2), 1.0e-3);
+
+ // but the weights on the first component are non-zero
+ v = lr.classify(new DenseVector(new double[]{1, 0}));
+ assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 0}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
+ assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8);
+
+ lr.setBeta(0, 1, 1);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3);
+ assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3);
+
+ lr.setBeta(1, 1, 3);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8);
+ assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8);
+ assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8);
+ }
+
+ @Test
+ public void iris() throws IOException {
+ // this test trains a 3-way classifier on the famous Iris dataset.
+ // a similar exercise can be accomplished in R using this code:
+ // library(nnet)
+ // correct = rep(0,100)
+ // for (j in 1:100) {
+ // i = order(runif(150))
+ // train = iris[i[1:100],]
+ // test = iris[i[101:150],]
+ // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
+ // correct[j] = mean(predict(m, newdata=test) == test$Species)
+ // }
+ // hist(correct)
+ //
+ // Note that depending on the training/test split, performance can be better or worse.
+ // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
+ // of 100%
+ //
+ // This test uses a deterministic split that is neither outstandingly good nor bad
+
+
+ RandomUtils.useTestSeed();
+ Splitter onComma = Splitter.on(",");
+
+ // read the data
+ List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);
+
+ // holds features
+ List<Vector> data = Lists.newArrayList();
+
+ // holds target variable
+ List<Integer> target = Lists.newArrayList();
+
+ // for decoding target values
+ Dictionary dict = new Dictionary();
+
+ // for permuting data later
+ List<Integer> order = Lists.newArrayList();
+
+ for (String line : raw.subList(1, raw.size())) {
+ // order gets a list of indexes
+ order.add(order.size());
+
+ // parse the predictor variables
+ Vector v = new DenseVector(5);
+ v.set(0, 1);
+ int i = 1;
+ Iterable<String> values = onComma.split(line);
+ for (String value : Iterables.limit(values, 4)) {
+ v.set(i++, Double.parseDouble(value));
+ }
+ data.add(v);
+
+ // and the target
+ target.add(dict.intern(Iterables.get(values, 4)));
+ }
+
+ // randomize the order ... original data has each species all together
+ // note that this randomization is deterministic
+ Random random = RandomUtils.getRandom();
+ Collections.shuffle(order, random);
+
+ // select training and test data
+ List<Integer> train = order.subList(0, 100);
+ List<Integer> test = order.subList(100, 150);
+ logger.warn("Training set = {}", train);
+ logger.warn("Test set = {}", test);
+
+ // now train many times and collect information on accuracy each time
+ int[] correct = new int[test.size() + 1];
+ for (int run = 0; run < 200; run++) {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
+ // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
+ for (int pass = 0; pass < 30; pass++) {
+ Collections.shuffle(train, random);
+ for (int k : train) {
+ lr.train(target.get(k), data.get(k));
+ }
+ }
+
+ // check the accuracy on held out data
+ int x = 0;
+ int[] count = new int[3];
+ for (Integer k : test) {
+ int r = lr.classifyFull(data.get(k)).maxValueIndex();
+ count[r]++;
+ x += r == target.get(k) ? 1 : 0;
+ }
+ correct[x]++;
+ }
+
+ // verify we never saw worse than 95% correct,
+ for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
+ assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]);
+ }
+ // nor perfect
+ assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size()]);
+ }
+
+ @Test
+ public void testTrain() throws Exception {
+ Vector target = readStandardData();
+
+
+ // lambda here needs to be relatively small to avoid swamping the actual signal, but can be
+ // larger than usual because the data are dense. The learning rate doesn't matter too much
+ // for this example, but should generally be < 1
+ // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias
+ // --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+ .lambda(1 * 1.0e-3)
+ .learningRate(50);
+
+ train(getInput(), target, lr);
+ test(getInput(), target, lr, 0.05, 0.3);
+ }
+
+ /**
+ * Test for Serialization/DeSerialization
+ *
+ */
+ @Test
+ public void testSerializationAndDeSerialization() throws Exception {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+ .lambda(1 * 1.0e-3)
+ .stepOffset(11)
+ .alpha(0.01)
+ .learningRate(50)
+ .decayExponent(-0.02);
+
+ lr.close();
+
+ byte[] output;
+
+ try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+ DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream)) {
+ PolymorphicWritable.write(dataOutputStream, lr);
+ output = byteArrayOutputStream.toByteArray();
+ }
+
+ OnlineLogisticRegression read;
+
+ try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(output);
+ DataInputStream dataInputStream = new DataInputStream(byteArrayInputStream)) {
+ read = PolymorphicWritable.read(dataInputStream, OnlineLogisticRegression.class);
+ }
+
+ //lambda
+ Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7);
+
+ // Reflection to get private variables
+ //stepOffset
+ Field stepOffset = lr.getClass().getDeclaredField("stepOffset");
+ stepOffset.setAccessible(true);
+ int stepOffsetVal = (Integer) stepOffset.get(lr);
+ Assert.assertEquals(11, stepOffsetVal);
+
+ //decayFactor (alpha)
+ Field decayFactor = lr.getClass().getDeclaredField("decayFactor");
+ decayFactor.setAccessible(true);
+ double decayFactorVal = (Double) decayFactor.get(lr);
+ Assert.assertEquals(0.01, decayFactorVal, 1.0e-7);
+
+ //learning rate (mu0)
+ Field mu0 = lr.getClass().getDeclaredField("mu0");
+ mu0.setAccessible(true);
+ double mu0Val = (Double) mu0.get(lr);
+ Assert.assertEquals(50, mu0Val, 1.0e-7);
+
+ //forgettingExponent (decayExponent)
+ Field forgettingExponent = lr.getClass().getDeclaredField("forgettingExponent");
+ forgettingExponent.setAccessible(true);
+ double forgettingExponentVal = (Double) forgettingExponent.get(lr);
+ Assert.assertEquals(-0.02, forgettingExponentVal, 1.0e-7);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
new file mode 100644
index 0000000..df97d38
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public final class PassiveAggressiveTest extends OnlineBaseTest {
+
+ @Test
+ public void testPassiveAggressive() throws IOException {
+ Vector target = readStandardData();
+ PassiveAggressive pa = new PassiveAggressive(2,8).learningRate(0.1);
+ train(getInput(), target, pa);
+ test(getInput(), target, pa, 0.11, 0.31);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
new file mode 100644
index 0000000..62e10c6
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
@@ -0,0 +1,152 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import java.io.IOException;
+import java.util.Random;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.stats.Sampler;
+
+public final class ClusteringTestUtils {
+
+ private ClusteringTestUtils() {
+ }
+
+ public static void writePointsToFile(Iterable<VectorWritable> points,
+ Path path,
+ FileSystem fs,
+ Configuration conf) throws IOException {
+ writePointsToFile(points, false, path, fs, conf);
+ }
+
+ public static void writePointsToFile(Iterable<VectorWritable> points,
+ boolean intWritable,
+ Path path,
+ FileSystem fs,
+ Configuration conf) throws IOException {
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs,
+ conf,
+ path,
+ intWritable ? IntWritable.class : LongWritable.class,
+ VectorWritable.class);
+ try {
+ int recNum = 0;
+ for (VectorWritable point : points) {
+ writer.append(intWritable ? new IntWritable(recNum++) : new LongWritable(recNum++), point);
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ public static Matrix sampledCorpus(Matrix matrix, Random random,
+ int numDocs, int numSamples, int numTopicsPerDoc) {
+ Matrix corpus = new SparseRowMatrix(numDocs, matrix.numCols());
+ LDASampler modelSampler = new LDASampler(matrix, random);
+ Vector topicVector = new DenseVector(matrix.numRows());
+ for (int i = 0; i < numTopicsPerDoc; i++) {
+ int topic = random.nextInt(topicVector.size());
+ topicVector.set(topic, topicVector.get(topic) + 1);
+ }
+ for (int docId = 0; docId < numDocs; docId++) {
+ for (int sample : modelSampler.sample(topicVector, numSamples)) {
+ corpus.set(docId, sample, corpus.get(docId, sample) + 1);
+ }
+ }
+ return corpus;
+ }
+
+ public static Matrix randomStructuredModel(int numTopics, int numTerms) {
+ return randomStructuredModel(numTopics, numTerms, new DoubleFunction() {
+ @Override public double apply(double d) {
+ return 1.0 / (1 + Math.abs(d));
+ }
+ });
+ }
+
+ public static Matrix randomStructuredModel(int numTopics, int numTerms, DoubleFunction decay) {
+ Matrix model = new DenseMatrix(numTopics, numTerms);
+ int width = numTerms / numTopics;
+ for (int topic = 0; topic < numTopics; topic++) {
+ int topicCentroid = width * (1+topic);
+ for (int i = 0; i < numTerms; i++) {
+ int distance = Math.abs(topicCentroid - i);
+ if (distance > numTerms / 2) {
+ distance = numTerms - distance;
+ }
+ double v = decay.apply(distance);
+ model.set(topic, i, v);
+ }
+ }
+ return model;
+ }
+
+ /**
+ * Takes in a {@link Matrix} of topic distributions (such as generated by {@link org.apache.mahout.clustering.lda.cvb.CVB0Driver} or
+ * {@link org.apache.mahout.clustering.lda.cvb.InMemoryCollapsedVariationalBayes0}, and constructs
+ * a set of samplers over this distribution, which may be sampled from by providing a distribution
+ * over topics, and a number of samples desired
+ */
+ static class LDASampler {
+ private final Random random;
+ private final Sampler[] samplers;
+
+ LDASampler(Matrix model, Random random) {
+ this.random = random;
+ samplers = new Sampler[model.numRows()];
+ for (int i = 0; i < samplers.length; i++) {
+ samplers[i] = new Sampler(random, model.viewRow(i));
+ }
+ }
+
+ /**
+ *
+ * @param topicDistribution vector of p(topicId) for all topicId < model.numTopics()
+ * @param numSamples the number of times to sample (with replacement) from the model
+ * @return array of length numSamples, with each entry being a sample from the model. There
+ * may be repeats
+ */
+ public int[] sample(Vector topicDistribution, int numSamples) {
+ Preconditions.checkNotNull(topicDistribution);
+ Preconditions.checkArgument(numSamples > 0, "numSamples must be positive");
+ Preconditions.checkArgument(topicDistribution.size() == samplers.length,
+ "topicDistribution must have same cardinality as the sampling model");
+ int[] samples = new int[numSamples];
+ Sampler topicSampler = new Sampler(random, topicDistribution);
+ for (int i = 0; i < numSamples; i++) {
+ samples[i] = samplers[topicSampler.sample()].sample();
+ }
+ return samples;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
new file mode 100644
index 0000000..1cbfb02
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public final class TestClusterInterface extends MahoutTestCase {
+
+ private static final DistanceMeasure measure = new ManhattanDistanceMeasure();
+
+ @Test
+ public void testClusterAsFormatString() {
+ double[] d = { 1.1, 2.2, 3.3 };
+ Vector m = new DenseVector(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String formatString = cluster.asFormatString(null);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[1.1,2.2,3.3]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+ @Test
+ public void testClusterAsFormatStringSparse() {
+ double[] d = { 1.1, 0.0, 3.3 };
+ Vector m = new SequentialAccessSparseVector(3);
+ m.assign(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String formatString = cluster.asFormatString(null);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+ @Test
+ public void testClusterAsFormatStringWithBindings() {
+ double[] d = { 1.1, 2.2, 3.3 };
+ Vector m = new DenseVector(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String[] bindings = { "fee", null, "foo" };
+ String formatString = cluster.asFormatString(bindings);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[{\"fee\":1.1},{\"1\":2.2},{\"foo\":3.3}]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+ @Test
+ public void testClusterAsFormatStringSparseWithBindings() {
+ double[] d = { 1.1, 0.0, 3.3 };
+ Vector m = new SequentialAccessSparseVector(3);
+ m.assign(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String formatString = cluster.asFormatString(null);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
new file mode 100644
index 0000000..43417fc
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
@@ -0,0 +1,186 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.util.Collection;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.SquareRootFunction;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class TestGaussianAccumulators extends MahoutTestCase {
+
+ private static final Logger log = LoggerFactory.getLogger(TestGaussianAccumulators.class);
+
+ private Collection<VectorWritable> sampleData = Lists.newArrayList();
+ private int sampleN;
+ private Vector sampleMean;
+ private Vector sampleStd;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ sampleData = Lists.newArrayList();
+ generateSamples();
+ sampleN = 0;
+ Vector sum = new DenseVector(2);
+ for (VectorWritable v : sampleData) {
+ sum.assign(v.get(), Functions.PLUS);
+ sampleN++;
+ }
+ sampleMean = sum.divide(sampleN);
+
+ Vector sampleVar = new DenseVector(2);
+ for (VectorWritable v : sampleData) {
+ Vector delta = v.get().minus(sampleMean);
+ sampleVar.assign(delta.times(delta), Functions.PLUS);
+ }
+ sampleVar = sampleVar.divide(sampleN - 1);
+ sampleStd = sampleVar.clone();
+ sampleStd.assign(new SquareRootFunction());
+ log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]",
+ sampleN, sampleMean.get(0), sampleMean.get(1), sampleStd.get(0), sampleStd.get(1));
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ *
+ * @param num
+ * int number of samples to generate
+ * @param mx
+ * double x-value of the sample mean
+ * @param my
+ * double y-value of the sample mean
+ * @param sdx
+ * double x-value standard deviation of the samples
+ * @param sdy
+ * double y-value standard deviation of the samples
+ */
+ private void generate2dSamples(int num, double mx, double my, double sdx, double sdy) {
+ log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy);
+ for (int i = 0; i < num; i++) {
+ sampleData.add(new VectorWritable(new DenseVector(new double[] { UncommonDistributions.rNorm(mx, sdx),
+ UncommonDistributions.rNorm(my, sdy) })));
+ }
+ }
+
+ private void generateSamples() {
+ generate2dSamples(50000, 1, 2, 3, 4);
+ }
+
+ @Test
+ public void testAccumulatorNoSamples() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
+ assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON);
+ }
+
+ @Test
+ public void testAccumulatorOneSample() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ Vector sample = new DenseVector(2);
+ accumulator0.observe(sample, 1.0);
+ accumulator1.observe(sample, 1.0);
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
+ assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON);
+ }
+
+ @Test
+ public void testOLAccumulatorResults() {
+ GaussianAccumulator accumulator = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator.observe(vw.get(), 1.0);
+ }
+ accumulator.compute();
+ log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]",
+ accumulator.getN(),
+ accumulator.getMean().get(0),
+ accumulator.getMean().get(1),
+ accumulator.getStd().get(0),
+ accumulator.getStd().get(1));
+ assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+ assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
+ assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), EPSILON);
+ }
+
+ @Test
+ public void testRSAccumulatorResults() {
+ GaussianAccumulator accumulator = new RunningSumsGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator.observe(vw.get(), 1.0);
+ }
+ accumulator.compute();
+ log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]",
+ (int) accumulator.getN(),
+ accumulator.getMean().get(0),
+ accumulator.getMean().get(1),
+ accumulator.getStd().get(0),
+ accumulator.getStd().get(1));
+ assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+ assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
+ assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), 0.0001);
+ }
+
+ @Test
+ public void testAccumulatorWeightedResults() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator0.observe(vw.get(), 0.5);
+ accumulator1.observe(vw.get(), 0.5);
+ }
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
+ assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
+ assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01);
+ }
+
+ @Test
+ public void testAccumulatorWeightedResults2() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator0.observe(vw.get(), 1.5);
+ accumulator1.observe(vw.get(), 1.5);
+ }
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
+ assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
+ assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
new file mode 100644
index 0000000..097fd74
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
@@ -0,0 +1,674 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.canopy;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Set;
+
+import com.google.common.collect.Iterables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.common.DummyRecordWriter;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+@Deprecated
+public final class TestCanopyCreation extends MahoutTestCase {
+
+ private static final double[][] RAW = { { 1, 1 }, { 2, 1 }, { 1, 2 },
+ { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 } };
+
+ private List<Canopy> referenceManhattan;
+
+ private final DistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
+
+ private List<Vector> manhattanCentroids;
+
+ private List<Canopy> referenceEuclidean;
+
+ private final DistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
+
+ private List<Vector> euclideanCentroids;
+
+ private FileSystem fs;
+
+ private static List<VectorWritable> getPointsWritable() {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : RAW) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+ private static List<Vector> getPoints() {
+ List<Vector> points = Lists.newArrayList();
+ for (double[] fr : RAW) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(vec);
+ }
+ return points;
+ }
+
+ /**
+ * Print the canopies to the transcript
+ *
+ * @param canopies
+ * a List<Canopy>
+ */
+ private static void printCanopies(Iterable<Canopy> canopies) {
+ for (Canopy canopy : canopies) {
+ System.out.println(canopy.asFormatString(null));
+ }
+ }
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ fs = FileSystem.get(getConfiguration());
+ referenceManhattan = CanopyClusterer.createCanopies(getPoints(),
+ manhattanDistanceMeasure, 3.1, 2.1);
+ manhattanCentroids = CanopyClusterer.getCenters(referenceManhattan);
+ referenceEuclidean = CanopyClusterer.createCanopies(getPoints(),
+ euclideanDistanceMeasure, 3.1, 2.1);
+ euclideanCentroids = CanopyClusterer.getCenters(referenceEuclidean);
+ }
+
+ /**
+ * Story: User can cluster points using a ManhattanDistanceMeasure and a
+ * reference implementation
+ */
+ @Test
+ public void testReferenceManhattan() throws Exception {
+ // see setUp for cluster creation
+ printCanopies(referenceManhattan);
+ assertEquals("number of canopies", 3, referenceManhattan.size());
+ for (int canopyIx = 0; canopyIx < referenceManhattan.size(); canopyIx++) {
+ Canopy testCanopy = referenceManhattan.get(canopyIx);
+ int[] expectedNumPoints = { 4, 4, 3 };
+ double[][] expectedCentroids = { { 1.5, 1.5 }, { 4.0, 4.0 },
+ { 4.666666666666667, 4.6666666666666667 } };
+ assertEquals("canopy points " + canopyIx, testCanopy.getNumObservations(),
+ expectedNumPoints[canopyIx]);
+ double[] refCentroid = expectedCentroids[canopyIx];
+ Vector testCentroid = testCanopy.computeCentroid();
+ for (int pointIx = 0; pointIx < refCentroid.length; pointIx++) {
+ assertEquals("canopy centroid " + canopyIx + '[' + pointIx + ']',
+ refCentroid[pointIx], testCentroid.get(pointIx), EPSILON);
+ }
+ }
+ }
+
+ /**
+ * Story: User can cluster points using a EuclideanDistanceMeasure and a
+ * reference implementation
+ */
+ @Test
+ public void testReferenceEuclidean() throws Exception {
+ // see setUp for cluster creation
+ printCanopies(referenceEuclidean);
+ assertEquals("number of canopies", 3, referenceEuclidean.size());
+ int[] expectedNumPoints = { 5, 5, 3 };
+ double[][] expectedCentroids = { { 1.8, 1.8 }, { 4.2, 4.2 },
+ { 4.666666666666667, 4.666666666666667 } };
+ for (int canopyIx = 0; canopyIx < referenceEuclidean.size(); canopyIx++) {
+ Canopy testCanopy = referenceEuclidean.get(canopyIx);
+ assertEquals("canopy points " + canopyIx, testCanopy.getNumObservations(),
+ expectedNumPoints[canopyIx]);
+ double[] refCentroid = expectedCentroids[canopyIx];
+ Vector testCentroid = testCanopy.computeCentroid();
+ for (int pointIx = 0; pointIx < refCentroid.length; pointIx++) {
+ assertEquals("canopy centroid " + canopyIx + '[' + pointIx + ']',
+ refCentroid[pointIx], testCentroid.get(pointIx), EPSILON);
+ }
+ }
+ }
+
+ /**
+ * Story: User can produce initial canopy centers using a
+ * ManhattanDistanceMeasure and a CanopyMapper which clusters input points to
+ * produce an output set of canopy centroid points.
+ */
+ @Test
+ public void testCanopyMapperManhattan() throws Exception {
+ CanopyMapper mapper = new CanopyMapper();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, manhattanDistanceMeasure
+ .getClass().getName());
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "0");
+ DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<>();
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter
+ .build(mapper, conf, writer);
+ mapper.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ // map the data
+ for (VectorWritable point : points) {
+ mapper.map(new Text(), point, context);
+ }
+ mapper.cleanup(context);
+ assertEquals("Number of map results", 1, writer.getData().size());
+ // now verify the output
+ List<VectorWritable> data = writer.getValue(new Text("centroid"));
+ assertEquals("Number of centroids", 3, data.size());
+ for (int i = 0; i < data.size(); i++) {
+ assertEquals("Centroid error",
+ manhattanCentroids.get(i).asFormatString(), data.get(i).get()
+ .asFormatString());
+ }
+ }
+
+ /**
+ * Story: User can produce initial canopy centers using a
+ * EuclideanDistanceMeasure and a CanopyMapper/Combiner which clusters input
+ * points to produce an output set of canopy centroid points.
+ */
+ @Test
+ public void testCanopyMapperEuclidean() throws Exception {
+ CanopyMapper mapper = new CanopyMapper();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, euclideanDistanceMeasure
+ .getClass().getName());
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "0");
+ DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<>();
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter
+ .build(mapper, conf, writer);
+ mapper.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ // map the data
+ for (VectorWritable point : points) {
+ mapper.map(new Text(), point, context);
+ }
+ mapper.cleanup(context);
+ assertEquals("Number of map results", 1, writer.getData().size());
+ // now verify the output
+ List<VectorWritable> data = writer.getValue(new Text("centroid"));
+ assertEquals("Number of centroids", 3, data.size());
+ for (int i = 0; i < data.size(); i++) {
+ assertEquals("Centroid error",
+ euclideanCentroids.get(i).asFormatString(), data.get(i).get()
+ .asFormatString());
+ }
+ }
+
+ /**
+ * Story: User can produce final canopy centers using a
+ * ManhattanDistanceMeasure and a CanopyReducer which clusters input centroid
+ * points to produce an output set of final canopy centroid points.
+ */
+ @Test
+ public void testCanopyReducerManhattan() throws Exception {
+ CanopyReducer reducer = new CanopyReducer();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY,
+ "org.apache.mahout.common.distance.ManhattanDistanceMeasure");
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "0");
+ DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<>();
+ Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context = DummyRecordWriter
+ .build(reducer, conf, writer, Text.class, VectorWritable.class);
+ reducer.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ reducer.reduce(new Text("centroid"), points, context);
+ Iterable<Text> keys = writer.getKeysInInsertionOrder();
+ assertEquals("Number of centroids", 3, Iterables.size(keys));
+ int i = 0;
+ for (Text key : keys) {
+ List<ClusterWritable> data = writer.getValue(key);
+ ClusterWritable clusterWritable = data.get(0);
+ Canopy canopy = (Canopy) clusterWritable.getValue();
+ assertEquals(manhattanCentroids.get(i).asFormatString() + " is not equal to "
+ + canopy.computeCentroid().asFormatString(),
+ manhattanCentroids.get(i), canopy.computeCentroid());
+ i++;
+ }
+ }
+
+ /**
+ * Story: User can produce final canopy centers using a
+ * EuclideanDistanceMeasure and a CanopyReducer which clusters input centroid
+ * points to produce an output set of final canopy centroid points.
+ */
+ @Test
+ public void testCanopyReducerEuclidean() throws Exception {
+ CanopyReducer reducer = new CanopyReducer();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, "org.apache.mahout.common.distance.EuclideanDistanceMeasure");
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "0");
+ DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<>();
+ Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context =
+ DummyRecordWriter.build(reducer, conf, writer, Text.class, VectorWritable.class);
+ reducer.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ reducer.reduce(new Text("centroid"), points, context);
+ Iterable<Text> keys = writer.getKeysInInsertionOrder();
+ assertEquals("Number of centroids", 3, Iterables.size(keys));
+ int i = 0;
+ for (Text key : keys) {
+ List<ClusterWritable> data = writer.getValue(key);
+ ClusterWritable clusterWritable = data.get(0);
+ Canopy canopy = (Canopy) clusterWritable.getValue();
+ assertEquals(euclideanCentroids.get(i).asFormatString() + " is not equal to "
+ + canopy.computeCentroid().asFormatString(),
+ euclideanCentroids.get(i), canopy.computeCentroid());
+ i++;
+ }
+ }
+
+ /**
+ * Story: User can produce final canopy centers using a Hadoop map/reduce job
+ * and a ManhattanDistanceMeasure.
+ */
+ @Test
+ public void testCanopyGenManhattanMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration config = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file1"), fs, config);
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file2"), fs, config);
+ // now run the Canopy Driver
+ Path output = getTestTempDirPath("output");
+ CanopyDriver.run(config, getTestTempDirPath("testdata"), output,
+ manhattanDistanceMeasure, 3.1, 2.1, false, 0.0, false);
+
+ // verify output from sequence file
+ Path path = new Path(output, "clusters-0-final/part-r-00000");
+ FileSystem fs = FileSystem.get(path.toUri(), config);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config);
+ try {
+ Writable key = new Text();
+ ClusterWritable clusterWritable = new ClusterWritable();
+ assertTrue("more to come", reader.next(key, clusterWritable));
+ assertEquals("1st key", "C-0", key.toString());
+
+ List<Pair<Double,Double>> refCenters = Lists.newArrayList();
+ refCenters.add(new Pair<>(1.5,1.5));
+ refCenters.add(new Pair<>(4.333333333333334,4.333333333333334));
+ Pair<Double,Double> c = new Pair<>(clusterWritable.getValue() .getCenter().get(0),
+ clusterWritable.getValue().getCenter().get(1));
+ assertTrue("center "+c+" not found", findAndRemove(c, refCenters, EPSILON));
+ assertTrue("more to come", reader.next(key, clusterWritable));
+ assertEquals("2nd key", "C-1", key.toString());
+ c = new Pair<>(clusterWritable.getValue().getCenter().get(0),
+ clusterWritable.getValue().getCenter().get(1));
+ assertTrue("center " + c + " not found", findAndRemove(c, refCenters, EPSILON));
+ assertFalse("more to come", reader.next(key, clusterWritable));
+ } finally {
+ Closeables.close(reader, true);
+ }
+ }
+
+ static boolean findAndRemove(Pair<Double, Double> target, Collection<Pair<Double, Double>> list, double epsilon) {
+ for (Pair<Double,Double> curr : list) {
+ if ( (Math.abs(target.getFirst() - curr.getFirst()) < epsilon)
+ && (Math.abs(target.getSecond() - curr.getSecond()) < epsilon) ) {
+ list.remove(curr);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Story: User can produce final canopy centers using a Hadoop map/reduce job
+ * and a EuclideanDistanceMeasure.
+ */
+ @Test
+ public void testCanopyGenEuclideanMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration config = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file1"), fs, config);
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file2"), fs, config);
+ // now run the Canopy Driver
+ Path output = getTestTempDirPath("output");
+ CanopyDriver.run(config, getTestTempDirPath("testdata"), output,
+ euclideanDistanceMeasure, 3.1, 2.1, false, 0.0, false);
+
+ // verify output from sequence file
+ Path path = new Path(output, "clusters-0-final/part-r-00000");
+ FileSystem fs = FileSystem.get(path.toUri(), config);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config);
+ try {
+ Writable key = new Text();
+ ClusterWritable clusterWritable = new ClusterWritable();
+ assertTrue("more to come", reader.next(key, clusterWritable));
+ assertEquals("1st key", "C-0", key.toString());
+
+ List<Pair<Double,Double>> refCenters = Lists.newArrayList();
+ refCenters.add(new Pair<>(1.8,1.8));
+ refCenters.add(new Pair<>(4.433333333333334, 4.433333333333334));
+ Pair<Double,Double> c = new Pair<>(clusterWritable.getValue().getCenter().get(0),
+ clusterWritable.getValue().getCenter().get(1));
+ assertTrue("center "+c+" not found", findAndRemove(c, refCenters, EPSILON));
+ assertTrue("more to come", reader.next(key, clusterWritable));
+ assertEquals("2nd key", "C-1", key.toString());
+ c = new Pair<>(clusterWritable.getValue().getCenter().get(0),
+ clusterWritable.getValue().getCenter().get(1));
+ assertTrue("center "+c+" not found", findAndRemove(c, refCenters, EPSILON));
+ assertFalse("more to come", reader.next(key, clusterWritable));
+ } finally {
+ Closeables.close(reader, true);
+ }
+ }
+
+ /** Story: User can cluster points using sequential execution */
+ @Test
+ public void testClusteringManhattanSeq() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration config = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file1"), fs, config);
+ // now run the Canopy Driver in sequential mode
+ Path output = getTestTempDirPath("output");
+ CanopyDriver.run(config, getTestTempDirPath("testdata"), output,
+ manhattanDistanceMeasure, 3.1, 2.1, true, 0.0, true);
+
+ // verify output from sequence file
+ Path path = new Path(output, "clusters-0-final/part-r-00000");
+ int ix = 0;
+ for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(path, true,
+ config)) {
+ assertEquals("Center [" + ix + ']', manhattanCentroids.get(ix), clusterWritable.getValue()
+ .getCenter());
+ ix++;
+ }
+
+ path = new Path(output, "clusteredPoints/part-m-0");
+ long count = HadoopUtil.countRecords(path, config);
+ assertEquals("number of points", points.size(), count);
+ }
+
+ /** Story: User can cluster points using sequential execution */
+ @Test
+ public void testClusteringEuclideanSeq() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration config = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file1"), fs, config);
+ // now run the Canopy Driver in sequential mode
+ Path output = getTestTempDirPath("output");
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION),
+ getTestTempDirPath("testdata").toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.T1_OPTION), "3.1",
+ optKey(DefaultOptionCreator.T2_OPTION), "2.1",
+ optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION),
+ optKey(DefaultOptionCreator.METHOD_OPTION),
+ DefaultOptionCreator.SEQUENTIAL_METHOD };
+ ToolRunner.run(config, new CanopyDriver(), args);
+
+ // verify output from sequence file
+ Path path = new Path(output, "clusters-0-final/part-r-00000");
+
+ int ix = 0;
+ for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(path, true,
+ config)) {
+ assertEquals("Center [" + ix + ']', euclideanCentroids.get(ix), clusterWritable.getValue()
+ .getCenter());
+ ix++;
+ }
+
+ path = new Path(output, "clusteredPoints/part-m-0");
+ long count = HadoopUtil.countRecords(path, config);
+ assertEquals("number of points", points.size(), count);
+ }
+
+ /** Story: User can remove outliers while clustering points using sequential execution */
+ @Test
+ public void testClusteringEuclideanWithOutlierRemovalSeq() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration config = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file1"), fs, config);
+ // now run the Canopy Driver in sequential mode
+ Path output = getTestTempDirPath("output");
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION),
+ getTestTempDirPath("testdata").toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.T1_OPTION), "3.1",
+ optKey(DefaultOptionCreator.T2_OPTION), "2.1",
+ optKey(DefaultOptionCreator.OUTLIER_THRESHOLD), "0.5",
+ optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION),
+ optKey(DefaultOptionCreator.METHOD_OPTION),
+ DefaultOptionCreator.SEQUENTIAL_METHOD };
+ ToolRunner.run(config, new CanopyDriver(), args);
+
+ // verify output from sequence file
+ Path path = new Path(output, "clusters-0-final/part-r-00000");
+
+ int ix = 0;
+ for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(path, true,
+ config)) {
+ assertEquals("Center [" + ix + ']', euclideanCentroids.get(ix), clusterWritable.getValue()
+ .getCenter());
+ ix++;
+ }
+
+ path = new Path(output, "clusteredPoints/part-m-0");
+ long count = HadoopUtil.countRecords(path, config);
+ int expectedPointsHavingPDFGreaterThanThreshold = 6;
+ assertEquals("number of points", expectedPointsHavingPDFGreaterThanThreshold, count);
+ }
+
+
+ /**
+ * Story: User can produce final point clustering using a Hadoop map/reduce
+ * job and a ManhattanDistanceMeasure.
+ */
+ @Test
+ public void testClusteringManhattanMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file2"), fs, conf);
+ // now run the Job
+ Path output = getTestTempDirPath("output");
+ CanopyDriver.run(conf, getTestTempDirPath("testdata"), output,
+ manhattanDistanceMeasure, 3.1, 2.1, true, 0.0, false);
+ Path path = new Path(output, "clusteredPoints/part-m-00000");
+ long count = HadoopUtil.countRecords(path, conf);
+ assertEquals("number of points", points.size(), count);
+ }
+
+ /**
+ * Story: User can produce final point clustering using a Hadoop map/reduce
+ * job and a EuclideanDistanceMeasure.
+ */
+ @Test
+ public void testClusteringEuclideanMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file2"), fs, conf);
+ // now run the Job using the run() command. Others can use runJob().
+ Path output = getTestTempDirPath("output");
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION),
+ getTestTempDirPath("testdata").toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.T1_OPTION), "3.1",
+ optKey(DefaultOptionCreator.T2_OPTION), "2.1",
+ optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION) };
+ ToolRunner.run(getConfiguration(), new CanopyDriver(), args);
+ Path path = new Path(output, "clusteredPoints/part-m-00000");
+ long count = HadoopUtil.countRecords(path, conf);
+ assertEquals("number of points", points.size(), count);
+ }
+
+ /**
+ * Story: User can produce final point clustering using a Hadoop map/reduce
+ * job and a EuclideanDistanceMeasure and outlier removal threshold.
+ */
+ @Test
+ public void testClusteringEuclideanWithOutlierRemovalMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file2"), fs, conf);
+ // now run the Job using the run() command. Others can use runJob().
+ Path output = getTestTempDirPath("output");
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION),
+ getTestTempDirPath("testdata").toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.T1_OPTION), "3.1",
+ optKey(DefaultOptionCreator.T2_OPTION), "2.1",
+ optKey(DefaultOptionCreator.OUTLIER_THRESHOLD), "0.7",
+ optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION) };
+ ToolRunner.run(getConfiguration(), new CanopyDriver(), args);
+ Path path = new Path(output, "clusteredPoints/part-m-00000");
+ long count = HadoopUtil.countRecords(path, conf);
+ int expectedPointsAfterOutlierRemoval = 8;
+ assertEquals("number of points", expectedPointsAfterOutlierRemoval, count);
+ }
+
+
+ /**
+ * Story: User can set T3 and T4 values to be used by the reducer for its T1
+ * and T2 thresholds
+ */
+ @Test
+ public void testCanopyReducerT3T4Configuration() throws Exception {
+ CanopyReducer reducer = new CanopyReducer();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY,
+ "org.apache.mahout.common.distance.ManhattanDistanceMeasure");
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.T3_KEY, String.valueOf(1.1));
+ conf.set(CanopyConfigKeys.T4_KEY, String.valueOf(0.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "0");
+ DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<>();
+ Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context = DummyRecordWriter
+ .build(reducer, conf, writer, Text.class, VectorWritable.class);
+ reducer.setup(context);
+ assertEquals(1.1, reducer.getCanopyClusterer().getT1(), EPSILON);
+ assertEquals(0.1, reducer.getCanopyClusterer().getT2(), EPSILON);
+ }
+
+ /**
+ * Story: User can specify a clustering limit that prevents output of small
+ * clusters
+ */
+ @Test
+ public void testCanopyMapperClusterFilter() throws Exception {
+ CanopyMapper mapper = new CanopyMapper();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, manhattanDistanceMeasure
+ .getClass().getName());
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "3");
+ DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<>();
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter
+ .build(mapper, conf, writer);
+ mapper.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ // map the data
+ for (VectorWritable point : points) {
+ mapper.map(new Text(), point, context);
+ }
+ mapper.cleanup(context);
+ assertEquals("Number of map results", 1, writer.getData().size());
+ // now verify the output
+ List<VectorWritable> data = writer.getValue(new Text("centroid"));
+ assertEquals("Number of centroids", 2, data.size());
+ }
+
+ /**
+ * Story: User can specify a cluster filter that limits the minimum size of
+ * canopies produced by the reducer
+ */
+ @Test
+ public void testCanopyReducerClusterFilter() throws Exception {
+ CanopyReducer reducer = new CanopyReducer();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY,
+ "org.apache.mahout.common.distance.ManhattanDistanceMeasure");
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "3");
+ DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<>();
+ Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context = DummyRecordWriter
+ .build(reducer, conf, writer, Text.class, VectorWritable.class);
+ reducer.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ reducer.reduce(new Text("centroid"), points, context);
+ Set<Text> keys = writer.getKeys();
+ assertEquals("Number of centroids", 2, keys.size());
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
new file mode 100644
index 0000000..cbf0e55
--- /dev/null
+++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
@@ -0,0 +1,255 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Set;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.clustering.iterator.CanopyClusteringPolicy;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+public class ClusterClassificationDriverTest extends MahoutTestCase {
+
+ private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4},
+ {5, 4}, {4, 5}, {5, 5}, {9, 9}, {8, 8}};
+
+ private FileSystem fs;
+ private Path clusteringOutputPath;
+ private Configuration conf;
+ private Path pointsPath;
+ private Path classifiedOutputPath;
+ private List<Vector> firstCluster;
+ private List<Vector> secondCluster;
+ private List<Vector> thirdCluster;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ Configuration conf = getConfiguration();
+ fs = FileSystem.get(conf);
+ firstCluster = Lists.newArrayList();
+ secondCluster = Lists.newArrayList();
+ thirdCluster = Lists.newArrayList();
+
+ }
+
+ private static List<VectorWritable> getPointsWritable(double[][] raw) {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : raw) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+ @Test
+ public void testVectorClassificationWithOutlierRemovalMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ pointsPath = getTestTempDirPath("points");
+ clusteringOutputPath = getTestTempDirPath("output");
+ classifiedOutputPath = getTestTempDirPath("classifiedClusters");
+ HadoopUtil.delete(conf, classifiedOutputPath);
+
+ conf = getConfiguration();
+
+ ClusteringTestUtils.writePointsToFile(points, true,
+ new Path(pointsPath, "file1"), fs, conf);
+ runClustering(pointsPath, conf, false);
+ runClassificationWithOutlierRemoval(false);
+ collectVectorsForAssertion();
+ assertVectorsWithOutlierRemoval();
+ }
+
+ @Test
+ public void testVectorClassificationWithoutOutlierRemoval() throws Exception {
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ pointsPath = getTestTempDirPath("points");
+ clusteringOutputPath = getTestTempDirPath("output");
+ classifiedOutputPath = getTestTempDirPath("classify");
+
+ conf = getConfiguration();
+
+ ClusteringTestUtils.writePointsToFile(points,
+ new Path(pointsPath, "file1"), fs, conf);
+ runClustering(pointsPath, conf, true);
+ runClassificationWithoutOutlierRemoval();
+ collectVectorsForAssertion();
+ assertVectorsWithoutOutlierRemoval();
+ }
+
+ @Test
+ public void testVectorClassificationWithOutlierRemoval() throws Exception {
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ pointsPath = getTestTempDirPath("points");
+ clusteringOutputPath = getTestTempDirPath("output");
+ classifiedOutputPath = getTestTempDirPath("classify");
+
+ conf = getConfiguration();
+
+ ClusteringTestUtils.writePointsToFile(points,
+ new Path(pointsPath, "file1"), fs, conf);
+ runClustering(pointsPath, conf, true);
+ runClassificationWithOutlierRemoval(true);
+ collectVectorsForAssertion();
+ assertVectorsWithOutlierRemoval();
+ }
+
+ private void runClustering(Path pointsPath, Configuration conf,
+ Boolean runSequential) throws IOException, InterruptedException,
+ ClassNotFoundException {
+ CanopyDriver.run(conf, pointsPath, clusteringOutputPath,
+ new ManhattanDistanceMeasure(), 3.1, 2.1, false, 0.0, runSequential);
+ Path finalClustersPath = new Path(clusteringOutputPath, "clusters-0-final");
+ ClusterClassifier.writePolicy(new CanopyClusteringPolicy(),
+ finalClustersPath);
+ }
+
+ private void runClassificationWithoutOutlierRemoval()
+ throws IOException, InterruptedException, ClassNotFoundException {
+ ClusterClassificationDriver.run(getConfiguration(), pointsPath, clusteringOutputPath, classifiedOutputPath, 0.0, true, true);
+ }
+
+ private void runClassificationWithOutlierRemoval(boolean runSequential)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ ClusterClassificationDriver.run(getConfiguration(), pointsPath, clusteringOutputPath, classifiedOutputPath, 0.73, true, runSequential);
+ }
+
+ private void collectVectorsForAssertion() throws IOException {
+ Path[] partFilePaths = FileUtil.stat2Paths(fs
+ .globStatus(classifiedOutputPath));
+ FileStatus[] listStatus = fs.listStatus(partFilePaths,
+ PathFilters.partFilter());
+ for (FileStatus partFile : listStatus) {
+ SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs,
+ partFile.getPath(), conf);
+ Writable clusterIdAsKey = new IntWritable();
+ WeightedPropertyVectorWritable point = new WeightedPropertyVectorWritable();
+ while (classifiedVectors.next(clusterIdAsKey, point)) {
+ collectVector(clusterIdAsKey.toString(), point.getVector());
+ }
+ }
+ }
+
+ private void collectVector(String clusterId, Vector vector) {
+ if ("0".equals(clusterId)) {
+ firstCluster.add(vector);
+ } else if ("1".equals(clusterId)) {
+ secondCluster.add(vector);
+ } else if ("2".equals(clusterId)) {
+ thirdCluster.add(vector);
+ }
+ }
+
+ private void assertVectorsWithOutlierRemoval() {
+ checkClustersWithOutlierRemoval();
+ }
+
+ private void assertVectorsWithoutOutlierRemoval() {
+ assertFirstClusterWithoutOutlierRemoval();
+ assertSecondClusterWithoutOutlierRemoval();
+ assertThirdClusterWithoutOutlierRemoval();
+ }
+
+ private void assertThirdClusterWithoutOutlierRemoval() {
+ Assert.assertEquals(2, thirdCluster.size());
+ for (Vector vector : thirdCluster) {
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:9.0,1:9.0}",
+ "{0:8.0,1:8.0}"}, vector.asFormatString()));
+ }
+ }
+
+ private void assertSecondClusterWithoutOutlierRemoval() {
+ Assert.assertEquals(4, secondCluster.size());
+ for (Vector vector : secondCluster) {
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:4.0,1:4.0}",
+ "{0:5.0,1:4.0}", "{0:4.0,1:5.0}", "{0:5.0,1:5.0}"},
+ vector.asFormatString()));
+ }
+ }
+
+ private void assertFirstClusterWithoutOutlierRemoval() {
+ Assert.assertEquals(3, firstCluster.size());
+ for (Vector vector : firstCluster) {
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:1.0,1:1.0}",
+ "{0:2.0,1:1.0}", "{0:1.0,1:2.0}"}, vector.asFormatString()));
+ }
+ }
+
+ private void checkClustersWithOutlierRemoval() {
+ Set<String> reference = Sets.newHashSet("{0:9.0,1:9.0}", "{0:1.0,1:1.0}");
+
+ List<List<Vector>> clusters = Lists.newArrayList();
+ clusters.add(firstCluster);
+ clusters.add(secondCluster);
+ clusters.add(thirdCluster);
+
+ int singletonCnt = 0;
+ int emptyCnt = 0;
+ for (List<Vector> vList : clusters) {
+ if (vList.isEmpty()) {
+ emptyCnt++;
+ } else {
+ singletonCnt++;
+ assertEquals("expecting only singleton clusters; got size=" + vList.size(), 1, vList.size());
+ if (vList.get(0).getClass().equals(NamedVector.class)) {
+ Assert.assertTrue("not expecting cluster:" + ((NamedVector) vList.get(0)).getDelegate().asFormatString(),
+ reference.contains(((NamedVector) vList.get(0)).getDelegate().asFormatString()));
+ reference.remove(((NamedVector)vList.get(0)).getDelegate().asFormatString());
+ } else if (vList.get(0).getClass().equals(RandomAccessSparseVector.class)) {
+ Assert.assertTrue("not expecting cluster:" + vList.get(0).asFormatString(),
+ reference.contains(vList.get(0).asFormatString()));
+ reference.remove(vList.get(0).asFormatString());
+ }
+ }
+ }
+ Assert.assertEquals("Different number of empty clusters than expected!", 1, emptyCnt);
+ Assert.assertEquals("Different number of singletons than expected!", 2, singletonCnt);
+ Assert.assertEquals("Didn't match all reference clusters!", 0, reference.size());
+ }
+
+}