You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/09/03 14:28:41 UTC
svn commit: r992277 [3/9] - in /mahout/trunk:
core/src/main/java/org/apache/mahout/ep/
core/src/main/java/org/apache/mahout/fpm/pfpgrowth/
core/src/test/java/org/apache/mahout/cf/taste/common/
core/src/test/java/org/apache/mahout/cf/taste/hadoop/ core/...
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/PerceptronTrainerTest.java Fri Sep 3 12:28:34 2010
@@ -22,18 +22,22 @@ import org.apache.mahout.math.DenseMatri
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
+import org.junit.Before;
+import org.junit.Test;
-public class PerceptronTrainerTest extends MahoutTestCase {
+public final class PerceptronTrainerTest extends MahoutTestCase {
private PerceptronTrainer trainer;
@Override
- protected void setUp() throws Exception {
+ @Before
+ public void setUp() throws Exception {
super.setUp();
trainer = new PerceptronTrainer(3, 0.5, 0.1, 1.0, 1.0);
}
- public void testUpdate() throws TrainingException {
+ @Test
+ public void testUpdate() throws Exception {
double[] labels = { 1.0, 1.0, 1.0, 0.0 };
Vector labelset = new DenseVector(labels);
double[][] values = new double[3][4];
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/discriminative/WinnowTrainerTest.java Fri Sep 3 12:28:34 2010
@@ -22,17 +22,21 @@ import org.apache.mahout.math.DenseMatri
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
+import org.junit.Before;
+import org.junit.Test;
-public class WinnowTrainerTest extends MahoutTestCase {
+public final class WinnowTrainerTest extends MahoutTestCase {
private WinnowTrainer trainer;
@Override
- protected void setUp() throws Exception {
+ @Before
+ public void setUp() throws Exception {
super.setUp();
trainer = new WinnowTrainer(3);
}
+ @Test
public void testUpdate() throws Exception {
double[] labels = { 0.0, 0.0, 0.0, 1.0 };
Vector labelset = new DenseVector(labels);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java Fri Sep 3 12:28:34 2010
@@ -17,16 +17,17 @@
package org.apache.mahout.classifier.sgd;
+import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.jet.random.Exponential;
-import org.junit.Assert;
import org.junit.Test;
import java.util.Random;
-public class AdaptiveLogisticRegressionTest {
+public final class AdaptiveLogisticRegressionTest extends MahoutTestCase {
+
@Test
public void testTrain() {
// we make up data for a simple model
@@ -87,8 +88,6 @@ public class AdaptiveLogisticRegressionT
@Test
public void copyLearnsAsExpected() {
- RandomUtils.useTestSeed();
-
Random gen = RandomUtils.getRandom();
Exponential exp = new Exponential(0.5, gen);
Vector beta = new DenseVector(200);
@@ -118,24 +117,24 @@ public class AdaptiveLogisticRegressionT
for (int i = 0; i < 5000; i++) {
if (i % 1000 == 0) {
if (i == 0) {
- Assert.assertEquals("Should have started with no data", 0.5, w2.getLearner().auc(), 0.0001);
+ assertEquals("Should have started with no data", 0.5, w2.getLearner().auc(), 0.0001);
}
if (i == 1000) {
double auc2 = w2.getLearner().auc();
- Assert.assertTrue("Should have had head-start", Math.abs(auc2 - 0.5) > 0.1);
- Assert.assertTrue("AUC should improve quickly on copy", auc1 < auc2);
+ assertTrue("Should have had head-start", Math.abs(auc2 - 0.5) > 0.1);
+ assertTrue("AUC should improve quickly on copy", auc1 < auc2);
}
System.out.printf("%10d %.3f\n", i, w2.getLearner().auc());
}
AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
w2.train(r);
}
- Assert.assertEquals("Original should not change after copy is updated", auc1, w.getLearner().auc(), 1.0e-5);
+ assertEquals("Original should not change after copy is updated", auc1, w.getLearner().auc(), 1.0e-5);
// this improvement is really quite lenient
- Assert.assertTrue("AUC should improve significantly on copy", auc1 < w2.getLearner().auc() - 0.05);
+ assertTrue("AUC should improve significantly on copy", auc1 < w2.getLearner().auc() - 0.05);
// make sure that the copy didn't lose anything
- Assert.assertEquals(auc1, w.getLearner().auc(), 0);
+ assertEquals(auc1, w.getLearner().auc(), 0);
}
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java Fri Sep 3 12:28:34 2010
@@ -18,13 +18,14 @@
package org.apache.mahout.classifier.sgd;
import com.google.common.collect.ImmutableMap;
+import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.vectors.Dictionary;
-import org.junit.Assert;
import org.junit.Test;
-public class CsvRecordFactoryTest {
+public final class CsvRecordFactoryTest extends MahoutTestCase {
+
@Test
public void testAddToVector() {
RecordFactory csv = new CsvRecordFactory("y", ImmutableMap.of("x1", "n", "x2", "w", "x3", "t"));
@@ -33,41 +34,41 @@ public class CsvRecordFactoryTest {
Vector v = new DenseVector(2000);
int t = csv.processLine("ignore,3.1,yes,tiger, \"this is text\",ignore", v);
- Assert.assertEquals(0, t);
+ assertEquals(0, t);
// should have 9 values set
- Assert.assertEquals(9.0, v.norm(0), 0);
+ assertEquals(9.0, v.norm(0), 0);
// all should be = 1 except for the 3.1
- Assert.assertEquals(3.1, v.maxValue(), 0);
+ assertEquals(3.1, v.maxValue(), 0);
v.set(v.maxValueIndex(), 0);
- Assert.assertEquals(8.0, v.norm(0), 0);
- Assert.assertEquals(8.0, v.norm(1), 0);
- Assert.assertEquals(1.0, v.maxValue(), 0);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(8.0, v.norm(1), 0);
+ assertEquals(1.0, v.maxValue(), 0);
v.assign(0);
t = csv.processLine("ignore,5.3,no,line, \"and more text and more\",ignore", v);
- Assert.assertEquals(1, t);
+ assertEquals(1, t);
// should have 9 values set
- Assert.assertEquals(9.0, v.norm(0), 0);
+ assertEquals(9.0, v.norm(0), 0);
// all should be = 1 except for the 3.1
- Assert.assertEquals(5.3, v.maxValue(), 0);
+ assertEquals(5.3, v.maxValue(), 0);
v.set(v.maxValueIndex(), 0);
- Assert.assertEquals(8.0, v.norm(0), 0);
- Assert.assertEquals(10.339850002884626, v.norm(1), 1.0e-6);
- Assert.assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(10.339850002884626, v.norm(1), 1.0e-6);
+ assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6);
v.assign(0);
t = csv.processLine("ignore,5.3,invalid,line, \"and more text and more\",ignore", v);
- Assert.assertEquals(1, t);
+ assertEquals(1, t);
// should have 9 values set
- Assert.assertEquals(9.0, v.norm(0), 0);
+ assertEquals(9.0, v.norm(0), 0);
// all should be = 1 except for the 3.1
- Assert.assertEquals(5.3, v.maxValue(), 0);
+ assertEquals(5.3, v.maxValue(), 0);
v.set(v.maxValueIndex(), 0);
- Assert.assertEquals(8.0, v.norm(0), 0);
- Assert.assertEquals(10.339850002884626, v.norm(1), 1.0e-6);
- Assert.assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(10.339850002884626, v.norm(1), 1.0e-6);
+ assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6);
}
@Test
@@ -80,10 +81,10 @@ public class CsvRecordFactoryTest {
dict.intern("b");
dict.intern("qrz");
- Assert.assertEquals("[a, d, c, b, qrz]", dict.values().toString());
+ assertEquals("[a, d, c, b, qrz]", dict.values().toString());
Dictionary dict2 = Dictionary.fromList(dict.values());
- Assert.assertEquals("[a, d, c, b, qrz]", dict2.values().toString());
+ assertEquals("[a, d, c, b, qrz]", dict2.values().toString());
}
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java Fri Sep 3 12:28:34 2010
@@ -21,6 +21,7 @@ import com.google.common.collect.Lists;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.ep.Mapping;
import org.apache.mahout.math.DenseVector;
@@ -28,7 +29,6 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.UnaryFunction;
import org.apache.mahout.math.stats.OnlineAuc;
-import org.junit.Before;
import org.junit.Test;
import java.io.StringReader;
@@ -37,15 +37,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Random;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-
-public class ModelSerializerTest {
-
- @Before
- public void setUp() {
- RandomUtils.useTestSeed();
- }
+public final class ModelSerializerTest extends MahoutTestCase {
@Test
public void testSoftLimitDeserialization() {
@@ -179,7 +171,7 @@ public class ModelSerializerTest {
private static void train(OnlineLearner olr, int n) {
Vector beta = new DenseVector(new double[]{1, -1, 0, 0.5, -0.5});
- final Random gen = new Random(1);
+ Random gen = new Random(1);
for (int i = 0; i < n; i++) {
Vector x = randomVector(gen, 5);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java Fri Sep 3 12:28:34 2010
@@ -25,13 +25,13 @@ import com.google.common.io.CharStreams;
import com.google.common.io.Resources;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
-import org.junit.Assert;
import org.junit.Test;
import java.io.IOException;
@@ -40,7 +40,8 @@ import java.util.List;
import java.util.Map;
import java.util.Random;
-public class OnlineLogisticRegressionTest {
+public final class OnlineLogisticRegressionTest extends MahoutTestCase {
+
private Matrix input;
/**
@@ -96,56 +97,56 @@ public class OnlineLogisticRegressionTes
// zero vector gives no information. All classes are equal.
Vector v = lr.classify(new DenseVector(new double[]{0, 0}));
- Assert.assertEquals(1 / 3.0, v.get(0), 1.0e-8);
- Assert.assertEquals(1 / 3.0, v.get(1), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(0), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-8);
v = lr.classifyFull(new DenseVector(new double[]{0, 0}));
- Assert.assertEquals(1.0, v.zSum(), 1.0e-8);
- Assert.assertEquals(1 / 3.0, v.get(0), 1.0e-8);
- Assert.assertEquals(1 / 3.0, v.get(1), 1.0e-8);
- Assert.assertEquals(1 / 3.0, v.get(2), 1.0e-8);
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(0), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(2), 1.0e-8);
// weights for second vector component are still zero so all classifications are equally likely
v = lr.classify(new DenseVector(new double[]{0, 1}));
- Assert.assertEquals(1 / 3.0, v.get(0), 1.0e-3);
- Assert.assertEquals(1 / 3.0, v.get(1), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(0), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-3);
v = lr.classifyFull(new DenseVector(new double[]{0, 1}));
- Assert.assertEquals(1.0, v.zSum(), 1.0e-8);
- Assert.assertEquals(1 / 3.0, v.get(0), 1.0e-3);
- Assert.assertEquals(1 / 3.0, v.get(1), 1.0e-3);
- Assert.assertEquals(1 / 3.0, v.get(2), 1.0e-3);
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(0), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(2), 1.0e-3);
// but the weights on the first component are non-zero
v = lr.classify(new DenseVector(new double[]{1, 0}));
- Assert.assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
- Assert.assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
+ assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
v = lr.classifyFull(new DenseVector(new double[]{1, 0}));
- Assert.assertEquals(1.0, v.zSum(), 1.0e-8);
- Assert.assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
- Assert.assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
- Assert.assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8);
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
+ assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8);
lr.setBeta(0, 1, 1);
v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
- Assert.assertEquals(1.0, v.zSum(), 1.0e-8);
- Assert.assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3);
- Assert.assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3);
- Assert.assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3);
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3);
+ assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3);
lr.setBeta(1, 1, 3);
v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
- Assert.assertEquals(1.0, v.zSum(), 1.0e-8);
- Assert.assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8);
- Assert.assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8);
- Assert.assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8);
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8);
+ assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8);
+ assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8);
}
@Test
- public void testTrain() throws IOException {
+ public void testTrain() throws Exception {
Vector target = readStandardData();
@@ -196,14 +197,14 @@ public class OnlineLogisticRegressionTes
double maxAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
- Assert.assertEquals(0, meanAbsoluteError , 0.05);
- Assert.assertEquals(0, maxAbsoluteError, 0.3);
+ assertEquals(0, meanAbsoluteError , 0.05);
+ assertEquals(0, maxAbsoluteError, 0.3);
// convenience methods should give the same results
Vector v = lr.classifyScalar(input);
- Assert.assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-5);
+ assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-5);
v = lr.classifyFull(input).getColumn(1);
- Assert.assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-4);
+ assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-4);
}
/**
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java Fri Sep 3 12:28:34 2010
@@ -25,10 +25,10 @@ import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.SequenceFile;
import org.apache.mahout.math.VectorWritable;
-import java.io.File;
import java.io.IOException;
-public class ClusteringTestUtils {
+public final class ClusteringTestUtils {
+
private ClusteringTestUtils() {
}
@@ -56,16 +56,4 @@ public class ClusteringTestUtils {
writer.close();
}
- public static void rmr(String path) {
- File f = new File(path);
- if (f.exists()) {
- if (f.isDirectory()) {
- String[] contents = f.list();
- for (String content : contents) {
- rmr(f.toString() + File.separator + content);
- }
- }
- f.delete();
- }
- }
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java Fri Sep 3 12:28:34 2010
@@ -37,13 +37,15 @@ import org.apache.mahout.math.Vector;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
+import org.junit.Test;
-public class TestClusterInterface extends MahoutTestCase {
+public final class TestClusterInterface extends MahoutTestCase {
private static final Type MODEL_TYPE = new TypeToken<Model<Vector>>() {}.getType();
private static final Type CLUSTER_TYPE = new TypeToken<DirichletCluster>() {}.getType();
private static final DistanceMeasure measure = new ManhattanDistanceMeasure();
+ @Test
public void testDirichletNormalModel() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -58,6 +60,7 @@ public class TestClusterInterface extend
assertEquals("Json", format, model2.asFormatString(null));
}
+ @Test
public void testDirichletSampledNormalModel() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -72,6 +75,7 @@ public class TestClusterInterface extend
assertEquals("Json", format, model2.asFormatString(null));
}
+ @Test
public void testDirichletASNormalModel() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -86,6 +90,7 @@ public class TestClusterInterface extend
assertEquals("Json", format, model2.asFormatString(null));
}
+ @Test
public void testDirichletL1Model() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -100,6 +105,7 @@ public class TestClusterInterface extend
assertEquals("Json", format, model2.asFormatString(null));
}
+ @Test
public void testDirichletNormalModelClusterAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -109,6 +115,7 @@ public class TestClusterInterface extend
assertEquals("format", "C-5: nm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format);
}
+ @Test
public void testDirichletNormalModelClusterAsJsonString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -123,6 +130,7 @@ public class TestClusterInterface extend
assertEquals("model", cluster.asFormatString(null), result.asFormatString(null));
}
+ @Test
public void testDirichletAsymmetricSampledNormalModelClusterAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -132,6 +140,7 @@ public class TestClusterInterface extend
assertEquals("format", "C-5: asnm{n=0 m=[1.100, 2.200, 3.300] sd=[1.100, 2.200, 3.300]}", format);
}
+ @Test
public void testDirichletAsymmetricSampledNormalModelClusterAsJsonString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -147,6 +156,7 @@ public class TestClusterInterface extend
assertEquals("model", cluster.asFormatString(null), result.asFormatString(null));
}
+ @Test
public void testDirichletL1ModelClusterAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -156,6 +166,7 @@ public class TestClusterInterface extend
assertEquals("format", "C-5: l1m{n=0 c=[1.100, 2.200, 3.300]}", format);
}
+ @Test
public void testDirichletL1ModelClusterAsJsonString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -171,6 +182,7 @@ public class TestClusterInterface extend
assertEquals("model", cluster.asFormatString(null), result.asFormatString(null));
}
+ @Test
public void testCanopyAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -180,6 +192,7 @@ public class TestClusterInterface extend
assertEquals("format", "C-123{n=0 c=[1.100, 2.200, 3.300] r=[]}", formatString);
}
+ @Test
public void testCanopyAsFormatStringSparse() {
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
@@ -190,6 +203,7 @@ public class TestClusterInterface extend
assertEquals("format", "C-123{n=0 c=[0:1.100, 2:3.300] r=[]}", formatString);
}
+ @Test
public void testCanopyAsFormatStringWithBindings() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -200,6 +214,7 @@ public class TestClusterInterface extend
assertEquals("format", "C-123{n=0 c=[fee:1.100, 1:2.200, 2:3.300] r=[]}", formatString);
}
+ @Test
public void testCanopyAsFormatStringSparseWithBindings() {
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
@@ -210,6 +225,7 @@ public class TestClusterInterface extend
assertEquals("format", "C-123{n=0 c=[0:1.100, 2:3.300] r=[]}", formatString);
}
+ @Test
public void testClusterAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -219,6 +235,7 @@ public class TestClusterInterface extend
assertEquals("format", "CL-123{n=0 c=[1.100, 2.200, 3.300] r=[]}", formatString);
}
+ @Test
public void testClusterAsFormatStringSparse() {
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
@@ -229,6 +246,7 @@ public class TestClusterInterface extend
assertEquals("format", "CL-123{n=0 c=[0:1.100, 2:3.300] r=[]}", formatString);
}
+ @Test
public void testClusterAsFormatStringWithBindings() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -239,6 +257,7 @@ public class TestClusterInterface extend
assertEquals("format", "CL-123{n=0 c=[fee:1.100, 1:2.200, foo:3.300] r=[]}", formatString);
}
+ @Test
public void testClusterAsFormatStringSparseWithBindings() {
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
@@ -249,6 +268,7 @@ public class TestClusterInterface extend
assertEquals("format", "CL-123{n=0 c=[0:1.100, 2:3.300] r=[]}", formatString);
}
+ @Test
public void testMSCanopyAsFormatString() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -258,6 +278,7 @@ public class TestClusterInterface extend
assertEquals("format", "MSC-123{n=0 c=[1.100, 2.200, 3.300] r=[]}", formatString);
}
+ @Test
public void testMSCanopyAsFormatStringSparse() {
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
@@ -268,6 +289,7 @@ public class TestClusterInterface extend
assertEquals("format", "MSC-123{n=0 c=[0:1.100, 2:3.300] r=[]}", formatString);
}
+ @Test
public void testMSCanopyAsFormatStringWithBindings() {
double[] d = { 1.1, 2.2, 3.3 };
Vector m = new DenseVector(d);
@@ -278,6 +300,7 @@ public class TestClusterInterface extend
assertEquals("format", "MSC-123{n=0 c=[fee:1.100, 1:2.200, foo:3.300] r=[]}", formatString);
}
+ @Test
public void testMSCanopyAsFormatStringSparseWithBindings() {
double[] d = { 1.1, 0.0, 3.3 };
Vector m = new SequentialAccessSparseVector(3);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestModelDistributionSerialization.java Fri Sep 3 12:28:34 2010
@@ -28,9 +28,11 @@ import org.apache.mahout.math.VectorWrit
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
+import org.junit.Test;
-public class TestModelDistributionSerialization extends MahoutTestCase {
+public final class TestModelDistributionSerialization extends MahoutTestCase {
+ @Test
public void testGaussianClusterDistribution() {
GaussianClusterDistribution dist = new GaussianClusterDistribution(new VectorWritable(new DenseVector(2)));
String json = dist.asJsonString();
@@ -40,11 +42,13 @@ public class TestModelDistributionSerial
Gson gson = builder.create();
GaussianClusterDistribution dist1 = (GaussianClusterDistribution) gson
.fromJson(json, AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
- assertEquals("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
+ assertSame("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
}
+ @Test
public void testDMClusterDistribution() {
- DistanceMeasureClusterDistribution dist = new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2)));
+ DistanceMeasureClusterDistribution dist =
+ new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2)));
String json = dist.asJsonString();
GsonBuilder builder = new GsonBuilder();
builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
@@ -52,13 +56,14 @@ public class TestModelDistributionSerial
Gson gson = builder.create();
DistanceMeasureClusterDistribution dist1 = (DistanceMeasureClusterDistribution) gson
.fromJson(json, AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
- assertEquals("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
- assertEquals("measure", dist.getMeasure().getClass(), dist1.getMeasure().getClass());
+ assertSame("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
+ assertSame("measure", dist.getMeasure().getClass(), dist1.getMeasure().getClass());
}
+ @Test
public void testDMClusterDistribution2() {
- DistanceMeasureClusterDistribution dist = new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2)),
- new EuclideanDistanceMeasure());
+ DistanceMeasureClusterDistribution dist =
+ new DistanceMeasureClusterDistribution(new VectorWritable(new DenseVector(2)), new EuclideanDistanceMeasure());
String json = dist.asJsonString();
GsonBuilder builder = new GsonBuilder();
builder.registerTypeAdapter(ModelDistribution.class, new JsonModelDistributionAdapter());
@@ -66,7 +71,7 @@ public class TestModelDistributionSerial
Gson gson = builder.create();
DistanceMeasureClusterDistribution dist1 = (DistanceMeasureClusterDistribution) gson
.fromJson(json, AbstractVectorModelDistribution.MODEL_DISTRIBUTION_TYPE);
- assertEquals("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
- assertEquals("measure", dist.getMeasure().getClass(), dist1.getMeasure().getClass());
+ assertSame("prototype", dist.getModelPrototype().getClass(), dist1.getModelPrototype().getClass());
+ assertSame("measure", dist.getMeasure().getClass(), dist1.getMeasure().getClass());
}
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java Fri Sep 3 12:28:34 2010
@@ -1,3 +1,20 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.mahout.clustering;
import java.util.ArrayList;
@@ -17,9 +34,11 @@ import org.apache.mahout.common.distance
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.junit.Test;
-public class TestVectorModelClassifier extends MahoutTestCase {
+public final class TestVectorModelClassifier extends MahoutTestCase {
+ @Test
public void testDMClusterClassification() {
List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
DistanceMeasure measure = new ManhattanDistanceMeasure();
@@ -33,6 +52,7 @@ public class TestVectorModelClassifier e
assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(pdf, null));
}
+ @Test
public void testCanopyClassification() {
List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
DistanceMeasure measure = new ManhattanDistanceMeasure();
@@ -46,6 +66,7 @@ public class TestVectorModelClassifier e
assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(pdf, null));
}
+ @Test
public void testClusterClassification() {
List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
DistanceMeasure measure = new ManhattanDistanceMeasure();
@@ -59,6 +80,7 @@ public class TestVectorModelClassifier e
assertEquals("[2,2]", "[0.867, 0.117, 0.016]", AbstractCluster.formatVector(pdf, null));
}
+ @Test
public void testMSCanopyClassification() {
List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
DistanceMeasure measure = new ManhattanDistanceMeasure();
@@ -74,6 +96,7 @@ public class TestVectorModelClassifier e
}
}
+ @Test
public void testSoftClusterClassification() {
List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
DistanceMeasure measure = new ManhattanDistanceMeasure();
@@ -87,6 +110,7 @@ public class TestVectorModelClassifier e
assertEquals("[2,2]", "[0.735, 0.184, 0.082]", AbstractCluster.formatVector(pdf, null));
}
+ @Test
public void testGaussianClusterClassification() {
List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
models.add(new GaussianCluster(new DenseVector(2).assign(1), new DenseVector(2).assign(1), 0));
@@ -99,6 +123,7 @@ public class TestVectorModelClassifier e
assertEquals("[2,2]", "[0.806, 0.180, 0.015]", AbstractCluster.formatVector(pdf, null));
}
+ @Test
public void testASNClusterClassification() {
List<Model<VectorWritable>> models = new ArrayList<Model<VectorWritable>>();
models.add(new AsymmetricSampledNormalModel(0, new DenseVector(2).assign(1), new DenseVector(2).assign(1)));
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java Fri Sep 3 12:28:34 2010
@@ -18,6 +18,7 @@
package org.apache.mahout.clustering.canopy;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -29,6 +30,7 @@ import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.mapreduce.Reducer;
@@ -45,28 +47,26 @@ import org.apache.mahout.common.distance
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.junit.Before;
+import org.junit.Test;
-public class TestCanopyCreation extends MahoutTestCase {
+public final class TestCanopyCreation extends MahoutTestCase {
- private static final double[][] raw = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 } };
+ private static final double[][] RAW = {
+ { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 }
+ };
private List<Canopy> referenceManhattan;
-
private final DistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
-
private List<Vector> manhattanCentroids;
-
private List<Canopy> referenceEuclidean;
-
private final DistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
-
private List<Vector> euclideanCentroids;
-
private FileSystem fs;
private static List<VectorWritable> getPointsWritable() {
List<VectorWritable> points = new ArrayList<VectorWritable>();
- for (double[] fr : raw) {
+ for (double[] fr : RAW) {
Vector vec = new RandomAccessSparseVector(fr.length);
vec.assign(fr);
points.add(new VectorWritable(vec));
@@ -76,7 +76,7 @@ public class TestCanopyCreation extends
private static List<Vector> getPoints() {
List<Vector> points = new ArrayList<Vector>();
- for (double[] fr : raw) {
+ for (double[] fr : RAW) {
Vector vec = new RandomAccessSparseVector(fr.length);
vec.assign(fr);
points.add(vec);
@@ -90,13 +90,13 @@ public class TestCanopyCreation extends
* @param canopies
* a List<Canopy>
*/
- private static void printCanopies(List<Canopy> canopies) {
+ private static void printCanopies(Iterable<Canopy> canopies) {
for (Canopy canopy : canopies) {
System.out.println(canopy.asFormatString(null));
}
}
- private static Canopy findCanopy(Integer key, List<Canopy> canopies) {
+ private static Canopy findCanopy(Integer key, Iterable<Canopy> canopies) {
for (Canopy c : canopies) {
if (c.getId() == key) {
return c;
@@ -106,7 +106,8 @@ public class TestCanopyCreation extends
}
@Override
- protected void setUp() throws Exception {
+ @Before
+ public void setUp() throws Exception {
super.setUp();
fs = FileSystem.get(new Configuration());
referenceManhattan = CanopyClusterer.createCanopies(getPoints(), manhattanDistanceMeasure, 3.1, 2.1);
@@ -116,8 +117,8 @@ public class TestCanopyCreation extends
}
/** Story: User can cluster points using a ManhattanDistanceMeasure and a reference implementation */
+ @Test
public void testReferenceManhattan() throws Exception {
- System.out.println("testReferenceManhattan");
// see setUp for cluster creation
printCanopies(referenceManhattan);
assertEquals("number of canopies", 3, referenceManhattan.size());
@@ -129,14 +130,17 @@ public class TestCanopyCreation extends
double[] refCentroid = expectedCentroids[canopyIx];
Vector testCentroid = testCanopy.computeCentroid();
for (int pointIx = 0; pointIx < refCentroid.length; pointIx++) {
- assertEquals("canopy centroid " + canopyIx + '[' + pointIx + ']', refCentroid[pointIx], testCentroid.get(pointIx));
+ assertEquals("canopy centroid " + canopyIx + '[' + pointIx + ']',
+ refCentroid[pointIx],
+ testCentroid.get(pointIx),
+ EPSILON);
}
}
}
/** Story: User can cluster points using a EuclideanDistanceMeasure and a reference implementation */
+ @Test
public void testReferenceEuclidean() throws Exception {
- System.out.println("testReferenceEuclidean()");
// see setUp for cluster creation
printCanopies(referenceEuclidean);
assertEquals("number of canopies", 3, referenceEuclidean.size());
@@ -148,7 +152,9 @@ public class TestCanopyCreation extends
double[] refCentroid = expectedCentroids[canopyIx];
Vector testCentroid = testCanopy.computeCentroid();
for (int pointIx = 0; pointIx < refCentroid.length; pointIx++) {
- assertEquals("canopy centroid " + canopyIx + '[' + pointIx + ']', refCentroid[pointIx], testCentroid.get(pointIx));
+ assertEquals("canopy centroid " + canopyIx + '[' + pointIx + ']',
+ refCentroid[pointIx], testCentroid.get(pointIx),
+ EPSILON);
}
}
}
@@ -157,6 +163,7 @@ public class TestCanopyCreation extends
* Story: User can produce initial canopy centers using a ManhattanDistanceMeasure and a
* CanopyMapper which clusters input points to produce an output set of canopy centroid points.
*/
+ @Test
public void testCanopyMapperManhattan() throws Exception {
CanopyMapper mapper = new CanopyMapper();
Configuration conf = new Configuration();
@@ -164,9 +171,8 @@ public class TestCanopyCreation extends
conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<Text, VectorWritable>();
- Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter.build(mapper,
- conf,
- writer);
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context =
+ DummyRecordWriter.build(mapper, conf, writer);
mapper.setup(context);
List<VectorWritable> points = getPointsWritable();
@@ -188,6 +194,7 @@ public class TestCanopyCreation extends
* Story: User can produce initial canopy centers using a EuclideanDistanceMeasure and a
* CanopyMapper/Combiner which clusters input points to produce an output set of canopy centroid points.
*/
+ @Test
public void testCanopyMapperEuclidean() throws Exception {
CanopyMapper mapper = new CanopyMapper();
Configuration conf = new Configuration();
@@ -195,9 +202,8 @@ public class TestCanopyCreation extends
conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<Text, VectorWritable>();
- Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter.build(mapper,
- conf,
- writer);
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context =
+ DummyRecordWriter.build(mapper, conf, writer);
mapper.setup(context);
List<VectorWritable> points = getPointsWritable();
@@ -219,6 +225,7 @@ public class TestCanopyCreation extends
* Story: User can produce final canopy centers using a ManhattanDistanceMeasure and a CanopyReducer which
* clusters input centroid points to produce an output set of final canopy centroid points.
*/
+ @Test
public void testCanopyReducerManhattan() throws Exception {
CanopyReducer reducer = new CanopyReducer();
Configuration conf = new Configuration();
@@ -250,6 +257,7 @@ public class TestCanopyCreation extends
* Story: User can produce final canopy centers using a EuclideanDistanceMeasure and a CanopyReducer which
* clusters input centroid points to produce an output set of final canopy centroid points.
*/
+ @Test
public void testCanopyReducerEuclidean() throws Exception {
CanopyReducer reducer = new CanopyReducer();
Configuration conf = new Configuration();
@@ -281,6 +289,7 @@ public class TestCanopyCreation extends
* Story: User can produce final canopy centers using a Hadoop map/reduce job and a
* ManhattanDistanceMeasure.
*/
+ @Test
public void testCanopyGenManhattanMR() throws Exception {
List<VectorWritable> points = getPointsWritable();
Configuration config = new Configuration();
@@ -294,16 +303,16 @@ public class TestCanopyCreation extends
Path path = new Path(output, "clusters-0/part-r-00000");
FileSystem fs = FileSystem.get(path.toUri(), config);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config);
- Text key = new Text();
+ Writable key = new Text();
Canopy canopy = new Canopy();
assertTrue("more to come", reader.next(key, canopy));
assertEquals("1st key", "C-0", key.toString());
- assertEquals("1st x value", 1.5, canopy.getCenter().get(0));
- assertEquals("1st y value", 1.5, canopy.getCenter().get(1));
+ assertEquals("1st x value", 1.5, canopy.getCenter().get(0), EPSILON);
+ assertEquals("1st y value", 1.5, canopy.getCenter().get(1), EPSILON);
assertTrue("more to come", reader.next(key, canopy));
assertEquals("2nd key", "C-1", key.toString());
- assertEquals("2nd x value", 4.333333333333334, canopy.getCenter().get(0));
- assertEquals("2nd y value", 4.333333333333334, canopy.getCenter().get(1));
+ assertEquals("2nd x value", 4.333333333333334, canopy.getCenter().get(0), EPSILON);
+ assertEquals("2nd y value", 4.333333333333334, canopy.getCenter().get(1), EPSILON);
assertFalse("more to come", reader.next(key, canopy));
reader.close();
}
@@ -312,6 +321,7 @@ public class TestCanopyCreation extends
* Story: User can produce final canopy centers using a Hadoop map/reduce job and a
* EuclideanDistanceMeasure.
*/
+ @Test
public void testCanopyGenEuclideanMR() throws Exception {
List<VectorWritable> points = getPointsWritable();
Configuration job = new Configuration();
@@ -325,21 +335,22 @@ public class TestCanopyCreation extends
Path path = new Path(output, "clusters-0/part-r-00000");
FileSystem fs = FileSystem.get(path.toUri(), job);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
- Text key = new Text();
+ Writable key = new Text();
Canopy value = new Canopy();
assertTrue("more to come", reader.next(key, value));
assertEquals("1st key", "C-0", key.toString());
- assertEquals("1st x value", 1.8, value.getCenter().get(0));
- assertEquals("1st y value", 1.8, value.getCenter().get(1));
+ assertEquals("1st x value", 1.8, value.getCenter().get(0), EPSILON);
+ assertEquals("1st y value", 1.8, value.getCenter().get(1), EPSILON);
assertTrue("more to come", reader.next(key, value));
assertEquals("2nd key", "C-1", key.toString());
- assertEquals("2nd x value", 4.433333333333334, value.getCenter().get(0));
- assertEquals("2nd y value", 4.433333333333334, value.getCenter().get(1));
+ assertEquals("2nd x value", 4.433333333333334, value.getCenter().get(0), EPSILON);
+ assertEquals("2nd y value", 4.433333333333334, value.getCenter().get(1), EPSILON);
assertFalse("more to come", reader.next(key, value));
reader.close();
}
/** Story: User can cluster a subset of the points using a ClusterMapper and a ManhattanDistanceMeasure. */
+ @Test
public void testClusterMapperManhattan() throws Exception {
ClusterMapper mapper = new ClusterMapper();
Configuration conf = new Configuration();
@@ -351,7 +362,7 @@ public class TestCanopyCreation extends
.build(mapper, conf, writer);
mapper.setup(context);
- List<Canopy> canopies = new ArrayList<Canopy>();
+ Collection<Canopy> canopies = new ArrayList<Canopy>();
int nextCanopyId = 0;
for (Vector centroid : manhattanCentroids) {
canopies.add(new Canopy(centroid, nextCanopyId++, manhattanDistanceMeasure));
@@ -375,6 +386,7 @@ public class TestCanopyCreation extends
}
/** Story: User can cluster a subset of the points using a ClusterMapper and a EuclideanDistanceMeasure. */
+ @Test
public void testClusterMapperEuclidean() throws Exception {
ClusterMapper mapper = new ClusterMapper();
Configuration conf = new Configuration();
@@ -386,7 +398,7 @@ public class TestCanopyCreation extends
.build(mapper, conf, writer);
mapper.setup(context);
- List<Canopy> canopies = new ArrayList<Canopy>();
+ Collection<Canopy> canopies = new ArrayList<Canopy>();
int nextCanopyId = 0;
for (Vector centroid : euclideanCentroids) {
canopies.add(new Canopy(centroid, nextCanopyId++, euclideanDistanceMeasure));
@@ -410,6 +422,7 @@ public class TestCanopyCreation extends
}
/** Story: User can cluster points using sequential execution */
+ @Test
public void testClusteringManhattanSeq() throws Exception {
List<VectorWritable> points = getPointsWritable();
Configuration config = new Configuration();
@@ -422,18 +435,18 @@ public class TestCanopyCreation extends
Path path = new Path(output, "clusters-0/part-r-00000");
FileSystem fs = FileSystem.get(path.toUri(), config);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config);
- Text key = new Text();
+ Writable key = new Text();
Canopy canopy = new Canopy();
int ix = 0;
while (reader.next(key, canopy)) {
- assertEquals("Center [" + ix + "]", manhattanCentroids.get(ix), canopy.getCenter());
+ assertEquals("Center [" + ix + ']', manhattanCentroids.get(ix), canopy.getCenter());
ix++;
}
reader.close();
path = new Path(output, "clusteredPoints/part-m-0");
reader = new SequenceFile.Reader(fs, path, config);
int count = 0;
- IntWritable clusterId = new IntWritable(0);
+ Writable clusterId = new IntWritable(0);
WeightedVectorWritable vector = new WeightedVectorWritable();
while (reader.next(clusterId, vector)) {
count++;
@@ -444,6 +457,7 @@ public class TestCanopyCreation extends
}
/** Story: User can cluster points using sequential execution */
+ @Test
public void testClusteringEuclideanSeq() throws Exception {
List<VectorWritable> points = getPointsWritable();
Configuration config = new Configuration();
@@ -462,18 +476,18 @@ public class TestCanopyCreation extends
Path path = new Path(output, "clusters-0/part-r-00000");
FileSystem fs = FileSystem.get(path.toUri(), config);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config);
- Text key = new Text();
+ Writable key = new Text();
Canopy canopy = new Canopy();
int ix = 0;
while (reader.next(key, canopy)) {
- assertEquals("Center [" + ix + "]", euclideanCentroids.get(ix), canopy.getCenter());
+ assertEquals("Center [" + ix + ']', euclideanCentroids.get(ix), canopy.getCenter());
ix++;
}
reader.close();
path = new Path(output, "clusteredPoints/part-m-0");
reader = new SequenceFile.Reader(fs, path, config);
int count = 0;
- IntWritable clusterId = new IntWritable(0);
+ Writable clusterId = new IntWritable(0);
WeightedVectorWritable vector = new WeightedVectorWritable();
while (reader.next(clusterId, vector)) {
count++;
@@ -487,6 +501,7 @@ public class TestCanopyCreation extends
* Story: User can produce final point clustering using a Hadoop map/reduce job and a
* ManhattanDistanceMeasure.
*/
+ @Test
public void testClusteringManhattanMR() throws Exception {
List<VectorWritable> points = getPointsWritable();
Configuration conf = new Configuration();
@@ -498,7 +513,7 @@ public class TestCanopyCreation extends
Path path = new Path(output, "clusteredPoints/part-m-00000");
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
int count = 0;
- IntWritable clusterId = new IntWritable(0);
+ Writable clusterId = new IntWritable(0);
WeightedVectorWritable vector = new WeightedVectorWritable();
while (reader.next(clusterId, vector)) {
count++;
@@ -512,6 +527,7 @@ public class TestCanopyCreation extends
* Story: User can produce final point clustering using a Hadoop map/reduce job and a
* EuclideanDistanceMeasure.
*/
+ @Test
public void testClusteringEuclideanMR() throws Exception {
List<VectorWritable> points = getPointsWritable();
Configuration conf = new Configuration();
@@ -520,7 +536,8 @@ public class TestCanopyCreation extends
// now run the Job using the run() command. Others can use runJob().
Path output = getTestTempDirPath("output");
String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), getTestTempDirPath("testdata").toString(),
- optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
EuclideanDistanceMeasure.class.getName(), optKey(DefaultOptionCreator.T1_OPTION), "3.1",
optKey(DefaultOptionCreator.T2_OPTION), "2.1", optKey(DefaultOptionCreator.CLUSTERING_OPTION),
optKey(DefaultOptionCreator.OVERWRITE_OPTION) };
@@ -528,17 +545,19 @@ public class TestCanopyCreation extends
Path path = new Path(output, "clusteredPoints/part-m-00000");
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
int count = 0;
- IntWritable canopyId = new IntWritable(0);
+ Writable canopyId = new IntWritable(0);
WeightedVectorWritable vw = new WeightedVectorWritable();
while (reader.next(canopyId, vw)) {
count++;
- System.out.println("Txt: " + canopyId.toString() + " Vec: " + AbstractCluster.formatVector(vw.getVector().get(), null));
+ System.out.println("Txt: " + canopyId.toString() + " Vec: "
+ + AbstractCluster.formatVector(vw.getVector().get(), null));
}
assertEquals("number of points", points.size(), count);
reader.close();
}
/** Story: Clustering algorithm must support arbitrary user defined distance measure */
+ @Test
public void testUserDefinedDistanceMeasure() throws Exception {
List<VectorWritable> points = getPointsWritable();
Configuration conf = new Configuration();
@@ -554,18 +573,18 @@ public class TestCanopyCreation extends
Path path = new Path(output, "clusters-0/part-r-00000");
FileSystem fs = FileSystem.get(path.toUri(), job);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
- Text key = new Text();
+ Writable key = new Text();
Canopy value = new Canopy();
assertTrue("more to come", reader.next(key, value));
assertEquals("1st key", "C-0", key.toString());
- assertEquals("1st x value", 1.5, value.getCenter().get(0));
- assertEquals("1st y value", 1.5, value.getCenter().get(1));
+ assertEquals("1st x value", 1.5, value.getCenter().get(0), EPSILON);
+ assertEquals("1st y value", 1.5, value.getCenter().get(1), EPSILON);
assertTrue("more to come", reader.next(key, value));
assertEquals("2nd key", "C-1", key.toString());
- assertEquals("1st x value", 4.333333333333334, value.getCenter().get(0));
- assertEquals("1st y value", 4.333333333333334, value.getCenter().get(1));
+ assertEquals("1st x value", 4.333333333333334, value.getCenter().get(0), EPSILON);
+ assertEquals("1st y value", 4.333333333333334, value.getCenter().get(1), EPSILON);
assertFalse("more to come", reader.next(key, value));
reader.close();
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java Fri Sep 3 12:28:34 2010
@@ -29,13 +29,16 @@ import org.apache.mahout.clustering.diri
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.VectorWritable;
+import org.junit.Before;
+import org.junit.Test;
-public class TestDirichletClustering extends MahoutTestCase {
+public final class TestDirichletClustering extends MahoutTestCase {
private List<VectorWritable> sampleData;
@Override
- protected void setUp() throws Exception {
+ @Before
+ public void setUp() throws Exception {
super.setUp();
sampleData = new ArrayList<VectorWritable>();
}
@@ -71,7 +74,7 @@ public class TestDirichletClustering ext
generateSamples(num, mx, my, sd, 2);
}
- private static void printResults(List<Cluster[]> result, int significant) {
+ private static void printResults(Iterable<Cluster[]> result, int significant) {
int row = 0;
for (Cluster[] r : result) {
System.out.print("sample[" + row++ + "]= ");
@@ -85,6 +88,7 @@ public class TestDirichletClustering ext
System.out.println();
}
+ @Test
public void testDirichletCluster100() {
System.out.println("testDirichletCluster100");
generateSamples(40, 1, 1, 3);
@@ -102,6 +106,7 @@ public class TestDirichletClustering ext
assertNotNull(result);
}
+ @Test
public void testDirichletCluster100s() {
System.out.println("testDirichletCluster100s");
generateSamples(40, 1, 1, 3);
@@ -119,6 +124,7 @@ public class TestDirichletClustering ext
assertNotNull(result);
}
+ @Test
public void testDirichletCluster100as() {
System.out.println("testDirichletCluster100as");
generateSamples(40, 1, 1, 3);
@@ -136,6 +142,7 @@ public class TestDirichletClustering ext
assertNotNull(result);
}
+ @Test
public void testDirichletCluster100C3() {
System.out.println("testDirichletCluster100");
generateSamples(40, 1, 1, 3, 3);
@@ -153,6 +160,7 @@ public class TestDirichletClustering ext
assertNotNull(result);
}
+ @Test
public void testDirichletCluster100sC3() {
System.out.println("testDirichletCluster100s");
generateSamples(40, 1, 1, 3, 3);
@@ -170,6 +178,7 @@ public class TestDirichletClustering ext
assertNotNull(result);
}
+ @Test
public void testDirichletCluster100asC3() {
System.out.println("testDirichletCluster100as");
generateSamples(40, 1, 1, 3, 3);
@@ -187,6 +196,7 @@ public class TestDirichletClustering ext
assertNotNull(result);
}
+ @Test
public void testDirichletGaussianCluster100() {
System.out.println("testDirichletGaussianCluster100");
generateSamples(40, 1, 1, 3);
@@ -204,6 +214,7 @@ public class TestDirichletClustering ext
assertNotNull(result);
}
+ @Test
public void testDirichletDMCluster100() {
System.out.println("testDirichletDMCluster100");
generateSamples(40, 1, 1, 3);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDistributions.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDistributions.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDistributions.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDistributions.java Fri Sep 3 12:28:34 2010
@@ -20,9 +20,11 @@ package org.apache.mahout.clustering.dir
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
+import org.junit.Test;
-public class TestDistributions extends MahoutTestCase {
+public final class TestDistributions extends MahoutTestCase {
+ @Test
public void testRbeta() {
for (double i = 0.01; i < 20; i += 0.25) {
System.out.println("rBeta(6,1," + i + ")="
@@ -30,6 +32,7 @@ public class TestDistributions extends M
}
}
+ @Test
public void testRchisq() {
for (int i = 0; i < 50; i++) {
System.out
@@ -37,6 +40,7 @@ public class TestDistributions extends M
}
}
+ @Test
public void testRnorm() {
for (int i = 1; i < 50; i++) {
System.out.println("rNorm(6,1," + i + ")="
@@ -44,6 +48,7 @@ public class TestDistributions extends M
}
}
+ @Test
public void testDnorm() {
for (int i = -30; i < 30; i++) {
double d = (i * 0.1);
@@ -57,6 +62,7 @@ public class TestDistributions extends M
}
}
+ @Test
public void testDnorm2() {
for (int i = -30; i < 30; i++) {
double d = (i * 0.1);
@@ -70,6 +76,7 @@ public class TestDistributions extends M
}
}
+ @Test
public void testDnorm1() {
for (int i = -10; i < 10; i++) {
double d = (i * 0.1);
@@ -83,6 +90,7 @@ public class TestDistributions extends M
}
}
+ @Test
public void testRmultinom1() {
double[] b = {0.4, 0.6};
Vector v = new DenseVector(b);
@@ -96,6 +104,7 @@ public class TestDistributions extends M
}
+ @Test
public void testRmultinom2() {
double[] b = {0.1, 0.2, 0.7};
Vector v = new DenseVector(b);
@@ -109,6 +118,7 @@ public class TestDistributions extends M
}
+ @Test
public void testRmultinom() {
double[] b = {0.1, 0.2, 0.8};
Vector v = new DenseVector(b);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java?rev=992277&r1=992276&r2=992277&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java Fri Sep 3 12:28:34 2010
@@ -19,6 +19,7 @@ package org.apache.mahout.clustering.dir
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
@@ -29,6 +30,7 @@ import org.apache.hadoop.io.DataOutputBu
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.ClusteringTestUtils;
@@ -41,18 +43,17 @@ import org.apache.mahout.clustering.diri
import org.apache.mahout.clustering.dirichlet.models.SampledNormalModel;
import org.apache.mahout.common.DummyRecordWriter;
import org.apache.mahout.common.MahoutTestCase;
-import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.junit.Before;
+import org.junit.Test;
-public class TestMapReduce extends MahoutTestCase {
-
- private List<VectorWritable> sampleData = new ArrayList<VectorWritable>();
+public final class TestMapReduce extends MahoutTestCase {
+ private Collection<VectorWritable> sampleData = new ArrayList<VectorWritable>();
private FileSystem fs;
-
private Configuration conf;
/**
@@ -104,28 +105,24 @@ public class TestMapReduce extends Mahou
}
@Override
- protected void setUp() throws Exception {
+ @Before
+ public void setUp() throws Exception {
super.setUp();
- RandomUtils.useTestSeed();
- ClusteringTestUtils.rmr("output");
- ClusteringTestUtils.rmr("input");
conf = new Configuration();
fs = FileSystem.get(conf);
- File f = new File("input");
- f.mkdir();
}
/** Test the basic Mapper */
+ @Test
public void testMapper() throws Exception {
generateSamples(10, 0, 0, 1);
DirichletState state = new DirichletState(new NormalModelDistribution(new VectorWritable(new DenseVector(2))), 5, 1);
DirichletMapper mapper = new DirichletMapper();
mapper.setup(state);
- DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<Text, VectorWritable>();
- Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter.build(mapper,
- conf,
- writer);
+ RecordWriter<Text,VectorWritable> writer = new DummyRecordWriter<Text, VectorWritable>();
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context =
+ DummyRecordWriter.build(mapper, conf, writer);
for (VectorWritable v : sampleData) {
mapper.map(null, v, context);
}
@@ -135,31 +132,29 @@ public class TestMapReduce extends Mahou
}
/** Test the basic Reducer */
+ @Test
public void testReducer() throws Exception {
generateSamples(100, 0, 0, 1);
generateSamples(100, 2, 0, 1);
generateSamples(100, 0, 2, 1);
generateSamples(100, 2, 2, 1);
- DirichletState state = new DirichletState(new SampledNormalDistribution(new VectorWritable(new DenseVector(2))), 20, 1);
+ DirichletState state =
+ new DirichletState(new SampledNormalDistribution(new VectorWritable(new DenseVector(2))), 20, 1);
DirichletMapper mapper = new DirichletMapper();
mapper.setup(state);
DummyRecordWriter<Text, VectorWritable> mapWriter = new DummyRecordWriter<Text, VectorWritable>();
- Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context mapContext = DummyRecordWriter.build(mapper,
- conf,
- mapWriter);
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context mapContext =
+ DummyRecordWriter.build(mapper, conf, mapWriter);
for (VectorWritable v : sampleData) {
mapper.map(null, v, mapContext);
}
DirichletReducer reducer = new DirichletReducer();
reducer.setup(state);
- DummyRecordWriter<Text, DirichletCluster> reduceWriter = new DummyRecordWriter<Text, DirichletCluster>();
- Reducer<Text, VectorWritable, Text, DirichletCluster>.Context reduceContext = DummyRecordWriter.build(reducer,
- conf,
- reduceWriter,
- Text.class,
- VectorWritable.class);
+ RecordWriter<Text, DirichletCluster> reduceWriter = new DummyRecordWriter<Text, DirichletCluster>();
+ Reducer<Text, VectorWritable, Text, DirichletCluster>.Context reduceContext =
+ DummyRecordWriter.build(reducer, conf, reduceWriter, Text.class, VectorWritable.class);
for (Text key : mapWriter.getKeys()) {
reducer.reduce(new Text(key), mapWriter.getValue(key), reduceContext);
}
@@ -169,34 +164,32 @@ public class TestMapReduce extends Mahou
}
/** Test the Mapper and Reducer in an iteration loop */
+ @Test
public void testMRIterations() throws Exception {
generateSamples(100, 0, 0, 1);
generateSamples(100, 2, 0, 1);
generateSamples(100, 0, 2, 1);
generateSamples(100, 2, 2, 1);
- DirichletState state = new DirichletState(new SampledNormalDistribution(new VectorWritable(new DenseVector(2))), 20, 1.0);
+ DirichletState state =
+ new DirichletState(new SampledNormalDistribution(new VectorWritable(new DenseVector(2))), 20, 1.0);
- List<Model<VectorWritable>[]> models = new ArrayList<Model<VectorWritable>[]>();
+ Collection<Model<VectorWritable>[]> models = new ArrayList<Model<VectorWritable>[]>();
for (int iteration = 0; iteration < 10; iteration++) {
DirichletMapper mapper = new DirichletMapper();
mapper.setup(state);
DummyRecordWriter<Text, VectorWritable> mapWriter = new DummyRecordWriter<Text, VectorWritable>();
- Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context mapContext = DummyRecordWriter.build(mapper,
- conf,
- mapWriter);
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context mapContext =
+ DummyRecordWriter.build(mapper, conf, mapWriter);
for (VectorWritable v : sampleData) {
mapper.map(null, v, mapContext);
}
DirichletReducer reducer = new DirichletReducer();
reducer.setup(state);
- DummyRecordWriter<Text, DirichletCluster> reduceWriter = new DummyRecordWriter<Text, DirichletCluster>();
- Reducer<Text, VectorWritable, Text, DirichletCluster>.Context reduceContext = DummyRecordWriter.build(reducer,
- conf,
- reduceWriter,
- Text.class,
- VectorWritable.class);
+ RecordWriter<Text, DirichletCluster> reduceWriter = new DummyRecordWriter<Text, DirichletCluster>();
+ Reducer<Text, VectorWritable, Text, DirichletCluster>.Context reduceContext =
+ DummyRecordWriter.build(reducer, conf, reduceWriter, Text.class, VectorWritable.class);
for (Text key : mapWriter.getKeys()) {
reducer.reduce(new Text(key), mapWriter.getValue(key), reduceContext);
}
@@ -223,7 +216,7 @@ public class TestMapReduce extends Mahou
System.out.println();
}
- private static void printResults(List<List<DirichletCluster>> clusters, int significant) {
+ private static void printResults(Iterable<List<DirichletCluster>> clusters, int significant) {
int row = 0;
for (List<DirichletCluster> r : clusters) {
System.out.print("sample[" + row++ + "]= ");
@@ -240,6 +233,7 @@ public class TestMapReduce extends Mahou
}
/** Test the Mapper and Reducer using the Driver in sequential execution mode */
+ @Test
public void testDriverIterationsSeq() throws Exception {
generateSamples(100, 0, 0, 0.5);
generateSamples(100, 2, 0, 0.2);
@@ -259,7 +253,7 @@ public class TestMapReduce extends Mahou
DefaultOptionCreator.SEQUENTIAL_METHOD };
new DirichletDriver().run(args);
// and inspect results
- List<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>();
+ Collection<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>();
Configuration conf = new Configuration();
conf.set(DirichletDriver.MODEL_DISTRIBUTION_KEY, modelDistribution.asJsonString());
conf.set(DirichletDriver.NUM_CLUSTERS_KEY, "20");
@@ -272,6 +266,7 @@ public class TestMapReduce extends Mahou
}
/** Test the Mapper and Reducer using the Driver in mapreduce mode */
+ @Test
public void testDriverIterationsMR() throws Exception {
generateSamples(100, 0, 0, 0.5);
generateSamples(100, 2, 0, 0.2);
@@ -290,7 +285,7 @@ public class TestMapReduce extends Mahou
optKey(DefaultOptionCreator.CLUSTERING_OPTION) };
new DirichletDriver().run(args);
// and inspect results
- List<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>();
+ Collection<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>();
Configuration conf = new Configuration();
conf.set(DirichletDriver.MODEL_DISTRIBUTION_KEY, modelDistribution.asJsonString());
conf.set(DirichletDriver.NUM_CLUSTERS_KEY, "20");
@@ -303,11 +298,13 @@ public class TestMapReduce extends Mahou
}
/** Test the Mapper and Reducer using the Driver */
+ @Test
public void testDriverMnRIterations() throws Exception {
generate4Datasets();
// Now run the driver
int maxIterations = 3;
- AbstractVectorModelDistribution modelDistribution = new SampledNormalDistribution(new VectorWritable(new DenseVector(2)));
+ AbstractVectorModelDistribution modelDistribution =
+ new SampledNormalDistribution(new VectorWritable(new DenseVector(2)));
DirichletDriver.runJob(getTestTempDirPath("input"),
getTestTempDirPath("output"),
modelDistribution,
@@ -347,6 +344,7 @@ public class TestMapReduce extends Mahou
}
/** Test the Mapper and Reducer using the Driver */
+ @Test
public void testDriverMnRnIterations() throws Exception {
generate4Datasets();
// Now run the driver
@@ -364,7 +362,7 @@ public class TestMapReduce extends Mahou
0,
false);
// and inspect results
- List<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>();
+ Collection<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>();
Configuration conf = new Configuration();
conf.set(DirichletDriver.MODEL_DISTRIBUTION_KEY, modelDistribution.asJsonString());
conf.set(DirichletDriver.NUM_CLUSTERS_KEY, "20");
@@ -377,11 +375,8 @@ public class TestMapReduce extends Mahou
}
/** Test the Mapper and Reducer using the Driver */
+ @Test
public void testDriverMnRnIterationsAsymmetric() throws Exception {
- File f = new File("input");
- for (File g : f.listFiles()) {
- g.delete();
- }
generateSamples(500, 0, 0, 0.5, 1.0);
ClusteringTestUtils.writePointsToFile(sampleData, getTestTempFilePath("input/data1.txt"), fs, conf);
sampleData = new ArrayList<VectorWritable>();
@@ -395,7 +390,8 @@ public class TestMapReduce extends Mahou
ClusteringTestUtils.writePointsToFile(sampleData, getTestTempFilePath("input/data4.txt"), fs, conf);
// Now run the driver
int maxIterations = 3;
- AbstractVectorModelDistribution modelDistribution = new SampledNormalDistribution(new VectorWritable(new DenseVector(2)));
+ AbstractVectorModelDistribution modelDistribution =
+ new SampledNormalDistribution(new VectorWritable(new DenseVector(2)));
DirichletDriver.runJob(getTestTempDirPath("input"),
getTestTempDirPath("output"),
modelDistribution,
@@ -408,7 +404,7 @@ public class TestMapReduce extends Mahou
0,
false);
// and inspect results
- List<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>();
+ Collection<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>();
Configuration conf = new Configuration();
conf.set(DirichletDriver.MODEL_DISTRIBUTION_KEY, modelDistribution.asJsonString());
conf.set(DirichletDriver.NUM_CLUSTERS_KEY, "20");
@@ -422,6 +418,7 @@ public class TestMapReduce extends Mahou
// =================== New Tests of Writable Implementations ====================
+ @Test
public void testNormalModelWritableSerialization() throws Exception {
double[] m = { 1.1, 2.2, 3.3 };
Model<?> model = new NormalModel(5, new DenseVector(m), 3.3);
@@ -434,6 +431,7 @@ public class TestMapReduce extends Mahou
assertEquals("models", model.toString(), model2.toString());
}
+ @Test
public void testSampledNormalModelWritableSerialization() throws Exception {
double[] m = { 1.1, 2.2, 3.3 };
Model<?> model = new SampledNormalModel(5, new DenseVector(m), 3.3);
@@ -446,6 +444,7 @@ public class TestMapReduce extends Mahou
assertEquals("models", model.toString(), model2.toString());
}
+ @Test
public void testAsymmetricSampledNormalModelWritableSerialization() throws Exception {
double[] m = { 1.1, 2.2, 3.3 };
double[] s = { 3.3, 4.4, 5.5 };
@@ -459,6 +458,7 @@ public class TestMapReduce extends Mahou
assertEquals("models", model.toString(), model2.toString());
}
+ @Test
public void testClusterWritableSerialization() throws Exception {
double[] m = { 1.1, 2.2, 3.3 };
DirichletCluster cluster = new DirichletCluster(new NormalModel(5, new DenseVector(m), 4), 10);
@@ -468,7 +468,7 @@ public class TestMapReduce extends Mahou
DataInputBuffer in = new DataInputBuffer();
in.reset(out.getData(), out.getLength());
cluster2.readFields(in);
- assertEquals("count", cluster.getTotalCount(), cluster2.getTotalCount());
+ assertEquals("count", cluster.getTotalCount(), cluster2.getTotalCount(), EPSILON);
assertNotNull("model null", cluster2.getModel());
assertEquals("model", cluster.getModel().toString(), cluster2.getModel().toString());
}