You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/03/01 06:42:36 UTC

svn commit: r917396 [1/3] - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/clustering/canopy/ core/src/main/java/org/apache/mahout/clustering/dirichlet/ core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/ core/src/main/java/org/...

Author: robinanil
Date: Mon Mar  1 05:42:35 2010
New Revision: 917396

URL: http://svn.apache.org/viewvc?rev=917396&view=rev
Log:
MAHOUT-307 MAHOUT-295 MAHOUT-304 clustering improvements

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyCreatorMapper.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/canopy/DisplayCanopy.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/kmeans/DisplayKMeans.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/meanshift/OutputMapper.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java Mon Mar  1 05:42:35 2010
@@ -18,6 +18,8 @@
 package org.apache.mahout.clustering.canopy;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
 import java.util.List;
 
 import org.apache.hadoop.io.Text;
@@ -174,7 +176,7 @@
         isCovered = true;
         vw.set(point);
         collector.collect(new Text(canopy.getIdentifier()), vw);
-        reporter.setStatus("Emit Canopy ID:" +canopy.getIdentifier());
+        reporter.setStatus("Emit Canopy ID:" + canopy.getIdentifier());
       } else if (dist < minDist) {
         minDist = dist;
         closest = canopy;
@@ -199,4 +201,78 @@
   public boolean canopyCovers(Canopy canopy, Vector point) {
     return measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point) < t1;
   }
+  
+  /**
+   * Iterate through the points, adding new canopies. Return the canopies.
+   * 
+   * @param points
+   *          a list<Vector> defining the points to be clustered
+   * @param measure
+   *          a DistanceMeasure to use
+   * @param t1
+   *          the T1 distance threshold
+   * @param t2
+   *          the T2 distance threshold
+   * @return the List<Canopy> created
+   */
+  public static List<Canopy> createCanopies(List<Vector> points, DistanceMeasure measure, double t1, double t2) {
+    List<Canopy> canopies = new ArrayList<Canopy>();
+    /**
+     * Reference Implementation: Given a distance metric, one can create canopies as follows: Start with a
+     * list of the data points in any order, and with two distance thresholds, T1 and T2, where T1 > T2.
+     * (These thresholds can be set by the user, or selected by cross-validation.) Pick a point on the list
+     * and measure its distance to all other points. Put all points that are within distance threshold T1 into
+     * a canopy. Remove from the list all points that are within distance threshold T2. Repeat until the list
+     * is empty.
+     */
+    int nextCanopyId = 0;
+    while (!points.isEmpty()) {
+      Iterator<Vector> ptIter = points.iterator();
+      Vector p1 = ptIter.next();
+      ptIter.remove();
+      Canopy canopy = new Canopy(p1, nextCanopyId++);
+      canopies.add(canopy);
+      while (ptIter.hasNext()) {
+        Vector p2 = ptIter.next();
+        double dist = measure.distance(p1, p2);
+        // Put all points that are within distance threshold T1 into the canopy
+        if (dist < t1) {
+          canopy.addPoint(p2);
+        }
+        // Remove from the list all points that are within distance threshold T2
+        if (dist < t2) {
+          ptIter.remove();
+        }
+      }
+    }
+    return canopies;
+  }
+  
+  /**
+   * Iterate through the canopies, adding their centroids to a list
+   * 
+   * @param canopies
+   *          a List<Canopy>
+   * @return the List<Vector>
+   */
+  public static List<Vector> calculateCentroids(List<Canopy> canopies) {
+    List<Vector> result = new ArrayList<Vector>();
+    for (Canopy canopy : canopies) {
+      result.add(canopy.computeCentroid());
+    }
+    return result;
+  }
+  
+  /**
+   * Iterate through the canopies, resetting their center to their centroids
+   * 
+   * @param canopies
+   *          a List<Canopy>
+   */
+  public static void updateCentroids(List<Canopy> canopies) {
+    for (Canopy canopy : canopies) {
+      canopy.setCenter(canopy.computeCentroid());
+    }
+  }
+  
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java Mon Mar  1 05:42:35 2010
@@ -125,7 +125,7 @@
     this.thin = thin;
     this.burnin = burnin;
     this.numClusters = numClusters;
