You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2010/07/08 20:56:58 UTC

svn commit: r961880 - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/meanshift/ core/src/test/java/org/apache/mahout/clustering/meanshift/ examples/src/main/java/org/apache/mahout/clustering/meanshift/

Author: jeastman
Date: Thu Jul  8 18:56:58 2010
New Revision: 961880

URL: http://svn.apache.org/viewvc?rev=961880&view=rev
Log:
Fixed subtle refactoring bug in MeanShiftCanopyClusterer which was causing DisplayMeanShift example to produce garbage. Added a unit test thereof. All tests run and example produces nice clustering again.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java?rev=961880&r1=961879&r2=961880&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java Thu Jul  8 18:56:58 2010
@@ -150,13 +150,6 @@ public class MeanShiftCanopy extends Clu
     return (converged ? "V-" : "C-") + getId();
   }
   
-  void init(MeanShiftCanopy canopy) {
-    setId(canopy.getId());
-    setCenter(canopy.getCenter());
-    addPoints(getCenter(), 1);
-    boundPoints.addAllOf(canopy.getBoundPoints());
-  }
-  
   public boolean isConverged() {
     return converged;
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java?rev=961880&r1=961879&r2=961880&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java Thu Jul  8 18:56:58 2010
@@ -18,11 +18,12 @@
 package org.apache.mahout.clustering.meanshift;
 
 import java.util.ArrayList;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Set;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
 import org.apache.mahout.math.Vector;
 
 public class MeanShiftCanopyClusterer {
@@ -42,8 +43,8 @@ public class MeanShiftCanopyClusterer {
 
   public MeanShiftCanopyClusterer(Configuration configuration) {
     try {
-      measure = Class.forName(configuration.get(MeanShiftCanopyConfigKeys.DISTANCE_MEASURE_KEY))
-          .asSubclass(DistanceMeasure.class).newInstance();
+      measure = Class.forName(configuration.get(MeanShiftCanopyConfigKeys.DISTANCE_MEASURE_KEY)).asSubclass(DistanceMeasure.class)
+          .newInstance();
       measure.configure(configuration);
     } catch (ClassNotFoundException e) {
       throw new IllegalStateException(e);
@@ -93,11 +94,9 @@ public class MeanShiftCanopyClusterer {
       if (norm < t1) {
         aCanopy.touch(canopy);
       }
-      if (norm < t2) {
-        if ((closestCoveringCanopy == null) || (norm < closestNorm)) {
-          closestNorm = norm;
-          closestCoveringCanopy = canopy;
-        }
+      if (norm < t2 && ((closestCoveringCanopy == null) || (norm < closestNorm))) {
+        closestNorm = norm;
+        closestCoveringCanopy = canopy;
       }
     }
     if (closestCoveringCanopy == null) {
@@ -150,29 +149,6 @@ public class MeanShiftCanopyClusterer {
   }
 
   /**
-   * Story: User can exercise the reference implementation to verify that the test datapoints are clustered in
-   * a reasonable manner.
-   */
-  public static void testReferenceImplementation() {
-    MeanShiftCanopyClusterer clusterer = new MeanShiftCanopyClusterer(new EuclideanDistanceMeasure(), 4.0, 1.0, 0.5);
-    List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
-    // add all points to the canopies
-
-    boolean done = false;
-    //int iter = 1;
-    while (!done) { // shift canopies to their centroids
-      done = true;
-      List<MeanShiftCanopy> migratedCanopies = new ArrayList<MeanShiftCanopy>();
-      for (MeanShiftCanopy canopy : canopies) {
-        done = clusterer.shiftToMean(canopy) && done;
-        clusterer.mergeCanopy(canopy, migratedCanopies);
-      }
-      canopies = migratedCanopies;
-      //System.out.println(iter++);
-    }
-  }
-
-  /**
    * This is the reference mean-shift implementation. Given its inputs it iterates over the points and
    * clusters until their centers converge or until the maximum number of iterations is exceeded.
    * 
@@ -186,7 +162,9 @@ public class MeanShiftCanopyClusterer {
   public static List<MeanShiftCanopy> clusterPoints(List<Vector> points,
                                                     DistanceMeasure measure,
                                                     double convergenceThreshold,
-                                                    double t1, double t2, int numIter) {
+                                                    double t1,
+                                                    double t2,
+                                                    int numIter) {
     MeanShiftCanopyClusterer clusterer = new MeanShiftCanopyClusterer(measure, t1, t2, convergenceThreshold);
 
     List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
@@ -197,29 +175,32 @@ public class MeanShiftCanopyClusterer {
 
     boolean converged = false;
     for (int iter = 0; !converged && iter < numIter; iter++) {
-      converged = runMeanShiftCanopyIteration(canopies, clusterer);
+      converged = true;
+      List<MeanShiftCanopy> migratedCanopies = new ArrayList<MeanShiftCanopy>();
+      for (MeanShiftCanopy canopy : canopies) {
+        converged = clusterer.shiftToMean(canopy) && converged;
+        clusterer.mergeCanopy(canopy, migratedCanopies);
+      }
+      canopies = migratedCanopies;
+
+      //verifyNonOverlap(canopies); useful for debugging
     }
     return canopies;
   }
 
-  /**
-   * Perform a single iteration over the points and clusters, assigning points to clusters and returning if
-   * the iterations are completed.
-   * 
-   * @param canopies
-   *          the List<MeanShiftCanopy> clusters
-   */
-  public static boolean runMeanShiftCanopyIteration(List<MeanShiftCanopy> canopies,
-                                                    MeanShiftCanopyClusterer clusterer) {
-    boolean converged = true;
-    List<MeanShiftCanopy> migratedCanopies = new ArrayList<MeanShiftCanopy>();
+   @SuppressWarnings("unused")
+  private static void verifyNonOverlap(List<MeanShiftCanopy> canopies) {
+    Set<Integer> coveredPoints = new HashSet<Integer>();
+    // verify no overlap
     for (MeanShiftCanopy canopy : canopies) {
-      converged = clusterer.shiftToMean(canopy) && converged;
-      clusterer.mergeCanopy(canopy, migratedCanopies);
+      for (int v : canopy.getBoundPoints().toList())
+        if (coveredPoints.contains(v))
+          System.out.println("Duplicate bound point: " + v + " in Canopy: " + canopy.asFormatString(null));
+        else {
+          coveredPoints.add(v);
+          //System.out.println("Added bound point: " + v + " to Canopy: " + canopy.asFormatString(null));
+        }
     }
-    //canopies = migratedCanopies;
-    return converged;
-
   }
 
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java?rev=961880&r1=961879&r2=961880&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java Thu Jul  8 18:56:58 2010
@@ -69,7 +69,7 @@ public class TestMeanShift extends Mahou
     }
     for (MeanShiftCanopy canopy : canopies) {
       int ch = 'A' + canopy.getCanopyId();
-      for (int pid : canopy.getBoundPoints().elements()) {
+      for (int pid : canopy.getBoundPoints().toList()) {
         Vector pt = raw[pid];
         out[(int) pt.getQuick(0)][(int) pt.getQuick(1)] = (char) ch;
       }
@@ -136,6 +136,18 @@ public class TestMeanShift extends Mahou
     }
     assertTrue(true);
   }
+  
+  /**
+   * Test the MeanShiftCanopyClusterer's reference implementation. Should produce the same final output as above.
+   */
+  public void testClustererReferenceImplementation() {
+    List<Vector> points = new ArrayList<Vector>();
+    for (Vector v: raw)
+      points.add(v);
+    List<MeanShiftCanopy> canopies = MeanShiftCanopyClusterer.clusterPoints(points, new EuclideanDistanceMeasure(), 0.5, 4, 1, 10);
+    printCanopies(canopies);
+    printImage(canopies);
+  }
 
   /**
    * Story: User can produce initial canopy centers using a EuclideanDistanceMeasure and a
@@ -193,7 +205,7 @@ public class TestMeanShift extends Mahou
       MeanShiftCanopy canopy = canopyMap.get((ref.isConverged() ? "V-" : "C-") + ref.getCanopyId());
       assertEquals("ids", ref.getCanopyId(), canopy.getCanopyId());
       assertEquals("centers(" + ref.getIdentifier() + ')', ref.getCenter().asFormatString(), canopy.getCenter().asFormatString());
-      assertEquals("bound points", ref.getBoundPoints().size(), canopy.getBoundPoints().size());
+      assertEquals("bound points", ref.getBoundPoints().toList().size(), canopy.getBoundPoints().toList().size());
     }
   }
 
@@ -278,7 +290,7 @@ public class TestMeanShift extends Mahou
       String refCenter = refCanopy.getCenter().asFormatString();
       String reducerCenter = reducerCanopy.getCenter().asFormatString();
       assertEquals("centers(" + mapEntry.getKey() + ')', refCenter, reducerCenter);
-      assertEquals("bound points", refCanopy.getBoundPoints().size(), reducerCanopy.getBoundPoints().size());
+      assertEquals("bound points", refCanopy.getBoundPoints().toList().size(), reducerCanopy.getBoundPoints().toList().size());
     }
   }
 

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java?rev=961880&r1=961879&r2=961880&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java Thu Jul  8 18:56:58 2010
@@ -41,11 +41,12 @@ final class DisplayMeanShift extends Dis
   private static List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
 
   private static double t1;
+
   private static double t2;
 
   private DisplayMeanShift() {
     initialize();
-    this.setTitle("Canopy Clusters (> 1.5% of population)");
+    this.setTitle("MeanShiftCanopy Clusters (> 1.5% of population)");
   }
 
   @Override
@@ -72,11 +73,17 @@ final class DisplayMeanShift extends Dis
     for (MeanShiftCanopy canopy : canopies) {
       if (canopy.getBoundPoints().toList().size() > 0.015 * DisplayDirichlet.SAMPLE_DATA.size()) {
         g2.setColor(COLORS[Math.min(i++, DisplayDirichlet.COLORS.length - 1)]);
-        for (int v : canopy.getBoundPoints().elements()) {
-          DisplayDirichlet.plotRectangle(g2, SAMPLE_DATA.get(v).get(), dv);
+        int count = 0;
+        Vector center = new DenseVector(2);
+        for (int vix : canopy.getBoundPoints().toList()) {
+          Vector v = SAMPLE_DATA.get(vix).get();
+          count++;
+          v.addTo(center);
+          DisplayDirichlet.plotRectangle(g2, v, dv);
         }
-        DisplayDirichlet.plotEllipse(g2, canopy.getCenter(), dv1);
-        DisplayDirichlet.plotEllipse(g2, canopy.getCenter(), dv2);
+        center = center.divide(count);
+        DisplayDirichlet.plotEllipse(g2, center, dv1);
+        DisplayDirichlet.plotEllipse(g2, center, dv2);
       }
     }
   }