You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2011/01/10 18:30:20 UTC

svn commit: r1057289 - in /mahout/trunk/core/src: main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java

Author: srowen
Date: Mon Jan 10 17:30:20 2011
New Revision: 1057289

URL: http://svn.apache.org/viewvc?rev=1057289&view=rev
Log:
MAHOUT-564

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java?rev=1057289&r1=1057288&r2=1057289&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java Mon Jan 10 17:30:20 2011
@@ -20,7 +20,6 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 
-import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.SequenceFile.Writer;
@@ -44,8 +43,6 @@ public class KMeansClusterer {
   /** Distance to use for point to cluster comparison. */
   private final DistanceMeasure measure;
 
-  private final double convergenceDelta;
-
   /**
    * Init the k-means clusterer with the distance measure to use for comparison.
    * 
@@ -55,17 +52,6 @@ public class KMeansClusterer {
    */
   public KMeansClusterer(DistanceMeasure measure) {
     this.measure = measure;
-    this.convergenceDelta = 0;
-  }
-
-  public KMeansClusterer(Configuration conf)
-      throws ClassNotFoundException, InstantiationException, IllegalAccessException {
-    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
-    this.measure = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY))
-        .asSubclass(DistanceMeasure.class).newInstance();
-    this.measure.configure(conf);
-
-    this.convergenceDelta = Double.parseDouble(conf.get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
   }
 
   /**
@@ -121,7 +107,7 @@ public class KMeansClusterer {
   protected boolean testConvergence(Iterable<Cluster> clusters, double distanceThreshold) {
     boolean converged = true;
     for (Cluster cluster : clusters) {
-      if (!computeConvergence(cluster)) {
+      if (!computeConvergence(cluster, distanceThreshold)) {
         converged = false;
       }
       cluster.computeParameters();
@@ -232,8 +218,8 @@ public class KMeansClusterer {
     return clusterer.testConvergence(clusters, distanceThreshold);
   }
 
-  public boolean computeConvergence(Cluster cluster) {
-    return cluster.computeConvergence(measure, convergenceDelta);
+  public boolean computeConvergence(Cluster cluster, double distanceThreshold) {
+    return cluster.computeConvergence(measure, distanceThreshold);
   }
 
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java?rev=1057289&r1=1057288&r2=1057289&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java Mon Jan 10 17:30:20 2011
@@ -32,17 +32,18 @@ import org.apache.mahout.common.distance
 public class KMeansReducer extends Reducer<Text, ClusterObservations, Text, Cluster> {
 
   private Map<String, Cluster> clusterMap;
-
+  private double convergenceDelta;
   private KMeansClusterer clusterer;
 
   @Override
-  protected void reduce(Text key, Iterable<ClusterObservations> values, Context context) throws IOException, InterruptedException {
+  protected void reduce(Text key, Iterable<ClusterObservations> values, Context context)
+    throws IOException, InterruptedException {
     Cluster cluster = clusterMap.get(key.toString());
     for (ClusterObservations delta : values) {
       cluster.observe(delta);
     }
     // force convergence calculation
-    boolean converged = clusterer.computeConvergence(cluster);
+    boolean converged = clusterer.computeConvergence(cluster, convergenceDelta);
     if (converged) {
       context.getCounter("Clustering", "Converged Clusters").increment(1);
     }
@@ -55,7 +56,13 @@ public class KMeansReducer extends Reduc
     super.setup(context);
     Configuration conf = context.getConfiguration();
     try {
-      this.clusterer = new KMeansClusterer(conf);
+      ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+      DistanceMeasure measure = ccl.loadClass(conf.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY))
+          .asSubclass(DistanceMeasure.class).newInstance();
+      measure.configure(conf);
+
+      this.convergenceDelta = Double.parseDouble(conf.get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
+      this.clusterer = new KMeansClusterer(measure);
       this.clusterMap = new HashMap<String, Cluster>();
 
       String path = conf.get(KMeansConfigKeys.CLUSTER_PATH_KEY);

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java?rev=1057289&r1=1057288&r2=1057289&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java Mon Jan 10 17:30:20 2011
@@ -18,6 +18,7 @@
 package org.apache.mahout.clustering.kmeans;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
@@ -55,8 +56,9 @@ import org.junit.Test;
 
 public final class TestKmeansClustering extends MahoutTestCase {
 
-  public static final double[][] REFERENCE = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 },
-      { 5, 5 } };
+  public static final double[][] REFERENCE = {
+      { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 }
+  };
 
   private static final int[][] EXPECTED_NUM_POINTS = { { 9 }, { 4, 5 }, { 4, 4, 1 }, { 1, 2, 1, 5 }, { 1, 1, 1, 2, 4 },
       { 1, 1, 1, 1, 1, 4 }, { 1, 1, 1, 1, 1, 2, 2 }, { 1, 1, 1, 1, 1, 1, 2, 1 }, { 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
@@ -91,6 +93,40 @@ public final class TestKmeansClustering 
     return points;
   }
 
+  /**
+   * Tests {@link KMeansClusterer#runKMeansIteration) single run convergence with a given distance threshold.
+   */
+  @Test
+  public void testRunKMeansIteration_convergesInOneRunWithGivenDistanceThreshold() {
+    double[][] rawPoints = { {0,0}, {0,0.25}, {0,0.75}, {0, 1}};
+    List<Vector> points = getPoints(rawPoints);
+
+    ManhattanDistanceMeasure distanceMeasure = new ManhattanDistanceMeasure();
+    List<Cluster> clusters = Arrays.asList(
+        new Cluster(points.get(0), 0, distanceMeasure),
+        new Cluster(points.get(3), 3, distanceMeasure));
+
+    // To converge in a single run, the given distance threshold should be greater than or equal to 0.125,
+    // since 0.125 will be the distance between center and centroid for the initial two clusters after one run.
+    double distanceThreshold = 0.25;
+
+    boolean converged = KMeansClusterer.runKMeansIteration(
+            points,
+            clusters,
+            distanceMeasure,
+            distanceThreshold);
+
+    Vector cluster1Center = clusters.get(0).getCenter();
+    assertEquals(0, cluster1Center.get(0), EPSILON);
+    assertEquals(0.125, cluster1Center.get(1), EPSILON);
+
+    Vector cluster2Center = clusters.get(1).getCenter();
+    assertEquals(0, cluster2Center.get(0), EPSILON);
+    assertEquals(0.875, cluster2Center.get(1), EPSILON);
+
+    assertTrue("KMeans iteration should be converged after a single run", converged);
+  }
+
   /** Story: Test the reference implementation */
   @Test
   public void testReferenceImplementation() throws Exception {
@@ -274,11 +310,8 @@ public final class TestKmeansClustering 
       KMeansReducer reducer = new KMeansReducer();
       reducer.setup(clusters, measure);
       DummyRecordWriter<Text, Cluster> reducerWriter = new DummyRecordWriter<Text, Cluster>();
-      Reducer<Text, ClusterObservations, Text, Cluster>.Context reducerContext = DummyRecordWriter.build(reducer,
-                                                                                                         conf,
-                                                                                                         reducerWriter,
-                                                                                                         Text.class,
-                                                                                                         ClusterObservations.class);
+      Reducer<Text, ClusterObservations, Text, Cluster>.Context reducerContext =
+          DummyRecordWriter.build(reducer, conf, reducerWriter, Text.class, ClusterObservations.class);
       for (Text key : combinerWriter.getKeys()) {
         reducer.reduce(new Text(key), combinerWriter.getValue(key), reducerContext);
       }
@@ -364,7 +397,8 @@ public final class TestKmeansClustering 
       Path clusteredPointsPath = new Path(outputPath, "clusteredPoints");
       SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(clusteredPointsPath, "part-m-0"), conf);
       int[] expect = EXPECTED_NUM_POINTS[k];
-      DummyOutputCollector<IntWritable, WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
+      DummyOutputCollector<IntWritable, WeightedVectorWritable> collector =
+          new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
       // The key is the clusterId
       IntWritable clusterId = new IntWritable(0);
       // The value is the weighted vector
@@ -421,7 +455,8 @@ public final class TestKmeansClustering 
       // assertEquals("output dir files?", 4, outFiles.length);
       SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(clusteredPointsPath, "part-m-00000"), conf);
       int[] expect = EXPECTED_NUM_POINTS[k];
-      DummyOutputCollector<IntWritable, WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
+      DummyOutputCollector<IntWritable, WeightedVectorWritable> collector =
+          new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
       // The key is the clusterId
       IntWritable clusterId = new IntWritable(0);
       // The value is the weighted vector
@@ -467,7 +502,8 @@ public final class TestKmeansClustering 
 
     // now compare the expected clusters with actual
     Path clusteredPointsPath = new Path(outputPath, "clusteredPoints");
-    DummyOutputCollector<IntWritable, WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
+    DummyOutputCollector<IntWritable, WeightedVectorWritable> collector =
+        new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
     SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(clusteredPointsPath, "part-m-00000"), conf);
 
     // The key is the clusterId