You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2011/01/10 18:30:20 UTC
svn commit: r1057289 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
Author: srowen
Date: Mon Jan 10 17:30:20 2011
New Revision: 1057289
URL: http://svn.apache.org/viewvc?rev=1057289&view=rev
Log:
MAHOUT-564
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java?rev=1057289&r1=1057288&r2=1057289&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java Mon Jan 10 17:30:20 2011
@@ -20,7 +20,6 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
-import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.SequenceFile.Writer;
@@ -44,8 +43,6 @@ public class KMeansClusterer {
/** Distance to use for point to cluster comparison. */
private final DistanceMeasure measure;
- private final double convergenceDelta;
-
/**
* Init the k-means clusterer with the distance measure to use for comparison.
*
@@ -55,17 +52,6 @@ public class KMeansClusterer {
*/
public KMeansClusterer(DistanceMeasure measure) {
this.measure = measure;
- this.convergenceDelta = 0;
- }
-
- public KMeansClusterer(Configuration conf)
- throws ClassNotFoundException, InstantiationException, IllegalAccessException {
- ClassLoader ccl = Thread.currentThread().getContextClassLoader();
- this.measure = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY))
- .asSubclass(DistanceMeasure.class).newInstance();
- this.measure.configure(conf);
-
- this.convergenceDelta = Double.parseDouble(conf.get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
}
/**
@@ -121,7 +107,7 @@ public class KMeansClusterer {
protected boolean testConvergence(Iterable<Cluster> clusters, double distanceThreshold) {
boolean converged = true;
for (Cluster cluster : clusters) {
- if (!computeConvergence(cluster)) {
+ if (!computeConvergence(cluster, distanceThreshold)) {
converged = false;
}
cluster.computeParameters();
@@ -232,8 +218,8 @@ public class KMeansClusterer {
return clusterer.testConvergence(clusters, distanceThreshold);
}
- public boolean computeConvergence(Cluster cluster) {
- return cluster.computeConvergence(measure, convergenceDelta);
+ public boolean computeConvergence(Cluster cluster, double distanceThreshold) {
+ return cluster.computeConvergence(measure, distanceThreshold);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java?rev=1057289&r1=1057288&r2=1057289&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java Mon Jan 10 17:30:20 2011
@@ -32,17 +32,18 @@ import org.apache.mahout.common.distance
public class KMeansReducer extends Reducer<Text, ClusterObservations, Text, Cluster> {
private Map<String, Cluster> clusterMap;
-
+ private double convergenceDelta;
private KMeansClusterer clusterer;
@Override
- protected void reduce(Text key, Iterable<ClusterObservations> values, Context context) throws IOException, InterruptedException {
+ protected void reduce(Text key, Iterable<ClusterObservations> values, Context context)
+ throws IOException, InterruptedException {
Cluster cluster = clusterMap.get(key.toString());
for (ClusterObservations delta : values) {
cluster.observe(delta);
}
// force convergence calculation
- boolean converged = clusterer.computeConvergence(cluster);
+ boolean converged = clusterer.computeConvergence(cluster, convergenceDelta);
if (converged) {
context.getCounter("Clustering", "Converged Clusters").increment(1);
}
@@ -55,7 +56,13 @@ public class KMeansReducer extends Reduc
super.setup(context);
Configuration conf = context.getConfiguration();
try {
- this.clusterer = new KMeansClusterer(conf);
+ ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+ DistanceMeasure measure = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY))
+ .asSubclass(DistanceMeasure.class).newInstance();
+ measure.configure(conf);
+
+ this.convergenceDelta = Double.parseDouble(conf.get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
+ this.clusterer = new KMeansClusterer(measure);
this.clusterMap = new HashMap<String, Cluster>();
String path = conf.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java?rev=1057289&r1=1057288&r2=1057289&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java Mon Jan 10 17:30:20 2011
@@ -18,6 +18,7 @@
package org.apache.mahout.clustering.kmeans;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
@@ -55,8 +56,9 @@ import org.junit.Test;
public final class TestKmeansClustering extends MahoutTestCase {
- public static final double[][] REFERENCE = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 },
- { 5, 5 } };
+ public static final double[][] REFERENCE = {
+ { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 }
+ };
private static final int[][] EXPECTED_NUM_POINTS = { { 9 }, { 4, 5 }, { 4, 4, 1 }, { 1, 2, 1, 5 }, { 1, 1, 1, 2, 4 },
{ 1, 1, 1, 1, 1, 4 }, { 1, 1, 1, 1, 1, 2, 2 }, { 1, 1, 1, 1, 1, 1, 2, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
@@ -91,6 +93,40 @@ public final class TestKmeansClustering
return points;
}
+ /**
+ * Tests {@link KMeansClusterer#runKMeansIteration) single run convergence with a given distance threshold.
+ */
+ @Test
+ public void testRunKMeansIteration_convergesInOneRunWithGivenDistanceThreshold() {
+ double[][] rawPoints = { {0,0}, {0,0.25}, {0,0.75}, {0, 1}};
+ List<Vector> points = getPoints(rawPoints);
+
+ ManhattanDistanceMeasure distanceMeasure = new ManhattanDistanceMeasure();
+ List<Cluster> clusters = Arrays.asList(
+ new Cluster(points.get(0), 0, distanceMeasure),
+ new Cluster(points.get(3), 3, distanceMeasure));
+
+ // To converge in a single run, the given distance threshold should be greater than or equal to 0.125,
+ // since 0.125 will be the distance between center and centroid for the initial two clusters after one run.
+ double distanceThreshold = 0.25;
+
+ boolean converged = KMeansClusterer.runKMeansIteration(
+ points,
+ clusters,
+ distanceMeasure,
+ distanceThreshold);
+
+ Vector cluster1Center = clusters.get(0).getCenter();
+ assertEquals(0, cluster1Center.get(0), EPSILON);
+ assertEquals(0.125, cluster1Center.get(1), EPSILON);
+
+ Vector cluster2Center = clusters.get(1).getCenter();
+ assertEquals(0, cluster2Center.get(0), EPSILON);
+ assertEquals(0.875, cluster2Center.get(1), EPSILON);
+
+ assertTrue("KMeans iteration should be converged after a single run", converged);
+ }
+
/** Story: Test the reference implementation */
@Test
public void testReferenceImplementation() throws Exception {
@@ -274,11 +310,8 @@ public final class TestKmeansClustering
KMeansReducer reducer = new KMeansReducer();
reducer.setup(clusters, measure);
DummyRecordWriter<Text, Cluster> reducerWriter = new DummyRecordWriter<Text, Cluster>();
- Reducer<Text, ClusterObservations, Text, Cluster>.Context reducerContext = DummyRecordWriter.build(reducer,
- conf,
- reducerWriter,
- Text.class,
- ClusterObservations.class);
+ Reducer<Text, ClusterObservations, Text, Cluster>.Context reducerContext =
+ DummyRecordWriter.build(reducer, conf, reducerWriter, Text.class, ClusterObservations.class);
for (Text key : combinerWriter.getKeys()) {
reducer.reduce(new Text(key), combinerWriter.getValue(key), reducerContext);
}
@@ -364,7 +397,8 @@ public final class TestKmeansClustering
Path clusteredPointsPath = new Path(outputPath, "clusteredPoints");
SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(clusteredPointsPath, "part-m-0"), conf);
int[] expect = EXPECTED_NUM_POINTS[k];
- DummyOutputCollector<IntWritable, WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
+ DummyOutputCollector<IntWritable, WeightedVectorWritable> collector =
+ new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
// The key is the clusterId
IntWritable clusterId = new IntWritable(0);
// The value is the weighted vector
@@ -421,7 +455,8 @@ public final class TestKmeansClustering
// assertEquals("output dir files?", 4, outFiles.length);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(clusteredPointsPath, "part-m-00000"), conf);
int[] expect = EXPECTED_NUM_POINTS[k];
- DummyOutputCollector<IntWritable, WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
+ DummyOutputCollector<IntWritable, WeightedVectorWritable> collector =
+ new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
// The key is the clusterId
IntWritable clusterId = new IntWritable(0);
// The value is the weighted vector
@@ -467,7 +502,8 @@ public final class TestKmeansClustering
// now compare the expected clusters with actual
Path clusteredPointsPath = new Path(outputPath, "clusteredPoints");
- DummyOutputCollector<IntWritable, WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
+ DummyOutputCollector<IntWritable, WeightedVectorWritable> collector =
+ new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(clusteredPointsPath, "part-m-00000"), conf);
// The key is the clusterId