You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by lu...@apache.org on 2011/04/04 20:32:52 UTC

svn commit: r1088702 - in /commons/proper/math/trunk/src: main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java site/xdoc/changes.xml

Author: luc
Date: Mon Apr  4 18:32:52 2011
New Revision: 1088702

URL: http://svn.apache.org/viewvc?rev=1088702&view=rev
Log:
Improved robustness of k-means++ algorithm, by tracking changes in points assignments to clusters

Modified:
    commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java
    commons/proper/math/trunk/src/site/xdoc/changes.xml

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java?rev=1088702&r1=1088701&r2=1088702&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java (original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/clustering/KMeansPlusPlusClusterer.java Mon Apr  4 18:32:52 2011
@@ -108,12 +108,16 @@ public class KMeansPlusPlusClusterer<T e
 
         // create the initial clusters
         List<Cluster<T>> clusters = chooseInitialCenters(points, k, random);
-        assignPointsToClusters(clusters, points);
+        
+        // create an array containing the latest assignment of a point to a cluster
+        // no need to initialize the array, as it will be filled with the first assignment
+        int[] assignments = new int[points.size()];
+        assignPointsToClusters(clusters, points, assignments);
 
         // iterate through updating the centers until we're done
         final int max = (maxIterations < 0) ? Integer.MAX_VALUE : maxIterations;
         for (int count = 0; count < max; count++) {
-            boolean clusteringChanged = false;
+            boolean emptyCluster = false;
             List<Cluster<T>> newClusters = new ArrayList<Cluster<T>>();
             for (final Cluster<T> cluster : clusters) {
                 final T newCenter;
@@ -131,20 +135,20 @@ public class KMeansPlusPlusClusterer<T e
                         default :
                             throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
                     }
-                    clusteringChanged = true;
+                    emptyCluster = true;
                 } else {
                     newCenter = cluster.getCenter().centroidOf(cluster.getPoints());
-                    if (!newCenter.equals(cluster.getCenter())) {
-                        clusteringChanged = true;
-                    }
                 }
                 newClusters.add(new Cluster<T>(newCenter));
             }
-            if (!clusteringChanged) {
+            int changes = assignPointsToClusters(newClusters, points, assignments);
+            clusters = newClusters;
+            
+            // if there were no more changes in the point-to-cluster assignment
+            // and there are no empty clusters left, return the current clusters
+            if (changes == 0 && !emptyCluster) {
                 return clusters;
             }
-            assignPointsToClusters(newClusters, points);
-            clusters = newClusters;
         }
         return clusters;
     }
@@ -155,13 +159,25 @@ public class KMeansPlusPlusClusterer<T e
      * @param <T> type of the points to cluster
      * @param clusters the {@link Cluster}s to add the points to
      * @param points the points to add to the given {@link Cluster}s
+     * @return the number of points assigned to different clusters as the iteration before
      */
-    private static <T extends Clusterable<T>> void
-        assignPointsToClusters(final Collection<Cluster<T>> clusters, final Collection<T> points) {
+    private static <T extends Clusterable<T>> int
+        assignPointsToClusters(final List<Cluster<T>> clusters, final Collection<T> points, 
+                               final int[] assignments) {
+        int assignedDifferently = 0;
+        int pointIndex = 0;
         for (final T p : points) {
-            Cluster<T> cluster = getNearestCluster(clusters, p);
+            int clusterIndex = getNearestCluster(clusters, p);
+            if (clusterIndex != assignments[pointIndex]) {
+                assignedDifferently++;
+            }
+            
+            Cluster<T> cluster = clusters.get(clusterIndex);
             cluster.addPoint(p);
+            assignments[pointIndex++] = clusterIndex;
         }
+        
+        return assignedDifferently;
     }
 
     /**
@@ -190,7 +206,8 @@ public class KMeansPlusPlusClusterer<T e
             double sum = 0;
             for (int i = 0; i < pointSet.size(); i++) {
                 final T p = pointSet.get(i);
-                final Cluster<T> nearest = getNearestCluster(resultSet, p);
+                int nearestClusterIndex = getNearestCluster(resultSet, p);
+                final Cluster<T> nearest = resultSet.get(nearestClusterIndex);
                 final double d = p.distanceFrom(nearest.getCenter());
                 sum += d * d;
                 dx2[i] = sum;
@@ -329,18 +346,20 @@ public class KMeansPlusPlusClusterer<T e
      * @param <T> type of the points to cluster
      * @param clusters the {@link Cluster}s to search
      * @param point the point to find the nearest {@link Cluster} for
-     * @return the nearest {@link Cluster} to the given point
+     * @return the index of the nearest {@link Cluster} to the given point
      */
-    private static <T extends Clusterable<T>> Cluster<T>
+    private static <T extends Clusterable<T>> int
         getNearestCluster(final Collection<Cluster<T>> clusters, final T point) {
         double minDistance = Double.MAX_VALUE;
-        Cluster<T> minCluster = null;
+        int clusterIndex = 0;
+        int minCluster = 0;
         for (final Cluster<T> c : clusters) {
             final double distance = point.distanceFrom(c.getCenter());
             if (distance < minDistance) {
                 minDistance = distance;
-                minCluster = c;
+                minCluster = clusterIndex;
             }
+            clusterIndex++;
         }
         return minCluster;
     }

Modified: commons/proper/math/trunk/src/site/xdoc/changes.xml
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/site/xdoc/changes.xml?rev=1088702&r1=1088701&r2=1088702&view=diff
==============================================================================
--- commons/proper/math/trunk/src/site/xdoc/changes.xml (original)
+++ commons/proper/math/trunk/src/site/xdoc/changes.xml Mon Apr  4 18:32:52 2011
@@ -52,6 +52,10 @@ The <action> type attribute can be add,u
     If the output is not quite correct, check for invisible trailing spaces!
      -->
     <release version="3.0" date="TBD" description="TBD">
+      <action dev="luc" type="fix" issue="MATH-547" due-to="Thomas Neidhart">
+        Improved robustness of k-means++ algorithm, by tracking changes in points assignments
+        to clusters.
+      </action>
       <action dev="psteitz" type="update" issue="MATH-555">
         Changed MathUtils.round(double,int,int) to propagate rather than
         wrap runtime exceptions.  Instead of MathRuntimeException, this method