You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2011/05/09 03:49:20 UTC
svn commit: r1100858 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/
core/src/test/java/org/apache/mahout/clustering/
examples/src/main/java/org/apache/mahout/clustering/display/
Author: jeastman
Date: Mon May 9 01:49:20 2011
New Revision: 1100858
URL: http://svn.apache.org/viewvc?rev=1100858&view=rev
Log:
MAHOUT-479: added a new iterate method to ClusterIterator. Method accepts 3
hadoop Paths for input, prior and output information plus number of desired iterations. All algorithm data is pulled-from/pushed-to SequenceFiles. Added a unit test and improved the example DisplayKMeans, DisplayFuzzyKMeans and DisplayDirichlet to use the new file-based implementation. Check out Dirichlet.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java Mon May 9 01:49:20 2011
@@ -16,10 +16,21 @@
*/
package org.apache.mahout.clustering;
+import java.io.IOException;
import java.util.Iterator;
import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
/**
* This is an experimental clustering iterator which works with a
@@ -50,8 +61,7 @@ public class ClusterIterator {
* @return the posterior ClusterClassifier
*/
public ClusterClassifier iterate(List<Vector> data,
- ClusterClassifier classifier,
- int numIterations) {
+ ClusterClassifier classifier, int numIterations) {
for (int iteration = 1; iteration <= numIterations; iteration++) {
for (Vector vector : data) {
// classification yields probabilities
@@ -59,7 +69,8 @@ public class ClusterIterator {
// policy selects weights for models given those probabilities
Vector weights = policy.select(probabilities);
// training causes all models to observe data
- for (Iterator<Vector.Element> it = weights.iterateNonZero(); it.hasNext();) {
+ for (Iterator<Vector.Element> it = weights.iterateNonZero(); it
+ .hasNext();) {
int index = it.next().index();
classifier.train(index, vector, weights.get(index));
}
@@ -71,4 +82,69 @@ public class ClusterIterator {
}
return classifier;
}
+
+ /**
+ * Iterate over data using a prior-trained ClusterClassifier, for a number of
+ * iterations
+ *
+ * @param inPath
+ * a Path to input VectorWritables
+ * @param priorPath
+ * a Path to the prior classifier
+ * @param outPath
+ * a Path of output directory
+ * @param numIterations
+ * the int number of iterations to perform
+ * @throws IOException
+ */
+ public void iterate(Path inPath, Path priorPath, Path outPath,
+ int numIterations) throws IOException {
+ ClusterClassifier classifier = readClassifier(priorPath);
+ Configuration conf = new Configuration();
+ for (int iteration = 1; iteration <= numIterations; iteration++) {
+ for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(
+ inPath, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
+ Vector vector = vw.get();
+ // classification yields probabilities
+ Vector probabilities = classifier.classify(vector);
+ // policy selects weights for models given those probabilities
+ Vector weights = policy.select(probabilities);
+ // training causes all models to observe data
+ for (Iterator<Vector.Element> it = weights.iterateNonZero(); it
+ .hasNext();) {
+ int index = it.next().index();
+ classifier.train(index, vector, weights.get(index));
+ }
+ }
+ // compute the posterior models
+ classifier.close();
+ // update the policy
+ policy.update(classifier);
+ // output the classifier
+ writeClassifier(classifier, new Path(outPath, "classifier-" + iteration),
+ String.valueOf(iteration));
+ }
+ }
+
+ private void writeClassifier(ClusterClassifier classifier, Path outPath, String k)
+ throws IOException {
+ Configuration config = new Configuration();
+ FileSystem fs = FileSystem.get(outPath.toUri(), config);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, outPath,
+ Text.class, ClusterClassifier.class);
+ Writable key = new Text(k);
+ writer.append(key, classifier);
+ writer.close();
+ }
+
+ private ClusterClassifier readClassifier(Path inPath) throws IOException {
+ Configuration config = new Configuration();
+ FileSystem fs = FileSystem.get(inPath.toUri(), config);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, inPath, config);
+ Writable key = new Text();
+ ClusterClassifier classifierOut = new ClusterClassifier();
+ reader.next(key, classifierOut);
+ reader.close();
+ return classifierOut;
+ }
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java Mon May 9 01:49:20 2011
@@ -38,6 +38,7 @@ import org.apache.mahout.common.distance
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
import org.junit.Test;
public final class TestClusterClassifier extends MahoutTestCase {
@@ -94,12 +95,22 @@ public final class TestClusterClassifier
Configuration config = new Configuration();
Path path = new Path(getTestTempDirPath(), "output");
FileSystem fs = FileSystem.get(path.toUri(), config);
+ writeClassifier(classifier, config, path, fs);
+ return readClassifier(config, path, fs);
+ }
+
+ private void writeClassifier(ClusterClassifier classifier,
+ Configuration config, Path path, FileSystem fs) throws IOException {
SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, path,
Text.class, ClusterClassifier.class);
Writable key = new Text("test");
writer.append(key, classifier);
writer.close();
-
+ }
+
+ private ClusterClassifier readClassifier(Configuration config, Path path,
+ FileSystem fs) throws IOException {
+ Writable key;
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config);
key = new Text();
ClusterClassifier classifierOut = new ClusterClassifier();
@@ -232,11 +243,10 @@ public final class TestClusterClassifier
ClusterClassifier posterior = iterator.iterate(data, prior, 5);
assertEquals(3, posterior.getModels().size());
for (Cluster cluster : posterior.getModels()) {
- System.out
- .println(cluster.asFormatString(null));
+ System.out.println(cluster.asFormatString(null));
}
}
-
+
@Test
public void testClusterIteratorDirichlet() {
List<Vector> data = TestKmeansClustering
@@ -247,8 +257,42 @@ public final class TestClusterClassifier
ClusterClassifier posterior = iterator.iterate(data, prior, 5);
assertEquals(3, posterior.getModels().size());
for (Cluster cluster : posterior.getModels()) {
- System.out
- .println(cluster.asFormatString(null));
+ System.out.println(cluster.asFormatString(null));
+ }
+ }
+
+ @Test
+ public void testSeqFileClusterIteratorKMeans() throws IOException {
+ Path pointsPath = getTestTempDirPath("points");
+ Path priorPath = getTestTempDirPath("prior");
+ Path outPath = getTestTempDirPath("output");
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(conf);
+ List<VectorWritable> points = TestKmeansClustering
+ .getPointsWritable(TestKmeansClustering.REFERENCE);
+ ClusteringTestUtils.writePointsToFile(points,
+ new Path(pointsPath, "file1"), fs, conf);
+ Path path = new Path(priorPath, "priorClassifier");
+ ClusterClassifier prior = newClusterClassifier();
+ writeClassifier(prior, conf, path, fs);
+ assertEquals(3, prior.getModels().size());
+ System.out.println("Prior");
+ for (Cluster cluster : prior.getModels()) {
+ System.out.println(cluster.asFormatString(null));
+ }
+ ClusteringPolicy policy = new KMeansClusteringPolicy();
+ ClusterIterator iterator = new ClusterIterator(policy);
+ iterator.iterate(pointsPath, path, outPath, 5);
+
+ for (int i = 1; i <= 5; i++) {
+ System.out.println("Classifier-" + i);
+ ClusterClassifier posterior = readClassifier(conf, new Path(outPath,
+ "classifier-" + i), fs);
+ assertEquals(3, posterior.getModels().size());
+ for (Cluster cluster : posterior.getModels()) {
+ System.out.println(cluster.asFormatString(null));
+ }
+
}
}
}
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java Mon May 9 01:49:20 2011
@@ -39,8 +39,10 @@ import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ClusterClassifier;
import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
@@ -297,4 +299,26 @@ public class DisplayClustering extends F
protected static boolean isSignificant(Cluster cluster) {
return (double) cluster.getNumPoints() / SAMPLE_DATA.size() > significance;
}
+
+ protected static ClusterClassifier readClassifier(Configuration config, Path path)
+ throws IOException {
+ Writable key;
+ SequenceFile.Reader reader = new SequenceFile.Reader(
+ FileSystem.get(config), path, config);
+ key = new Text();
+ ClusterClassifier classifierOut = new ClusterClassifier();
+ reader.next(key, classifierOut);
+ reader.close();
+ return classifierOut;
+ }
+
+ protected static void writeClassifier(ClusterClassifier classifier, Configuration config, Path path)
+ throws IOException {
+ SequenceFile.Writer writer = new SequenceFile.Writer(
+ FileSystem.get(config), config, path, Text.class,
+ ClusterClassifier.class);
+ Writable key = new Text("test");
+ writer.append(key, classifier);
+ writer.close();
+ }
}
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java Mon May 9 01:49:20 2011
@@ -19,10 +19,12 @@ package org.apache.mahout.clustering.dis
import java.awt.Graphics;
import java.awt.Graphics2D;
+import java.io.IOException;
import java.util.ArrayList;
-import java.util.Iterator;
import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.ClusterClassifier;
import org.apache.mahout.clustering.ClusterIterator;
@@ -34,7 +36,6 @@ import org.apache.mahout.clustering.diri
import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -77,44 +78,63 @@ public class DisplayDirichlet extends Di
protected static void generateResults(
ModelDistribution<VectorWritable> modelDist, int numClusters,
- int numIterations, double alpha0, int thin, int burnin) {
- boolean b = false;
- if (b) {
- DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist,
- alpha0, numClusters, thin, burnin);
- List<Cluster[]> result = dc.cluster(numIterations);
- printModels(result, burnin);
- for (Cluster[] models : result) {
- List<Cluster> clusters = new ArrayList<Cluster>();
- for (Cluster cluster : models) {
- if (isSignificant(cluster)) {
- clusters.add(cluster);
- }
- }
- CLUSTERS.add(clusters);
- }
+ int numIterations, double alpha0, int thin, int burnin)
+ throws IOException {
+ boolean runClusterer = false;
+ if (runClusterer) {
+ runSequentialDirichletClusterer(modelDist, numClusters, numIterations, alpha0,
+ thin, burnin);
} else {
- List<Vector> points = new ArrayList<Vector>();
- for (VectorWritable sample : SAMPLE_DATA) {
- points.add(sample.get());
- }
- ClusteringPolicy policy = new DirichletClusteringPolicy(numClusters,
- numIterations);
- List<Cluster> models = new ArrayList<Cluster>();
- for (Model<VectorWritable> cluster : modelDist
- .sampleFromPrior(numClusters)) {
- models.add((Cluster) cluster);
+ runSequentialDirichletClassifier(modelDist, numClusters, numIterations);
+ }
+ }
+
+ private static void runSequentialDirichletClassifier(
+ ModelDistribution<VectorWritable> modelDist, int numClusters,
+ int numIterations) throws IOException {
+ List<Cluster> models = new ArrayList<Cluster>();
+ for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(numClusters)) {
+ models.add((Cluster) cluster);
+ }
+ ClusterClassifier prior = new ClusterClassifier(models);
+ Path samples = new Path("samples");
+ Path output = new Path("output");
+ Path priorClassifier = new Path(output, "clusters-0");
+ Configuration conf = new Configuration();
+ writeClassifier(prior, conf, priorClassifier);
+
+ ClusteringPolicy policy = new DirichletClusteringPolicy(numClusters,
+ numIterations);
+ new ClusterIterator(policy).iterate(samples, priorClassifier, output,
+ numIterations);
+ for (int i = 1; i <= numIterations; i++) {
+ ClusterClassifier posterior = readClassifier(conf, new Path(output,
+ "classifier-" + i));
+ List<Cluster> clusters = new ArrayList<Cluster>();
+ for (Cluster cluster : posterior.getModels()) {
+ if (isSignificant(cluster)) {
+ clusters.add(cluster);
+ }
}
- ClusterClassifier prior = new ClusterClassifier(models);
- ClusterIterator iterator = new ClusterIterator(policy);
- ClusterClassifier posterior = iterator.iterate(points, prior, 5);
- List<Cluster> models2 = posterior.getModels();
- for (Iterator<Cluster> it = models2.iterator(); it.hasNext();) {
- if (!isSignificant(it.next())) {
- it.remove();
+ CLUSTERS.add(clusters);
+ }
+ }
+
+ private static void runSequentialDirichletClusterer(
+ ModelDistribution<VectorWritable> modelDist, int numClusters,
+ int numIterations, double alpha0, int thin, int burnin) {
+ DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist,
+ alpha0, numClusters, thin, burnin);
+ List<Cluster[]> result = dc.cluster(numIterations);
+ printModels(result, burnin);
+ for (Cluster[] models : result) {
+ List<Cluster> clusters = new ArrayList<Cluster>();
+ for (Cluster cluster : models) {
+ if (isSignificant(cluster)) {
+ clusters.add(cluster);
}
}
- CLUSTERS.add(models2);
+ CLUSTERS.add(clusters);
}
}
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java Mon May 9 01:49:20 2011
@@ -19,6 +19,7 @@ package org.apache.mahout.clustering.dis
import java.awt.Graphics;
import java.awt.Graphics2D;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@@ -37,7 +38,6 @@ import org.apache.mahout.common.RandomUt
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
class DisplayFuzzyKMeans extends DisplayClustering {
@@ -59,45 +59,61 @@ class DisplayFuzzyKMeans extends Display
Path samples = new Path("samples");
Path output = new Path("output");
+ int numClusters = 3;
+ int maxIterations = 10;
Configuration conf = new Configuration();
HadoopUtil.delete(conf, samples);
HadoopUtil.delete(conf, output);
RandomUtils.useTestSeed();
DisplayClustering.generateSamples();
- boolean b = false;
- if (b) {
- writeSampleData(samples);
- Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
- output, "clusters-0"), 3, measure);
- double threshold = 0.001;
- int numIterations = 10;
- int m = 3;
- FuzzyKMeansDriver.run(samples, clusters, output, measure, threshold,
- numIterations, m, true, true, threshold, true);
-
- loadClusters(output);
+ writeSampleData(samples);
+ boolean runClusterer = false;
+ if (runClusterer) {
+ runSequentialFuzzyKClusterer(conf, samples, output, measure, numClusters,
+ maxIterations);
} else {
- List<Vector> points = new ArrayList<Vector>();
- for (VectorWritable sample : SAMPLE_DATA) {
- points.add(sample.get());
- }
- List<Cluster> initialClusters = new ArrayList<Cluster>();
- int id = 0;
- int numClusters = 4;
- for (Vector point : points) {
- if (initialClusters.size() < Math.min(numClusters, points.size())) {
- initialClusters.add(new SoftCluster(point, id++, measure));
- } else {
- break;
- }
- }
-
- ClusterClassifier prior = new ClusterClassifier(initialClusters);
- ClusteringPolicy policy = new FuzzyKMeansClusteringPolicy();
- ClusterClassifier posterior = new ClusterIterator(policy).iterate(points,
- prior, 10);
- CLUSTERS.add(posterior.getModels());
+ runSequentialFuzzyKClassifier(conf, samples, output, measure,
+ numClusters, maxIterations);
}
new DisplayFuzzyKMeans();
}
+
+ private static void runSequentialFuzzyKClassifier(Configuration conf,
+ Path samples, Path output, DistanceMeasure measure, int numClusters,
+ int maxIterations) throws IOException {
+ List<Vector> points = new ArrayList<Vector>();
+ for (int i = 0; i < numClusters; i++) {
+ points.add(SAMPLE_DATA.get(i).get());
+ }
+ List<Cluster> initialClusters = new ArrayList<Cluster>();
+ int id = 0;
+ for (Vector point : points) {
+ initialClusters.add(new SoftCluster(point, id++, measure));
+ }
+ ClusterClassifier prior = new ClusterClassifier(initialClusters);
+ Path priorClassifier = new Path(output, "classifier-0");
+ writeClassifier(prior, conf, priorClassifier);
+
+ ClusteringPolicy policy = new FuzzyKMeansClusteringPolicy();
+ new ClusterIterator(policy).iterate(samples, priorClassifier, output,
+ maxIterations);
+ for (int i = 1; i <= maxIterations; i++) {
+ ClusterClassifier posterior = readClassifier(conf, new Path(output,
+ "classifier-" + i));
+ CLUSTERS.add(posterior.getModels());
+ }
+ }
+
+ private static void runSequentialFuzzyKClusterer(Configuration conf, Path samples,
+ Path output, DistanceMeasure measure, int numClusters, int maxIterations)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
+ output, "clusters-0"), 3, measure);
+ double threshold = 0.001;
+ int m = 3;
+ FuzzyKMeansDriver.run(samples, clusters, output, measure, threshold,
+ maxIterations, m, true, true, threshold, true);
+
+ loadClusters(output);
+ }
}
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java?rev=1100858&r1=1100857&r2=1100858&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java Mon May 9 01:49:20 2011
@@ -19,16 +19,17 @@ package org.apache.mahout.clustering.dis
import java.awt.Graphics;
import java.awt.Graphics2D;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.ClusterClassifier;
import org.apache.mahout.clustering.ClusterIterator;
import org.apache.mahout.clustering.ClusteringPolicy;
import org.apache.mahout.clustering.KMeansClusteringPolicy;
-import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.kmeans.KMeansDriver;
import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
import org.apache.mahout.common.HadoopUtil;
@@ -36,7 +37,6 @@ import org.apache.mahout.common.RandomUt
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
class DisplayKMeans extends DisplayClustering {
@@ -53,47 +53,64 @@ class DisplayKMeans extends DisplayClust
Path samples = new Path("samples");
Path output = new Path("output");
Configuration conf = new Configuration();
+ int numClusters = 3;
+ int maxIterations = 10;
HadoopUtil.delete(conf, samples);
HadoopUtil.delete(conf, output);
RandomUtils.useTestSeed();
DisplayClustering.generateSamples();
writeSampleData(samples);
- boolean b = false;
- if (b) {
- Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
- output, "clusters-0"), 3, measure);
- int maxIter = 10;
- double distanceThreshold = 0.001;
- KMeansDriver.run(samples, clusters, output, measure, distanceThreshold,
- maxIter, true, true);
- loadClusters(output);
+ boolean runClusterer = false;
+ if (runClusterer) {
+ runSequentialKMeansClusterer(conf, samples, output, measure, numClusters,
+ maxIterations);
} else {
- List<Vector> points = new ArrayList<Vector>();
- for (VectorWritable sample : SAMPLE_DATA) {
- points.add(sample.get());
- }
- List<Cluster> initialClusters = new ArrayList<Cluster>();
- int id = 0;
- int numClusters = 4;
- for (Vector point : points) {
- if (initialClusters.size() < Math.min(numClusters, points.size())) {
- initialClusters.add(new org.apache.mahout.clustering.kmeans.Cluster(
- point, id++, measure));
- } else {
- break;
- }
- }
-
- ClusterClassifier prior = new ClusterClassifier(initialClusters);
- ClusteringPolicy policy = new KMeansClusteringPolicy();
- ClusterClassifier posterior = new ClusterIterator(policy).iterate(points,
- prior, 10);
- CLUSTERS.add(posterior.getModels());
+ runSequentialKMeansClassifier(conf, samples, output, measure,
+ numClusters, maxIterations);
}
new DisplayKMeans();
}
+ private static void runSequentialKMeansClassifier(Configuration conf,
+ Path samples, Path output, DistanceMeasure measure, int numClusters,
+ int maxIterations) throws IOException {
+ List<Vector> points = new ArrayList<Vector>();
+ for (int i = 0; i < numClusters; i++) {
+ points.add(SAMPLE_DATA.get(i).get());
+ }
+ List<Cluster> initialClusters = new ArrayList<Cluster>();
+ int id = 0;
+ for (Vector point : points) {
+ initialClusters.add(new org.apache.mahout.clustering.kmeans.Cluster(
+ point, id++, measure));
+ }
+ ClusterClassifier prior = new ClusterClassifier(initialClusters);
+ Path priorClassifier = new Path(output, "clusters-0");
+ writeClassifier(prior, conf, priorClassifier);
+
+ int maxIter = 10;
+ ClusteringPolicy policy = new KMeansClusteringPolicy();
+ new ClusterIterator(policy).iterate(samples, priorClassifier, output,
+ maxIter);
+ for (int i = 1; i <= maxIter; i++) {
+ ClusterClassifier posterior = readClassifier(conf, new Path(output,
+ "classifier-" + i));
+ CLUSTERS.add(posterior.getModels());
+ }
+ }
+
+ private static void runSequentialKMeansClusterer(Configuration conf, Path samples,
+ Path output, DistanceMeasure measure, int numClusters, int maxIterations)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
+ output, "clusters-0"), 3, measure);
+ double distanceThreshold = 0.001;
+ KMeansDriver.run(samples, clusters, output, measure, distanceThreshold,
+ maxIterations, true, true);
+ loadClusters(output);
+ }
+
// Override the paint() method
@Override
public void paint(Graphics g) {