-    state = new DirichletState<O>(modelFactory, numClusters, alpha_0, thin, burnin);
+    state = new DirichletState<O>(modelFactory, numClusters, alpha_0);
   }
   
   /**
@@ -198,4 +198,34 @@
     return pi;
   }
   
+  /**
+   * Create a new instance on the sample data with the given additional parameters
+   * 
+   * @param points
+   *          the observed data to be clustered
+   * @param modelFactory
+   *          the ModelDistribution to use
+   * @param alpha_0
+   *          the double value for the beta distributions
+   * @param numClusters
+   *          the int number of clusters
+   * @param thin
+   *          the int thinning interval, used to report every n iterations
+   * @param burnin
+   *          the int burnin interval, used to suppress early iterations
+   * @param numIterations
+   *          number of iterations to be performed
+   */
+  public static List<Model<Vector>[]> clusterPoints(List<Vector> points,
+                                                    ModelDistribution<Vector> modelFactory,
+                                                    double alpha_0,
+                                                    int numClusters,
+                                                    int thin,
+                                                    int burnin,
+                                                    int numIterations) {
+    DirichletClusterer<Vector> clusterer = new DirichletClusterer<Vector>(points, modelFactory, alpha_0,
+        numClusters, thin, burnin);
+    return clusterer.cluster(numIterations);
+    
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Mon Mar  1 05:42:35 2010
@@ -65,7 +65,7 @@
   
   private static final Logger log = LoggerFactory.getLogger(DirichletDriver.class);
   
-  private DirichletDriver() { }
+  private DirichletDriver() {}
   
   public static void main(String[] args) throws Exception {
     DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
@@ -84,15 +84,18 @@
     
     Option modelOpt = obuilder.withLongName("modelClass").withRequired(true).withShortName("d").withArgument(
       abuilder.withName("modelClass").withMinimum(1).withMaximum(1).create()).withDescription(
-      "The ModelDistribution class name.").create();
+      "The ModelDistribution class name. "
+          + "Defaults to org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution").create();
     
-    Option prototypeOpt = obuilder.withLongName("modelPrototypeClass").withRequired(true).withShortName("p")
+    Option prototypeOpt = obuilder.withLongName("modelPrototypeClass").withRequired(false).withShortName("p")
         .withArgument(abuilder.withName("prototypeClass").withMinimum(1).withMaximum(1).create())
-        .withDescription("The ModelDistribution prototype Vector class name.").create();
+        .withDescription(
+          "The ModelDistribution prototype Vector class name. "
+              + "Defaults to org.apache.mahout.math.RandomAccessSparseVector").create();
     
     Option sizeOpt = obuilder.withLongName("prototypeSize").withRequired(true).withShortName("s")
         .withArgument(abuilder.withName("prototypeSize").withMinimum(1).withMaximum(1).create())
-        .withDescription("The ModelDistribution prototype Vector size.").create();
+        .withDescription("The ModelDistribution prototype Vector size. ").create();
     
     Option numRedOpt = obuilder.withLongName("maxRed").withRequired(true).withShortName("r").withArgument(
       abuilder.withName("maxRed").withMinimum(1).withMaximum(1).create()).withDescription(
@@ -113,15 +116,17 @@
       
       String input = cmdLine.getValue(inputOpt).toString();
       String output = cmdLine.getValue(outputOpt).toString();
-      String modelFactory = cmdLine.getValue(modelOpt).toString();
-      String modelPrototype = cmdLine.getValue(prototypeOpt).toString();
+      String modelFactory = "org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution";
+      if (cmdLine.hasOption(modelOpt)) modelFactory = cmdLine.getValue(modelOpt).toString();
+      String modelPrototype = "org.apache.mahout.math.RandomAccessSparseVector";
+      if (cmdLine.hasOption(prototypeOpt)) modelPrototype = cmdLine.getValue(prototypeOpt).toString();
       int prototypeSize = Integer.parseInt(cmdLine.getValue(sizeOpt).toString());
       int numReducers = Integer.parseInt(cmdLine.getValue(numRedOpt).toString());
       int numModels = Integer.parseInt(cmdLine.getValue(topicsOpt).toString());
       int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
       double alpha_0 = Double.parseDouble(cmdLine.getValue(mOpt).toString());
-      runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels,
-        maxIterations, alpha_0, numReducers);
+      runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels, maxIterations, alpha_0,
+        numReducers);
     } catch (OptionException e) {
       log.error("Exception parsing command line: ", e);
       CommandLineUtil.printHelp(group);
@@ -161,8 +166,8 @@
                                             SecurityException,
                                             NoSuchMethodException,
                                             InvocationTargetException {
-    runJob(input, output, modelFactory, "org.apache.mahout.math.DenseVector", 2, numClusters,
-      maxIterations, alpha_0, numReducers);
+    runJob(input, output, modelFactory, "org.apache.mahout.math.DenseVector", 2, numClusters, maxIterations,
+      alpha_0, numReducers);
   }
   
   /**
@@ -200,15 +205,14 @@
                                             InvocationTargetException {
     
     String stateIn = output + "/state-0";
-    writeInitialState(output, stateIn, modelFactory, modelPrototype, prototypeSize,
-      numClusters, alpha_0);
+    writeInitialState(output, stateIn, modelFactory, modelPrototype, prototypeSize, numClusters, alpha_0);
     
     for (int iteration = 0; iteration < maxIterations; iteration++) {
       log.info("Iteration {}", iteration);
       // point the output to a new directory per iteration
       String stateOut = output + "/state-" + (iteration + 1);
-      runIteration(input, stateIn, stateOut, modelFactory, modelPrototype, prototypeSize,
-        numClusters, alpha_0, numReducers);
+      runIteration(input, stateIn, stateOut, modelFactory, modelPrototype, prototypeSize, numClusters,
+        alpha_0, numReducers);
       // now point the input to the old output directory
       stateIn = stateOut;
     }
@@ -228,8 +232,8 @@
                                                        NoSuchMethodException,
                                                        InvocationTargetException {
     
-    DirichletState<VectorWritable> state = createState(modelFactory, modelPrototype,
-      prototypeSize, numModels, alpha_0);
+    DirichletState<VectorWritable> state = createState(modelFactory, modelPrototype, prototypeSize,
+      numModels, alpha_0);
     JobConf job = new JobConf(KMeansDriver.class);
     Path outPath = new Path(output);
     FileSystem fs = FileSystem.get(outPath.toUri(), job);
@@ -278,7 +282,7 @@
     Class<? extends Vector> vcl = ccl.loadClass(modelPrototype).asSubclass(Vector.class);
     Constructor<? extends Vector> v = vcl.getConstructor(int.class);
     factory.setModelPrototype(new VectorWritable(v.newInstance(prototypeSize)));
-    return new DirichletState<VectorWritable>(factory, numModels, alpha_0, 1, 1);
+    return new DirichletState<VectorWritable>(factory, numModels, alpha_0);
   }
   
   /**

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java Mon Mar  1 05:42:35 2010
@@ -41,7 +41,7 @@
   
   private static final Logger log = LoggerFactory.getLogger(DirichletJob.class);
   
-  private DirichletJob() { }
+  private DirichletJob() {}
   
   public static void main(String[] args) throws Exception {
     DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
@@ -58,13 +58,17 @@
       abuilder.withName("alpha").withMinimum(1).withMaximum(1).create()).withDescription(
       "The alpha0 value for the DirichletDistribution.").create();
     
-    Option modelOpt = obuilder.withLongName("modelClass").withRequired(true).withShortName("d").withArgument(
-      abuilder.withName("modelClass").withMinimum(1).withMaximum(1).create()).withDescription(
-      "The ModelDistribution class name.").create();
+    Option modelOpt = obuilder.withLongName("modelClass").withRequired(false).withShortName("d")
+        .withArgument(abuilder.withName("modelClass").withMinimum(1).withMaximum(1).create())
+        .withDescription(
+          "The ModelDistribution class name."
+              + "Defaults to org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution").create();
     
-    Option prototypeOpt = obuilder.withLongName("modelPrototypeClass").withRequired(true).withShortName("p")
+    Option prototypeOpt = obuilder.withLongName("modelPrototypeClass").withRequired(false).withShortName("p")
         .withArgument(abuilder.withName("prototypeClass").withMinimum(1).withMaximum(1).create())
-        .withDescription("The ModelDistribution prototype Vector class name.").create();
+        .withDescription(
+          "The ModelDistribution prototype Vector class name."
+              + "Defaults to org.apache.mahout.math.RandomAccessSparseVector").create();
     
     Option sizeOpt = obuilder.withLongName("prototypeSize").withRequired(true).withShortName("s")
         .withArgument(abuilder.withName("prototypeSize").withMinimum(1).withMaximum(1).create())
@@ -85,14 +89,15 @@
       
       String input = cmdLine.getValue(inputOpt).toString();
       String output = cmdLine.getValue(outputOpt).toString();
-      String modelFactory = cmdLine.getValue(modelOpt).toString();
-      String modelPrototype = cmdLine.getValue(prototypeOpt).toString();
+      String modelFactory = "org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution";
+      if (cmdLine.hasOption(modelOpt)) modelFactory = cmdLine.getValue(modelOpt).toString();
+      String modelPrototype = "org.apache.mahout.math.RandomAccessSparseVector";
+      if (cmdLine.hasOption(prototypeOpt)) modelPrototype = cmdLine.getValue(prototypeOpt).toString();
       int prototypeSize = Integer.parseInt(cmdLine.getValue(sizeOpt).toString());
       int numModels = Integer.parseInt(cmdLine.getValue(topicsOpt).toString());
       int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
       double alpha_0 = Double.parseDouble(cmdLine.getValue(mOpt).toString());
-      runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels,
-        maxIterations, alpha_0);
+      runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels, maxIterations, alpha_0);
     } catch (OptionException e) {
       log.error("Exception parsing command line: ", e);
       CommandLineUtil.printHelp(group);

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java Mon Mar  1 05:42:35 2010
@@ -39,9 +39,7 @@
   
   public DirichletState(ModelDistribution<O> modelFactory,
                         int numClusters,
-                        double alpha_0,
-                        int thin,
-                        int burnin) {
+                        double alpha_0) {
     this.numClusters = numClusters;
     this.modelFactory = modelFactory;
     this.alpha_0 = alpha_0;

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java Mon Mar  1 05:42:35 2010
@@ -29,19 +29,11 @@
 
 public class FuzzyKMeansClusterer {
   
-  private static final double MINIMAL_VALUE = 0.0000000001; // using it for
-  // adding
-  // exception
-  // this value to any
-  // zero valued
-  // variable to avoid
-  // divide by Zero
-  
-  // private int nextClusterId = 0;
+  private static final double MINIMAL_VALUE = 0.0000000001;
   
   private DistanceMeasure measure;
   
-  private double convergenceDelta = 0;
+  private double convergenceDelta;
   
   private double m = 2.0; // default value
   
@@ -65,20 +57,6 @@
   }
   
   /**
-   * Configure the distance measure directly. Used by unit tests.
-   * 
-   * @param aMeasure
-   *          the DistanceMeasure
-   * @param aConvergenceDelta
-   *          the delta value used to define convergence
-   */
-  private void config(DistanceMeasure aMeasure, double aConvergenceDelta) {
-    measure = aMeasure;
-    convergenceDelta = aConvergenceDelta;
-    // nextClusterId = 0;
-  }
-  
-  /**
    * Configure the distance measure from the job
    * 
    * @param job
@@ -123,10 +101,8 @@
     
     for (int i = 0; i < clusters.size(); i++) {
       double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
-      Text key = new Text(clusters.get(i).getIdentifier()); // just output the
-      // identifier,avoids
-      // too much data
-      // traffic
+      Text key = new Text(clusters.get(i).getIdentifier());
+      // just output the identifier,avoids too much data traffic
       /*
        * Text value = new Text(Double.toString(probWeight) + FuzzyKMeansDriver.MAPPER_VALUE_SEPARATOR +
        * values.toString());
@@ -154,7 +130,7 @@
     List<Double> clusterDistanceList = new ArrayList<Double>();
     
     for (SoftCluster cluster : clusters) {
-      clusterDistanceList.add(measure.distance(point, cluster.getCenter()));
+      clusterDistanceList.add(measure.distance(cluster.getCenter(), point));
     }
     FuzzyKMeansOutput fOutput = new FuzzyKMeansOutput(clusters.size());
     for (int i = 0; i < clusters.size(); i++) {
@@ -175,9 +151,7 @@
       if (eachCDist == 0.0) {
         eachCDist = MINIMAL_VALUE;
       }
-      
       denom += Math.pow(clusterDistance / eachCDist, 2.0 / (m - 1));
-      
     }
     return 1.0 / denom;
   }
@@ -200,4 +174,81 @@
   public DistanceMeasure getMeasure() {
     return this.measure;
   }
+  
+  /**
+   * This is the reference k-means implementation. Given its inputs it iterates over the points and clusters
+   * until their centers converge or until the maximum number of iterations is exceeded.
+   * 
+   * @param points
+   *          the input List<Vector> of points
+   * @param clusters
+   *          the initial List<SoftCluster> of clusters
+   * @param measure
+   *          the DistanceMeasure to use
+   * @param maxIter
+   *          the maximum number of iterations
+   */
+  public static List<List<SoftCluster>> clusterPoints(List<Vector> points,
+                                                      List<SoftCluster> clusters,
+                                                      DistanceMeasure measure,
+                                                      double threshold,
+                                                      double m,
+                                                      int numIter) {
+    List<List<SoftCluster>> clustersList = new ArrayList<List<SoftCluster>>();
+    clustersList.add(clusters);
+    FuzzyKMeansClusterer clusterer = new FuzzyKMeansClusterer(measure, threshold, m);
+    boolean converged = false;
+    int iteration = 0;
+    for (int iter = 0; !converged && iter < numIter; iter++) {
+      List<SoftCluster> next = new ArrayList<SoftCluster>();
+      List<SoftCluster> cs = clustersList.get(iteration++);
+      for (SoftCluster c : cs) {
+        next.add(new SoftCluster(c.getCenter(), c.getId()));
+      }
+      clustersList.add(next);
+      converged = runFuzzyKMeansIteration(points, clustersList.get(iteration), clusterer);
+    }
+    return clustersList;
+  }
+  
+  /**
+   * Perform a single iteration over the points and clusters, assigning points to clusters and returning if
+   * the iterations are completed.
+   * 
+   * @param points
+   *          the List<Vector> having the input points
+   * @param clusterList
+   *          the List<Cluster> clusters
+   * @return
+   */
+  public static boolean runFuzzyKMeansIteration(List<Vector> points,
+                                                List<SoftCluster> clusterList,
+                                                FuzzyKMeansClusterer clusterer) {
+    // for each
+    for (Vector point : points) {
+      List<Double> clusterDistanceList = new ArrayList<Double>();
+      for (SoftCluster cluster : clusterList) {
+        clusterDistanceList.add(clusterer.getMeasure().distance(point, cluster.getCenter()));
+      }
+      
+      for (int i = 0; i < clusterList.size(); i++) {
+        double probWeight = clusterer.computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
+        clusterList.get(i).addPoint(point, Math.pow(probWeight, clusterer.getM()));
+      }
+    }
+    boolean converged = true;
+    for (SoftCluster cluster : clusterList) {
+      if (!clusterer.computeConvergence(cluster)) {
+        converged = false;
+      }
+    }
+    // update the cluster centers
+    if (!converged) {
+      for (SoftCluster cluster : clusterList) {
+        cluster.recomputeCenter();
+      }
+    }
+    return converged;
+    
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java Mon Mar  1 05:42:35 2010
@@ -21,35 +21,30 @@
 import java.io.DataOutput;
 import java.io.IOException;
 
-import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.ClusterBase;
 import org.apache.mahout.math.AbstractVector;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
 import org.apache.mahout.math.function.SquareRootFunction;
 
-public class SoftCluster implements Writable {
-  
-  // this cluster's clusterId
-  private int clusterId;
-  
-  // the current center
-  private Vector center = new RandomAccessSparseVector(0);
-  
+public class SoftCluster extends ClusterBase{
+
   // the current centroid is lazy evaluated and may be null
-  private Vector centroid = null;
+  private Vector centroid;
   
   // The Probability of belongingness sum
-  private double pointProbSum = 0.0;
+  private double pointProbSum;
   
   // the total of all points added to the cluster
-  private Vector weightedPointTotal = null;
+  private Vector weightedPointTotal;
   
   // has the centroid converged with the center?
-  private boolean converged = false;
+  private boolean converged;
   
   // track membership parameters
-  private double s0 = 0;
+  private double s0;
   
   private Vector s1;
   
@@ -87,10 +82,25 @@
     }
     return null;
   }
+
+  // For Writable
+  public SoftCluster() { }
+  
+  /**
+   * Construct a new SoftCluster with the given point as its center
+   * 
+   * @param center
+   *          the center point
+   */
+  public SoftCluster(Vector center) {
+    setCenter(new RandomAccessSparseVector(center));
+    this.pointProbSum = 0;
+    this.weightedPointTotal = getCenter().like();
+  }
   
   @Override
   public void write(DataOutput out) throws IOException {
-    out.writeInt(clusterId);
+    out.writeInt(this.getId());
     out.writeBoolean(converged);
     Vector vector = computeCentroid();
     VectorWritable.writeVector(out, vector);
@@ -98,13 +108,13 @@
   
   @Override
   public void readFields(DataInput in) throws IOException {
-    clusterId = in.readInt();
+    this.setId(in.readInt());
     converged = in.readBoolean();
     VectorWritable temp = new VectorWritable();
     temp.readFields(in);
-    center = temp.get();
+    this.setCenter(new RandomAccessSparseVector(temp.get()));
     this.pointProbSum = 0;
-    this.weightedPointTotal = center.like();
+    this.weightedPointTotal = getCenter().like();
   }
   
   /**
@@ -112,6 +122,7 @@
    * 
    * @return the new centroid
    */
+  @Override
   public Vector computeCentroid() {
     if (pointProbSum == 0) {
       return weightedPointTotal;
@@ -122,22 +133,6 @@
     return centroid;
   }
   
-  // For Writable
-  public SoftCluster() { }
-  
-  /**
-   * Construct a new SoftCluster with the given point as its center
-   * 
-   * @param center
-   *          the center point
-   */
-  public SoftCluster(Vector center) {
-    this.center = center;
-    this.pointProbSum = 0;
-    
-    this.weightedPointTotal = center.like();
-  }
-  
   /**
    * Construct a new SoftCluster with the given point as its center
    * 
@@ -145,8 +140,8 @@
    *          the center point
    */
   public SoftCluster(Vector center, int clusterId) {
-    this.clusterId = clusterId;
-    this.center = center;
+    this.setId(clusterId);
+    this.setCenter(center);
     this.pointProbSum = 0;
     this.weightedPointTotal = center.like();
   }
@@ -154,7 +149,7 @@
   /** Construct a new softcluster with the given clusterID */
   public SoftCluster(String clusterId) {
     
-    this.clusterId = Integer.parseInt(clusterId.substring(1));
+    this.setId(Integer.parseInt(clusterId.substring(1)));
     this.pointProbSum = 0;
     // this.weightedPointTotal = center.like();
     this.converged = clusterId.charAt(0) == 'V';
@@ -162,14 +157,15 @@
   
   @Override
   public String toString() {
-    return getIdentifier() + " - " + center.asFormatString();
+    return getIdentifier() + " - " + getCenter().asFormatString();
   }
   
+  @Override
   public String getIdentifier() {
     if (converged) {
-      return "V" + clusterId;
+      return "V" + this.getId();
     } else {
-      return "C" + clusterId;
+      return "C" + this.getId();
     }
   }
   
@@ -212,9 +208,9 @@
     centroid = null;
     pointProbSum += ptProb;
     if (weightedPointTotal == null) {
-      weightedPointTotal = point.clone().times(ptProb);
+      weightedPointTotal = point.clone().assign(Functions.mult, ptProb);
     } else {
-      weightedPointTotal = weightedPointTotal.plus(point.times(ptProb));
+      point.clone().assign(Functions.mult, ptProb).addTo(weightedPointTotal);
     }
   }
   
@@ -230,33 +226,25 @@
     if (weightedPointTotal == null) {
       weightedPointTotal = delta.clone();
     } else {
-      weightedPointTotal = weightedPointTotal.plus(delta);
+      delta.addTo(weightedPointTotal);
     }
   }
   
-  public Vector getCenter() {
-    return center;
-  }
-  
   public double getPointProbSum() {
     return pointProbSum;
   }
   
   /** Compute the centroid and set the center to it. */
   public void recomputeCenter() {
-    center = computeCentroid();
+    this.setCenter(computeCentroid());
     pointProbSum = 0;
-    weightedPointTotal = center.like();
+    weightedPointTotal = getCenter().like();
   }
   
   public Vector getWeightedPointTotal() {
     return weightedPointTotal;
   }
-  
-  public void setWeightedPointTotal(Vector v) {
-    this.weightedPointTotal = v;
-  }
-  
+
   public boolean isConverged() {
     return converged;
   }
@@ -265,8 +253,9 @@
     this.converged = converged;
   }
   
-  public int getClusterId() {
-    return clusterId;
+  @Override
+  public String asFormatString() {
+    return formatCluster(this);
   }
   
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java Mon Mar  1 05:42:35 2010
@@ -43,6 +43,8 @@
   /** Has the centroid converged with the center? */
   private boolean converged;
   
+  private double std = 0.00000001;
+  
   /**
    * Format the cluster for output
    * 
@@ -134,8 +136,8 @@
     super();
     this.setCenter(new RandomAccessSparseVector(center));
     this.setNumPoints(0);
-    this.setPointTotal(center.like());
-    this.pointSquaredTotal = center.like();
+    this.setPointTotal(getCenter().like());
+    this.pointSquaredTotal = getCenter().like();
   }
   
   /** For (de)serialization as a Writable */
@@ -152,7 +154,7 @@
     this.setId(clusterId);
     this.setCenter(new RandomAccessSparseVector(center));
     this.setNumPoints(0);
-    this.setPointTotal(center.like());
+    this.setPointTotal(getCenter().like());
     this.pointSquaredTotal = getCenter().like();
   }
   
@@ -194,21 +196,24 @@
    */
   public void addPoints(int count, Vector delta) {
     centroid = null;
-    setNumPoints(getNumPoints() + count);
-    if (getPointTotal() == null) {
-      setPointTotal(delta.clone());
+    if (getNumPoints() == 0) {
+      setPointTotal(new RandomAccessSparseVector(delta.clone()));
       pointSquaredTotal = new RandomAccessSparseVector(delta.clone().assign(Functions.square));
     } else {
       delta.addTo(getPointTotal());
       delta.clone().assign(Functions.square).addTo(pointSquaredTotal);
     }
+    setNumPoints(getNumPoints() + count);
   }
   
   /** Compute the centroid and set the center to it. */
   public void recomputeCenter() {
+    std = getStd();
     setCenter(computeCentroid());
+    centroid = null;
     setNumPoints(0);
-    setPointTotal(getCenter().like());
+    this.setPointTotal(getCenter().like());
+    this.pointSquaredTotal = getCenter().like();
   }
   
   /**
@@ -236,6 +241,7 @@
   
   /** @return the std */
   public double getStd() {
+    if (getNumPoints() == 0) return std;
     Vector stds = pointSquaredTotal.times(getNumPoints()).minus(getPointTotal().times(getPointTotal()))
         .assign(new SquareRootFunction()).divide(getNumPoints());
     return stds.zSum() / stds.size();

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java Mon Mar  1 05:42:35 2010
@@ -17,6 +17,7 @@
 package org.apache.mahout.clustering.kmeans;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.hadoop.io.Text;
@@ -95,4 +96,86 @@
     String key = (name != null) && (name.length() != 0) ? name : point.asFormatString();
     output.collect(new Text(key), new Text(String.valueOf(nearestCluster.getId())));
   }
+  
+  /**
+   * This is the reference k-means implementation. Given its inputs it iterates over the points and clusters
+   * until their centers converge or until the maximum number of iterations is exceeded.
+   * 
+   * @param points
+   *          the input List<Vector> of points
+   * @param clusters
+   *          the List<Cluster> of initial clusters
+   * @param measure
+   *          the DistanceMeasure to use
+   * @param maxIter
+   *          the maximum number of iterations
+   */
+  public static List<List<Cluster>> clusterPoints(List<Vector> points,
+                                                  List<Cluster> clusters,
+                                                  DistanceMeasure measure,
+                                                  int maxIter,
+                                                  double distanceThreshold) {
+    List<List<Cluster>> clustersList = new ArrayList<List<Cluster>>();
+    clustersList.add(clusters);
+    
+    boolean converged = false;
+    int iteration = 0;
+    while (!converged && iteration < maxIter) {
+      System.out.println("iteration: " + iteration);
+      List<Cluster> next = new ArrayList<Cluster>();
+      List<Cluster> cs = clustersList.get(iteration++);
+      for (Cluster c : cs) {
+        next.add(new Cluster(c.getCenter()));
+      }
+      clustersList.add(next);
+      converged = runKMeansIteration(points, next, measure, distanceThreshold);
+    }
+    return clustersList;
+  }
+  
+  /**
+   * Perform a single iteration over the points and clusters, assigning points to clusters and returning if
+   * the iterations are completed.
+   * 
+   * @param points
+   *          the List<Vector> having the input points
+   * @param clusters
+   *          the List<Cluster> clusters
+   * @param measure
+   *          a DistanceMeasure to use
+   * @return
+   */
+  public static boolean runKMeansIteration(List<Vector> points,
+                                           List<Cluster> clusters,
+                                           DistanceMeasure measure,
+                                           double distanceThreshold) {
+    // iterate through all points, assigning each to the nearest cluster
+    for (Vector point : points) {
+      Cluster closestCluster = null;
+      double closestDistance = Double.MAX_VALUE;
+      for (Cluster cluster : clusters) {
+        double distance = measure.distance(cluster.getCenter(), point);
+        if (closestCluster == null || closestDistance > distance) {
+          closestCluster = cluster;
+          closestDistance = distance;
+        }
+      }
+      closestCluster.addPoint(point);
+    }
+    // test for convergence
+    boolean converged = true;
+    for (Cluster cluster : clusters) {
+      if (!cluster.computeConvergence(measure, distanceThreshold)) {
+        converged = false;
+      }
+    }
+    // update the cluster centers
+    if (!converged) {
+      for (Cluster cluster : clusters) {
+        cluster.recomputeCenter();
+      }
+    }
+    return converged;
+  }
+  
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java Mon Mar  1 05:42:35 2010
@@ -30,6 +30,7 @@
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -46,7 +47,7 @@
   
   public static final String K = "k";
   
-  private RandomSeedGenerator() { }
+  private RandomSeedGenerator() {}
   
   public static Path buildRandom(String input, String output, int k) throws IOException,
                                                                     IllegalAccessException,
@@ -117,4 +118,20 @@
     
     return outFile;
   }
+  
+  public static List<Vector> chooseRandomPoints(List<Vector> vectors, int k) {
+    List<Vector> chosenPoints = new ArrayList<Vector>(k);
+    Random random = RandomUtils.getRandom();
+    for(Vector value : vectors){
+      int currentSize = chosenPoints.size();
+      if (currentSize < k) {
+        chosenPoints.add(value);
+      } else if (random.nextInt(currentSize + 1) == 0) { // with chance 1/(currentSize+1) pick new element
+        int indexToRemove = random.nextInt(currentSize); // evict one chosen randomly
+        chosenPoints.remove(indexToRemove);
+        chosenPoints.add(value);
+      }
+    }
+    return chosenPoints;
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java Mon Mar  1 05:42:35 2010
@@ -21,16 +21,15 @@
 import java.io.DataOutput;
 import java.io.IOException;
 import java.lang.reflect.Type;
-import java.util.ArrayList;
-import java.util.List;
 
 import org.apache.mahout.clustering.ClusterBase;
 import org.apache.mahout.math.CardinalityException;
-import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.list.IntArrayList;
 
 import com.google.gson.Gson;
 import com.google.gson.GsonBuilder;
@@ -44,7 +43,7 @@
 public class MeanShiftCanopy extends ClusterBase {
   
   // TODO: this is problematic, but how else to encode membership?
-  private List<Vector> boundPoints = new ArrayList<Vector>();
+  private IntArrayList boundPoints = new IntArrayList();
   
   private boolean converged = false;
   
@@ -78,9 +77,9 @@
   public MeanShiftCanopy(Vector point, int id) {
     this.setId(id);
     this.setCenter(point);
-    this.setPointTotal(point.clone());
+    this.setPointTotal(new RandomAccessSparseVector(point.clone()));
     this.setNumPoints(1);
-    this.boundPoints.add(point);
+    this.boundPoints.add(id);
   }
   
   /**
@@ -91,14 +90,14 @@
    * @param id
    *          an int identifying the canopy local to this process only
    * @param boundPoints
-   *          a List<Vector> containing points bound to the canopy
+   *          a IntArrayList containing points ids bound to the canopy
    * @param converged
    *          true if the canopy has converged
    */
-  MeanShiftCanopy(Vector point, int id, List<Vector> boundPoints, boolean converged) {
+  MeanShiftCanopy(Vector point, int id, IntArrayList boundPoints, boolean converged) {
     this.setId(id);
     this.setCenter(point);
-    this.setPointTotal(point.clone());
+    this.setPointTotal(new RandomAccessSparseVector(point));
     this.setNumPoints(1);
     this.boundPoints = boundPoints;
     this.converged = converged;
@@ -117,22 +116,14 @@
   void addPoints(Vector point, int nPoints) {
     setNumPoints(getNumPoints() + nPoints);
     Vector subTotal = nPoints == 1 ? point.clone() : point.times(nPoints);
-    setPointTotal(getPointTotal() == null ? subTotal : getPointTotal().plus(subTotal));
-  }
-  
-  /**
-   * Compute the bound centroid by averaging the bound points
-   * 
-   * @return a Vector which is the new bound centroid
-   */
-  public Vector computeBoundCentroid() {
-    Vector result = new DenseVector(getCenter().size());
-    for (Vector v : boundPoints) {
-      result.assign(v, Functions.plus);
+    if (getPointTotal() == null) {
+      setPointTotal(new RandomAccessSparseVector(subTotal));
+    } else {
+      subTotal.addTo(getPointTotal()); 
     }
-    return result.divide(boundPoints.size());
   }
   
+  
   /**
    * Compute the centroid by normalizing the pointTotal
    * 
@@ -147,7 +138,7 @@
     }
   }
   
-  public List<Vector> getBoundPoints() {
+  public IntArrayList getBoundPoints() {
     return boundPoints;
   }
   
@@ -164,7 +155,7 @@
     setId(canopy.getId());
     setCenter(canopy.getCenter());
     addPoints(getCenter(), 1);
-    boundPoints.addAll(canopy.getBoundPoints());
+    boundPoints.addAllOf(canopy.getBoundPoints());
   }
   
   public boolean isConverged() {
@@ -178,7 +169,7 @@
    *          an existing MeanShiftCanopy
    */
   void merge(MeanShiftCanopy canopy) {
-    boundPoints.addAll(canopy.boundPoints);
+    boundPoints.addAllOf(canopy.boundPoints);
   }
   
   @Override
@@ -204,20 +195,19 @@
     temp.readFields(in);
     this.setCenter(temp.get());
     int numpoints = in.readInt();
-    this.boundPoints = new ArrayList<Vector>();
+    this.boundPoints = new IntArrayList();
     for (int i = 0; i < numpoints; i++) {
-      temp.readFields(in);
-      this.boundPoints.add(temp.get());
+      this.boundPoints.add(in.readInt());
     }
   }
   
   @Override
   public void write(DataOutput out) throws IOException {
     super.write(out);
-    VectorWritable.writeVector(out, computeCentroid());
+    VectorWritable.writeVector(out, new SequentialAccessSparseVector(computeCentroid()));
     out.writeInt(boundPoints.size());
-    for (Vector v : boundPoints) {
-      VectorWritable.writeVector(out, v);
+    for (int v : boundPoints.elements()) {
+      out.writeInt(v);
     }
   }
   
@@ -236,7 +226,7 @@
     return formatCanopy(this);
   }
   
-  public void setBoundPoints(List<Vector> boundPoints) {
+  public void setBoundPoints(IntArrayList boundPoints) {
     this.boundPoints = boundPoints;
   }
   

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java Mon Mar  1 05:42:35 2010
@@ -1,6 +1,7 @@
 package org.apache.mahout.clustering.meanshift;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.hadoop.io.Text;
@@ -15,7 +16,7 @@
   
   private double convergenceDelta = 0;
   // the next canopyId to be allocated
-  //private int nextCanopyId = 0;
+  // private int nextCanopyId = 0;
   // the T1 distance threshold
   private double t1;
   // the T2 distance threshold
@@ -57,7 +58,7 @@
     } catch (InstantiationException e) {
       throw new IllegalStateException(e);
     }
-    //nextCanopyId = 0; // never read?
+    // nextCanopyId = 0; // never read?
     t1 = Double.parseDouble(job.get(MeanShiftCanopyConfigKeys.T1_KEY));
     t2 = Double.parseDouble(job.get(MeanShiftCanopyConfigKeys.T2_KEY));
     convergenceDelta = Double.parseDouble(job.get(MeanShiftCanopyConfigKeys.CLUSTER_CONVERGENCE_KEY));
@@ -70,7 +71,7 @@
    *          the convergence criteria
    */
   public void config(DistanceMeasure aMeasure, double aT1, double aT2, double aDelta) {
-    //nextCanopyId = 100; // so canopyIds will sort properly  // never read?
+    // nextCanopyId = 100; // so canopyIds will sort properly // never read?
     measure = aMeasure;
     t1 = aT1;
     t2 = aT2;
@@ -111,8 +112,7 @@
   }
   
   /** Emit the new canopy to the collector, keyed by the canopy's Id */
-  static void emitCanopy(MeanShiftCanopy canopy,
-                         OutputCollector<Text,WritableComparable<?>> collector) throws IOException {
+  static void emitCanopy(MeanShiftCanopy canopy, OutputCollector<Text,WritableComparable<?>> collector) throws IOException {
     String identifier = canopy.getIdentifier();
     collector.collect(new Text(identifier), new Text("new " + canopy.toString()));
   }
@@ -126,8 +126,7 @@
    */
   public boolean shiftToMean(MeanShiftCanopy canopy) {
     Vector centroid = canopy.computeCentroid();
-    canopy
-        .setConverged(new EuclideanDistanceMeasure().distance(centroid, canopy.getCenter()) < convergenceDelta);
+    canopy.setConverged(measure.distance(centroid, canopy.getCenter()) < convergenceDelta);
     canopy.setCenter(centroid);
     canopy.setNumPoints(1);
     canopy.setPointTotal(centroid.clone());
@@ -159,4 +158,82 @@
   public boolean closelyBound(MeanShiftCanopy canopy, Vector point) {
     return measure.distance(canopy.getCenter(), point) < t2;
   }
+  
+  /**
+   * Story: User can exercise the reference implementation to verify that the test datapoints are clustered in
+   * a reasonable manner.
+   */
+  public void testReferenceImplementation() {
+    MeanShiftCanopyClusterer clusterer = new MeanShiftCanopyClusterer(new EuclideanDistanceMeasure(), 4.0,
+        1.0, 0.5);
+    List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
+    // add all points to the canopies
+    
+    boolean done = false;
+    int iter = 1;
+    while (!done) {// shift canopies to their centroids
+      done = true;
+      List<MeanShiftCanopy> migratedCanopies = new ArrayList<MeanShiftCanopy>();
+      for (MeanShiftCanopy canopy : canopies) {
+        done = clusterer.shiftToMean(canopy) && done;
+        clusterer.mergeCanopy(canopy, migratedCanopies);
+      }
+      canopies = migratedCanopies;
+      System.out.println(iter++);
+    }
+  }
+  
+  /**
+   * This is the reference mean-shift implementation. Given its inputs it iterates over the points and
+   * clusters until their centers converge or until the maximum number of iterations is exceeded.
+   * 
+   * @param points
+   *          the input List<Vector> of points
+   * @param measure
+   *          the DistanceMeasure to use
+   * @param maxIter
+   *          the maximum number of iterations
+   */
+  public static List<MeanShiftCanopy> clusterPoints(List<Vector> points,
+                                                    DistanceMeasure measure,
+                                                    double convergenceThreshold,
+                                                    double t1,
+                                                    double t2,
+                                                    int numIter) {
+    MeanShiftCanopyClusterer clusterer = new MeanShiftCanopyClusterer(measure, t1, t2, convergenceThreshold);
+    
+    List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
+    int nextCanopyId = 0;
+    for (Vector point : points) {
+      clusterer.mergeCanopy(new MeanShiftCanopy(point, nextCanopyId++), canopies);
+    }
+    
+    boolean converged = false;
+    for (int iter = 0; !converged && iter < numIter; iter++) {
+      converged = runMeanShiftCanopyIteration(canopies, clusterer);
+    }
+    return canopies;
+  }
+  
+  /**
+   * Perform a single iteration over the points and clusters, assigning points to clusters and returning if
+   * the iterations are completed.
+   * 
+   * @param clusterList
+   *          the List<MeanShiftCanopy> clusters
+   * @return
+   */
+  public static boolean runMeanShiftCanopyIteration(List<MeanShiftCanopy> canopies,
+                                                    MeanShiftCanopyClusterer clusterer) {
+    boolean converged = true;
+    List<MeanShiftCanopy> migratedCanopies = new ArrayList<MeanShiftCanopy>();
+    for (MeanShiftCanopy canopy : canopies) {
+      converged = clusterer.shiftToMean(canopy) && converged;
+      clusterer.mergeCanopy(canopy, migratedCanopies);
+    }
+    canopies = migratedCanopies;
+    return converged;
+    
+  }
+  
 }

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyCreatorMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyCreatorMapper.java?rev=917396&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyCreatorMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyCreatorMapper.java Mon Mar  1 05:42:35 2010
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.meanshift;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.math.VectorWritable;
+
+public class MeanShiftCanopyCreatorMapper extends MapReduceBase implements
+    Mapper<WritableComparable<?>,VectorWritable,Text,MeanShiftCanopy> {
+  
+  private static int nextCanopyId = -1;
+  
+  @Override
+  public void map(WritableComparable<?> key,
+                  VectorWritable vector,
+                  OutputCollector<Text,MeanShiftCanopy> output,
+                  Reporter reporter) throws IOException {
+    MeanShiftCanopy canopy = new MeanShiftCanopy(vector.get(), nextCanopyId++);
+    output.collect(new Text(key.toString()), canopy);
+  }
+  
+  @Override
+  public void configure(JobConf job) {
+    super.configure(job);
+    if (nextCanopyId == -1) {
+      String taskId = job.get("mapred.task.id");
+      String[] parts = taskId.split("_");
+      if (parts.length != 6 || !parts[0].equals("attempt")
+          || (!"m".equals(parts[3]) && !"r".equals(parts[3]))) {
+        throw new IllegalArgumentException("TaskAttemptId string : " + taskId + " is not properly formed");
+      }
+      nextCanopyId = ((1 << 31) / 50000) * (Integer.parseInt(parts[4]));
+      //each mapper has 42,949 ids to give.
+    }
+  }
+}

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java Mon Mar  1 05:42:35 2010
@@ -45,7 +45,7 @@
   
   private static final Logger log = LoggerFactory.getLogger(MeanShiftCanopyDriver.class);
   
-  private MeanShiftCanopyDriver() { }
+  private MeanShiftCanopyDriver() {}
   
   public static void main(String[] args) {
     DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
@@ -88,7 +88,8 @@
       double t1 = Double.parseDouble(cmdLine.getValue(threshold1Opt).toString());
       double t2 = Double.parseDouble(cmdLine.getValue(threshold2Opt).toString());
       double convergenceDelta = Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt).toString());
-      runJob(input, output, output + MeanShiftCanopyConfigKeys.CONTROL_PATH_KEY,
+      createCanopyFromVectors(input, output + "/intial-canopies");
+      runJob(output + "/intial-canopies", output, output + MeanShiftCanopyConfigKeys.CONTROL_PATH_KEY,
         measureClassName, t1, t2, convergenceDelta);
     } catch (OptionException e) {
       log.error("Exception parsing command line: ", e);
@@ -150,4 +151,47 @@
       log.warn(e.toString(), e);
     }
   }
+  
+  /**
+   * Run the job
+   * 
+   * @param input
+   *          the input pathname String
+   * @param output
+   *          the output pathname String
+   * @param control
+   *          the control path
+   * @param measureClassName
+   *          the DistanceMeasure class name
+   * @param t1
+   *          the T1 distance threshold
+   * @param t2
+   *          the T2 distance threshold
+   * @param convergenceDelta
+   *          the double convergence criteria
+   */
+  public static void createCanopyFromVectors(String input, String output) {
+    
+    Configurable client = new JobClient();
+    JobConf conf = new JobConf(MeanShiftCanopyDriver.class);
+    
+    conf.setOutputKeyClass(Text.class);
+    conf.setOutputValueClass(MeanShiftCanopy.class);
+    
+    FileInputFormat.setInputPaths(conf, new Path(input));
+    Path outPath = new Path(output);
+    FileOutputFormat.setOutputPath(conf, outPath);
+    
+    conf.setMapperClass(MeanShiftCanopyCreatorMapper.class);
+    conf.setNumReduceTasks(0);
+    conf.setInputFormat(SequenceFileInputFormat.class);
+    conf.setOutputFormat(SequenceFileOutputFormat.class);
+    
+    client.setConf(conf);
+    try {
+      JobClient.runJob(conf);
+    } catch (IOException e) {
+      log.warn(e.toString(), e);
+    }
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java Mon Mar  1 05:42:35 2010
@@ -127,10 +127,13 @@
       fs.delete(outPath, true);
     }
     fs.mkdirs(outPath);
+    
+    MeanShiftCanopyDriver.createCanopyFromVectors(input, output+"/initial-canopies");
+    
     // iterate until the clusters converge
     boolean converged = false;
     int iteration = 0;
-    String clustersIn = input;
+    String clustersIn = output+"/initial-canopies";
     while (!converged && (iteration < maxIterations)) {
       log.info("Iteration {}", iteration);
       // point the output to a new directory per iteration