You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2010/04/29 18:16:11 UTC

svn commit: r939360 - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/clustering/ core/src/main/java/org/apache/mahout/clustering/canopy/ core/src/main/java/org/apache/mahout/clustering/dirichlet/ core/src/main/java/org/apache/mahout/clus...

Author: jeastman
Date: Thu Apr 29 16:16:10 2010
New Revision: 939360

URL: http://svn.apache.org/viewvc?rev=939360&view=rev
Log:
MAHOUT-236: modified point clustering jobs to output WeightedPointWritables containing the clustered point and its probability of membership. Still only outputting the top cluster for each point

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedVectorWritable.java
      - copied, changed from r939144, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedPointWritable.java
Removed:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedPointWritable.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java
    lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDriver.java
    lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java
    lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java

Copied: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedVectorWritable.java (from r939144, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedPointWritable.java)
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedVectorWritable.java?p2=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedVectorWritable.java&p1=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedPointWritable.java&r1=939144&r2=939360&rev=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedPointWritable.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/WeightedVectorWritable.java Thu Apr 29 16:16:10 2010
@@ -24,7 +24,11 @@ import java.io.IOException;
 import org.apache.hadoop.io.Writable;
 import org.apache.mahout.math.VectorWritable;
 
-public class WeightedPointWritable implements Writable {
+public class WeightedVectorWritable implements Writable {
+
+  private double weight;
+
+  private VectorWritable vector;
 
   /**
    * @return the weight
@@ -36,39 +40,35 @@ public class WeightedPointWritable imple
   /**
    * @return the point
    */
-  public VectorWritable getPoint() {
-    return point;
+  public VectorWritable getVector() {
+    return vector;
   }
 
-  public WeightedPointWritable(double weight, VectorWritable point) {
+  public WeightedVectorWritable(double weight, VectorWritable vector) {
     super();
     this.weight = weight;
-    this.point = point;
+    this.vector = vector;
   }
 
-  public WeightedPointWritable() {
+  public WeightedVectorWritable() {
     super();
   }
 
-  private double weight;
-
-  private VectorWritable point;
-
   @Override
   public void readFields(DataInput in) throws IOException {
     weight = in.readDouble();
-    point = new VectorWritable();
-    point.readFields(in);
+    vector = new VectorWritable();
+    vector.readFields(in);
   }
 
   @Override
   public void write(DataOutput out) throws IOException {
     out.writeDouble(weight);
-    point.write(out);
+    vector.write(out);
   }
 
   public String toString() {
-    return String.valueOf(weight) + ": " + (point == null ? "null" : ClusterBase.formatVector(point.get(), null));
+    return String.valueOf(weight) + ": " + (vector == null ? "null" : ClusterBase.formatVector(vector.get(), null));
   }
 
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java Thu Apr 29 16:16:10 2010
@@ -27,37 +27,38 @@ import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
 public class CanopyClusterer {
-  
+
   private int nextCanopyId;
-  
+
   private int numVectors;
-  
+
   // the T1 distance threshold
   private double t1;
-  
+
   // the T2 distance threshold
   private double t2;
-  
+
   // the distance measure
   private DistanceMeasure measure;
-  
+
   // private int nextClusterId = 0;
-  
+
   public CanopyClusterer(DistanceMeasure measure, double t1, double t2) {
     this.t1 = t1;
     this.t2 = t2;
     this.measure = measure;
   }
-  
+
   public CanopyClusterer(JobConf job) {
     this.configure(job);
   }
-  
+
   /**
    * Configure the Canopy and its distance measure
    * 
@@ -81,14 +82,14 @@ public class CanopyClusterer {
     t2 = Double.parseDouble(job.get(CanopyConfigKeys.T2_KEY));
     nextCanopyId = 0;
   }
-  
+
   /** Configure the Canopy for unit tests */
   public void config(DistanceMeasure aMeasure, double aT1, double aT2) {
     measure = aMeasure;
     t1 = aT1;
     t2 = aT2;
   }
-  
+
   /**
    * This is the same algorithm as the reference but inverted to iterate over existing canopies instead of the
    * points. Because of this it does not need to actually store the points, instead storing a total points
@@ -118,7 +119,7 @@ public class CanopyClusterer {
     }
     numVectors++;
   }
-  
+
   /**
    * This method is used by the CanopyMapper to perform canopy inclusion tests and to emit the point and its
    * covering canopies to the output. The CanopyCombiner will then sum the canopy points and produce the
@@ -131,9 +132,8 @@ public class CanopyClusterer {
    * @param collector
    *          an OutputCollector in which to emit the point
    */
-  public void emitPointToNewCanopies(Vector point,
-                                     List<Canopy> canopies,
-                                     OutputCollector<Text,Vector> collector) throws IOException {
+  public void emitPointToNewCanopies(Vector point, List<Canopy> canopies, OutputCollector<Text, Vector> collector)
+      throws IOException {
     boolean pointStronglyBound = false;
     for (Canopy canopy : canopies) {
       double dist = measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point);
@@ -148,7 +148,7 @@ public class CanopyClusterer {
       canopy.emitPoint(point, collector);
     }
   }
-  
+
   /**
    * This method is used by the CanopyMapper to perform canopy inclusion tests and to emit the point keyed by
    * its covering canopies to the output. if the point is not covered by any canopies (due to canopy centroid
@@ -163,10 +163,8 @@ public class CanopyClusterer {
    * @param reporter
    *          to report status of the job
    */
-  public void emitPointToExistingCanopies(Vector point,
-                                          List<Canopy> canopies,
-                                          OutputCollector<IntWritable,VectorWritable> collector,
-                                          Reporter reporter) throws IOException {
+  public void emitPointToExistingCanopies(Vector point, List<Canopy> canopies,
+      OutputCollector<IntWritable, WeightedVectorWritable> collector, Reporter reporter) throws IOException {
     double minDist = Double.MAX_VALUE;
     Canopy closest = null;
     boolean isCovered = false;
@@ -175,7 +173,7 @@ public class CanopyClusterer {
       if (dist < t1) {
         isCovered = true;
         VectorWritable vw = new VectorWritable(point);
-        collector.collect(new IntWritable(canopy.getId()), vw);
+        collector.collect(new IntWritable(canopy.getId()), new WeightedVectorWritable(1, vw));
         reporter.setStatus("Emit Canopy ID:" + canopy.getIdentifier());
       } else if (dist < minDist) {
         minDist = dist;
@@ -185,11 +183,11 @@ public class CanopyClusterer {
     // if the point is not contained in any canopies (due to canopy centroid
     // clustering), emit the point to the closest covering canopy.
     if (!isCovered) {
-      collector.collect(new IntWritable(closest.getId()), new VectorWritable(point));
+      collector.collect(new IntWritable(closest.getId()), new WeightedVectorWritable(1, new VectorWritable(point)));
       reporter.setStatus("Emit Closest Canopy ID:" + closest.getIdentifier());
     }
   }
-  
+
   /**
    * Return if the point is covered by the canopy
    * 
@@ -200,7 +198,7 @@ public class CanopyClusterer {
   public boolean canopyCovers(Canopy canopy, Vector point) {
     return measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point) < t1;
   }
-  
+
   /**
    * Iterate through the points, adding new canopies. Return the canopies.
    * 
@@ -246,7 +244,7 @@ public class CanopyClusterer {
     }
     return canopies;
   }
-  
+
   /**
    * Iterate through the canopies, adding their centroids to a list
    * 
@@ -261,7 +259,7 @@ public class CanopyClusterer {
     }
     return result;
   }
-  
+
   /**
    * Iterate through the canopies, resetting their center to their centroids
    * 
@@ -273,5 +271,5 @@ public class CanopyClusterer {
       canopy.setCenter(canopy.computeCentroid());
     }
   }
-  
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterDriver.java Thu Apr 29 16:16:10 2010
@@ -38,9 +38,9 @@ import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.SequenceFileInputFormat;
 import org.apache.hadoop.mapred.SequenceFileOutputFormat;
 import org.apache.hadoop.mapred.lib.IdentityReducer;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
-import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -155,7 +155,7 @@ public final class ClusterDriver {
     
     conf.setInputFormat(SequenceFileInputFormat.class);
     conf.setOutputKeyClass(IntWritable.class);
-    conf.setOutputValueClass(VectorWritable.class);
+    conf.setOutputValueClass(WeightedVectorWritable.class);
     conf.setOutputFormat(SequenceFileOutputFormat.class);
     
     FileInputFormat.setInputPaths(conf, new Path(points));

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/ClusterMapper.java Thu Apr 29 16:16:10 2010
@@ -32,11 +32,12 @@ import org.apache.hadoop.mapred.MapReduc
 import org.apache.hadoop.mapred.Mapper;
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
 public class ClusterMapper extends MapReduceBase implements
-    Mapper<WritableComparable<?>,VectorWritable,IntWritable,VectorWritable> {
+    Mapper<WritableComparable<?>,VectorWritable,IntWritable,WeightedVectorWritable> {
   
   private CanopyClusterer canopyClusterer;
   private final List<Canopy> canopies = new ArrayList<Canopy>();
@@ -44,7 +45,7 @@ public class ClusterMapper extends MapRe
   @Override
   public void map(WritableComparable<?> key,
                   VectorWritable point,
-                  OutputCollector<IntWritable,VectorWritable> output,
+                  OutputCollector<IntWritable,WeightedVectorWritable> output,
                   Reporter reporter) throws IOException {
     canopyClusterer.emitPointToExistingCanopies(point.get(), canopies, output, reporter);
   }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterMapper.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterMapper.java Thu Apr 29 16:16:10 2010
@@ -36,10 +36,11 @@ import org.apache.hadoop.mapred.OutputCo
 import org.apache.hadoop.mapred.OutputLogFilter;
 import org.apache.hadoop.mapred.Reporter;
 import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.math.VectorWritable;
 
 public class DirichletClusterMapper extends MapReduceBase implements
-    Mapper<WritableComparable<?>, VectorWritable, IntWritable, VectorWritable> {
+    Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable> {
 
   private OutputCollector<IntWritable, VectorWritable> output;
 
@@ -47,7 +48,7 @@ public class DirichletClusterMapper exte
 
   @SuppressWarnings("unchecked")
   @Override
-  public void map(WritableComparable<?> key, VectorWritable vector, OutputCollector<IntWritable, VectorWritable> output,
+  public void map(WritableComparable<?> key, VectorWritable vector, OutputCollector<IntWritable, WeightedVectorWritable> output,
       Reporter reporter) throws IOException {
     int clusterId = -1;
     double clusterPdf = 0;
@@ -59,7 +60,7 @@ public class DirichletClusterMapper exte
       }
     }
     System.out.println(clusterId + ": " + ClusterBase.formatVector(vector.get(), null));
-    output.collect(new IntWritable(clusterId), vector);
+    output.collect(new IntWritable(clusterId), new WeightedVectorWritable(clusterPdf, vector));
   }
 
   @Override

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Thu Apr 29 16:16:10 2010
@@ -42,6 +42,7 @@ import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.SequenceFileInputFormat;
 import org.apache.hadoop.mapred.SequenceFileOutputFormat;
 import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.dirichlet.models.VectorModelDistribution;
 import org.apache.mahout.clustering.kmeans.KMeansDriver;
 import org.apache.mahout.common.CommandLineUtil;
@@ -376,7 +377,7 @@ public class DirichletDriver {
     conf.setJobName("Dirichlet Clustering");
     
     conf.setOutputKeyClass(IntWritable.class);
-    conf.setOutputValueClass(VectorWritable.class);
+    conf.setOutputValueClass(WeightedVectorWritable.class);
     
     FileInputFormat.setInputPaths(conf, new Path(input));
     Path outPath = new Path(output);

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterMapper.java Thu Apr 29 16:16:10 2010
@@ -22,25 +22,24 @@ import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.hadoop.io.IntWritable;
-import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.MapReduceBase;
 import org.apache.hadoop.mapred.Mapper;
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.hadoop.mapred.Reporter;
-import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.math.VectorWritable;
 
 public class FuzzyKMeansClusterMapper extends MapReduceBase implements
-    Mapper<WritableComparable<?>, VectorWritable, IntWritable, VectorWritable> {
+    Mapper<WritableComparable<?>, VectorWritable, IntWritable, WeightedVectorWritable> {
 
   private final List<SoftCluster> clusters = new ArrayList<SoftCluster>();
 
   private FuzzyKMeansClusterer clusterer;
 
   @Override
-  public void map(WritableComparable<?> key, VectorWritable point, OutputCollector<IntWritable, VectorWritable> output,
+  public void map(WritableComparable<?> key, VectorWritable point, OutputCollector<IntWritable, WeightedVectorWritable> output,
       Reporter reporter) throws IOException {
     clusterer.outputPointWithClusterProbabilities(key.toString(), point.get(), clusters, output);
   }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java Thu Apr 29 16:16:10 2010
@@ -25,9 +25,8 @@ import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.OutputCollector;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.math.NamedVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
@@ -126,35 +125,32 @@ public class FuzzyKMeansClusterer {
    *          the OutputCollector to emit into
    */
   public void outputPointWithClusterProbabilities(String key, Vector point, List<SoftCluster> clusters,
-      OutputCollector<IntWritable, VectorWritable> output) throws IOException {
+      OutputCollector<IntWritable, WeightedVectorWritable> output) throws IOException {
 
-    // TODO: remove this later
-    //    List<Double> clusterDistanceList = new ArrayList<Double>();
-    //    
-    //    for (SoftCluster cluster : clusters) {
-    //      clusterDistanceList.add(measure.distance(cluster.getCenter(), point));
-    //    }
-    //    FuzzyKMeansOutput fOutput = new FuzzyKMeansOutput(clusters.size());
-    //    for (int i = 0; i < clusters.size(); i++) {
-    //      double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
-    //      fOutput.add(i, clusters.get(i), probWeight);
-    //    }
-    //    String name = point.getName();
-
-    // for now just emit the closest cluster
-    int clusterId = -1;
-    double distance = Double.MAX_VALUE;
+    // calculate point distances for all clusters    
+    List<Double> clusterDistanceList = new ArrayList<Double>();
     for (SoftCluster cluster : clusters) {
-      Vector center = cluster.getCenter();
-      // System.out.println("cluster-" + cluster.getId() + "@ " + ClusterBase.formatVector(center, null));
-      double d = measure.distance(center, point);
-      if (d < distance) {
-        clusterId = cluster.getId();
-        distance = d;
+      clusterDistanceList.add(measure.distance(cluster.getCenter(), point));
+    }
+    // calculate point pdf for all clusters
+    List<Double> clusterPdfList = new ArrayList<Double>();
+    for (int i = 0; i < clusters.size(); i++) {
+      double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
+      clusterPdfList.add(probWeight);
+    }
+    // for now just emit the most likely cluster
+    int clusterId = -1;
+    double clusterPdf = 0;
+    for (int i = 0; i < clusters.size(); i++) {
+      // System.out.println("cluster-" + clusters.get(i).getId() + "@ " + ClusterBase.formatVector(center, null));
+      double pdf = clusterPdfList.get(i);
+      if (pdf > clusterPdf) {
+        clusterId = clusters.get(i).getId();
+        clusterPdf = pdf;
       }
     }
     // System.out.println("cluster-" + clusterId + ": " + ClusterBase.formatVector(point, null));
-    output.collect(new IntWritable(clusterId), new VectorWritable(point));
+    output.collect(new IntWritable(clusterId), new WeightedVectorWritable(clusterPdf, new VectorWritable(point)));
   }
 
   /** Computes the probability of a point belonging to a cluster */
@@ -210,12 +206,8 @@ public class FuzzyKMeansClusterer {
    * @return
    *          a List<List<SoftCluster>> of clusters produced per iteration
    */
-  public static List<List<SoftCluster>> clusterPoints(List<Vector> points,
-                                                      List<SoftCluster> clusters,
-                                                      DistanceMeasure measure,
-                                                      double threshold,
-                                                      double m,
-                                                      int numIter) {
+  public static List<List<SoftCluster>> clusterPoints(List<Vector> points, List<SoftCluster> clusters, DistanceMeasure measure,
+      double threshold, double m, int numIter) {
     List<List<SoftCluster>> clustersList = new ArrayList<List<SoftCluster>>();
     clustersList.add(clusters);
     FuzzyKMeansClusterer clusterer = new FuzzyKMeansClusterer(measure, threshold, m);
@@ -243,9 +235,7 @@ public class FuzzyKMeansClusterer {
    *          the List<Cluster> clusters
    * @return
    */
-  public static boolean runFuzzyKMeansIteration(List<Vector> points,
-                                                List<SoftCluster> clusterList,
-                                                FuzzyKMeansClusterer clusterer) {
+  public static boolean runFuzzyKMeansIteration(List<Vector> points, List<SoftCluster> clusterList, FuzzyKMeansClusterer clusterer) {
     for (Vector point : points) {
       List<Double> clusterDistanceList = new ArrayList<Double>();
       for (SoftCluster cluster : clusterList) {

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java Thu Apr 29 16:16:10 2010
@@ -17,7 +17,6 @@
 
 package org.apache.mahout.clustering.fuzzykmeans;
 
-import java.io.File;
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
@@ -46,11 +45,11 @@ import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.SequenceFileInputFormat;
 import org.apache.hadoop.mapred.SequenceFileOutputFormat;
 import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
 import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
-import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -348,7 +347,7 @@ public final class FuzzyKMeansDriver {
     conf.setJobName("Fuzzy K Means Clustering");
     
     conf.setOutputKeyClass(IntWritable.class);
-    conf.setOutputValueClass(VectorWritable.class);
+    conf.setOutputValueClass(WeightedVectorWritable.class);
     
     FileInputFormat.setInputPaths(conf, new Path(input));
     Path outPath = new Path(output);

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterMapper.java Thu Apr 29 16:16:10 2010
@@ -28,12 +28,12 @@ import org.apache.hadoop.mapred.MapReduc
 import org.apache.hadoop.mapred.Mapper;
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.math.NamedVector;
 import org.apache.mahout.math.VectorWritable;
 
 public class KMeansClusterMapper extends MapReduceBase implements
-    Mapper<WritableComparable<?>,VectorWritable,IntWritable,VectorWritable> {
+    Mapper<WritableComparable<?>,VectorWritable,IntWritable,WeightedVectorWritable> {
   
   private final List<Cluster> clusters = new ArrayList<Cluster>();
   private KMeansClusterer clusterer;
@@ -41,7 +41,7 @@ public class KMeansClusterMapper extends
   @Override
   public void map(WritableComparable<?> key,
                   VectorWritable point,
-                  OutputCollector<IntWritable,VectorWritable> output,
+                  OutputCollector<IntWritable,WeightedVectorWritable> output,
                   Reporter reporter) throws IOException {
     clusterer.outputPointWithClusterInfo(point.get(), clusters, output);
   }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java Thu Apr 29 16:16:10 2010
@@ -23,8 +23,8 @@ import java.util.List;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.math.NamedVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
@@ -35,12 +35,12 @@ import org.slf4j.LoggerFactory;
  * representation. The class can be used as part of a clustering job to be started as map/reduce job.
  * */
 public class KMeansClusterer {
-  
+
   private static final Logger log = LoggerFactory.getLogger(KMeansClusterer.class);
-  
+
   /** Distance to use for point to cluster comparison. */
   private final DistanceMeasure measure;
-  
+
   /**
    * Init the k-means clusterer with the distance measure to use for comparison.
    * 
@@ -51,7 +51,7 @@ public class KMeansClusterer {
   public KMeansClusterer(DistanceMeasure measure) {
     this.measure = measure;
   }
-  
+
   /**
    * Iterates over all clusters and identifies the one closes to the given point. Distance measure used is
    * configured at creation time of .
@@ -61,9 +61,8 @@ public class KMeansClusterer {
    * @param clusters
    *          a List<Cluster> to test.
    */
-  public void emitPointToNearestCluster(Vector point,
-                                        List<Cluster> clusters,
-                                        OutputCollector<Text,KMeansInfo> output) throws IOException {
+  public void emitPointToNearestCluster(Vector point, List<Cluster> clusters, OutputCollector<Text, KMeansInfo> output)
+      throws IOException {
     Cluster nearestCluster = null;
     double nearestDistance = Double.MAX_VALUE;
     for (Cluster cluster : clusters) {
@@ -80,10 +79,9 @@ public class KMeansClusterer {
     // emit only clusterID
     output.collect(new Text(nearestCluster.getIdentifier()), new KMeansInfo(1, point));
   }
-  
-  public void outputPointWithClusterInfo(Vector vector,
-                                         List<Cluster> clusters,
-                                         OutputCollector<IntWritable,VectorWritable> output) throws IOException {
+
+  public void outputPointWithClusterInfo(Vector vector, List<Cluster> clusters,
+      OutputCollector<IntWritable, WeightedVectorWritable> output) throws IOException {
     Cluster nearestCluster = null;
     double nearestDistance = Double.MAX_VALUE;
     for (Cluster cluster : clusters) {
@@ -94,10 +92,10 @@ public class KMeansClusterer {
         nearestDistance = distance;
       }
     }
-    
-    output.collect(new IntWritable(nearestCluster.getId()), new VectorWritable(vector));
+
+    output.collect(new IntWritable(nearestCluster.getId()), new WeightedVectorWritable(1, new VectorWritable(vector)));
   }
-  
+
   /**
    * This is the reference k-means implementation. Given its inputs it iterates over the points and clusters
    * until their centers converge or until the maximum number of iterations is exceeded.
@@ -111,14 +109,11 @@ public class KMeansClusterer {
    * @param maxIter
    *          the maximum number of iterations
    */
-  public static List<List<Cluster>> clusterPoints(List<Vector> points,
-                                                  List<Cluster> clusters,
-                                                  DistanceMeasure measure,
-                                                  int maxIter,
-                                                  double distanceThreshold) {
+  public static List<List<Cluster>> clusterPoints(List<Vector> points, List<Cluster> clusters, DistanceMeasure measure,
+      int maxIter, double distanceThreshold) {
     List<List<Cluster>> clustersList = new ArrayList<List<Cluster>>();
     clustersList.add(clusters);
-    
+
     boolean converged = false;
     int iteration = 0;
     while (!converged && iteration < maxIter) {
@@ -133,7 +128,7 @@ public class KMeansClusterer {
     }
     return clustersList;
   }
-  
+
   /**
    * Perform a single iteration over the points and clusters, assigning points to clusters and returning if
    * the iterations are completed.
@@ -146,10 +141,8 @@ public class KMeansClusterer {
    *          a DistanceMeasure to use
    * @return
    */
-  public static boolean runKMeansIteration(List<Vector> points,
-                                           List<Cluster> clusters,
-                                           DistanceMeasure measure,
-                                           double distanceThreshold) {
+  public static boolean runKMeansIteration(List<Vector> points, List<Cluster> clusters, DistanceMeasure measure,
+      double distanceThreshold) {
     // iterate through all points, assigning each to the nearest cluster
     for (Vector point : points) {
       Cluster closestCluster = null;
@@ -178,5 +171,5 @@ public class KMeansClusterer {
     }
     return converged;
   }
-  
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Thu Apr 29 16:16:10 2010
@@ -40,6 +40,7 @@ import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.SequenceFileInputFormat;
 import org.apache.hadoop.mapred.SequenceFileOutputFormat;
 import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
@@ -303,11 +304,8 @@ public final class KMeansDriver {
     conf.setInputFormat(SequenceFileInputFormat.class);
     conf.setOutputFormat(SequenceFileOutputFormat.class);
     
-    conf.setMapOutputKeyClass(IntWritable.class);
-    conf.setMapOutputValueClass(VectorWritable.class);
     conf.setOutputKeyClass(IntWritable.class);
-    // the output is the cluster id
-    conf.setOutputValueClass(VectorWritable.class);
+    conf.setOutputValueClass(WeightedVectorWritable.class);
     
     FileInputFormat.setInputPaths(conf, new Path(input));
     Path outPath = new Path(output);

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterMapper.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterMapper.java Thu Apr 29 16:16:10 2010
@@ -35,27 +35,27 @@ import org.apache.hadoop.mapred.Mapper;
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.hadoop.mapred.OutputLogFilter;
 import org.apache.hadoop.mapred.Reporter;
-import org.apache.mahout.clustering.ClusterBase;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.math.VectorWritable;
 
 public class MeanShiftCanopyClusterMapper extends MapReduceBase implements
-    Mapper<WritableComparable<?>, MeanShiftCanopy, IntWritable, VectorWritable> {
+    Mapper<WritableComparable<?>, MeanShiftCanopy, IntWritable, WeightedVectorWritable> {
 
   private MeanShiftCanopyClusterer clusterer;
 
-  private OutputCollector<IntWritable, VectorWritable> output;
+  private OutputCollector<IntWritable, WeightedVectorWritable> output;
 
   private List<MeanShiftCanopy> canopies;
 
   @Override
-  public void map(WritableComparable<?> key, MeanShiftCanopy vector, OutputCollector<IntWritable, VectorWritable> output,
+  public void map(WritableComparable<?> key, MeanShiftCanopy vector, OutputCollector<IntWritable, WeightedVectorWritable> output,
       Reporter reporter) throws IOException {
     int vectorId = vector.getId();
     for (MeanShiftCanopy msc : canopies) {
       for (int containedId : msc.getBoundPoints().toList()) {
         if (vectorId == containedId) {
           // System.out.println(msc.getId() + ": v" + vectorId + "=" + ClusterBase.formatVector(vector.getCenter(), null));
-          output.collect(new IntWritable(msc.getId()), new VectorWritable(vector.getCenter()));
+          output.collect(new IntWritable(msc.getId()), new WeightedVectorWritable(1, new VectorWritable(vector.getCenter())));
         }
       }
     }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java Thu Apr 29 16:16:10 2010
@@ -37,12 +37,10 @@ import org.apache.hadoop.mapred.JobClien
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.mapred.SequenceFileInputFormat;
 import org.apache.hadoop.mapred.SequenceFileOutputFormat;
-import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterMapper;
-import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansConfigKeys;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver;
 import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -205,7 +203,7 @@ public final class MeanShiftCanopyDriver
     conf.setJobName("Mean Shift Clustering");
     
     conf.setOutputKeyClass(IntWritable.class);
-    conf.setOutputValueClass(VectorWritable.class);
+    conf.setOutputValueClass(WeightedVectorWritable.class);
     
     FileInputFormat.setInputPaths(conf, new Path(input));
     Path outPath = new Path(output);

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java Thu Apr 29 16:16:10 2010
@@ -31,9 +31,8 @@ import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.Reducer;
-import org.apache.hadoop.mapred.lib.IdentityReducer;
 import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.DummyOutputCollector;
 import org.apache.mahout.common.DummyReporter;
 import org.apache.mahout.common.MahoutTestCase;
@@ -410,7 +409,7 @@ public class TestCanopyCreation extends 
     mapper.configure(conf);
     
     List<Canopy> canopies = new ArrayList<Canopy>();
-    DummyOutputCollector<IntWritable,VectorWritable> collector = new DummyOutputCollector<IntWritable,VectorWritable>();
+    DummyOutputCollector<IntWritable,WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable,WeightedVectorWritable>();
     int nextCanopyId = 0;
     for (Vector centroid : manhattanCentroids) {
       canopies.add(new Canopy(centroid, nextCanopyId++));
@@ -421,14 +420,14 @@ public class TestCanopyCreation extends 
     for (VectorWritable point : points) {
       mapper.map(new Text(), point, collector, new DummyReporter());
     }
-    Map<IntWritable, List<VectorWritable>> data = collector.getData();
+    Map<IntWritable, List<WeightedVectorWritable>> data = collector.getData();
     assertEquals("Number of map results", canopies.size(), data.size());
-    for (Entry<IntWritable, List<VectorWritable>> stringListEntry : data.entrySet()) {
+    for (Entry<IntWritable, List<WeightedVectorWritable>> stringListEntry : data.entrySet()) {
       IntWritable key = stringListEntry.getKey();
       Canopy canopy = findCanopy(key.get(), canopies);
-      List<VectorWritable> pts = stringListEntry.getValue();
-      for (VectorWritable ptDef : pts) {
-        assertTrue("Point not in canopy", mapper.canopyCovers(canopy, ptDef.get()));
+      List<WeightedVectorWritable> pts = stringListEntry.getValue();
+      for (WeightedVectorWritable ptDef : pts) {
+        assertTrue("Point not in canopy", mapper.canopyCovers(canopy, ptDef.getVector().get()));
       }
     }
   }
@@ -453,7 +452,7 @@ public class TestCanopyCreation extends 
     mapper.configure(conf);
     
     List<Canopy> canopies = new ArrayList<Canopy>();
-    DummyOutputCollector<IntWritable,VectorWritable> collector = new DummyOutputCollector<IntWritable,VectorWritable>();
+    DummyOutputCollector<IntWritable,WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable,WeightedVectorWritable>();
     int nextCanopyId = 0;
     for (Vector centroid : euclideanCentroids) {
       canopies.add(new Canopy(centroid, nextCanopyId++));
@@ -464,14 +463,14 @@ public class TestCanopyCreation extends 
     for (VectorWritable point : points) {
       mapper.map(new Text(), point, collector, new DummyReporter());
     }
-    Map<IntWritable,List<VectorWritable>> data = collector.getData();
+    Map<IntWritable, List<WeightedVectorWritable>> data = collector.getData();
     assertEquals("Number of map results", canopies.size(), data.size());
-    for (Entry<IntWritable, List<VectorWritable>> stringListEntry : data.entrySet()) {
+    for (Entry<IntWritable, List<WeightedVectorWritable>> stringListEntry : data.entrySet()) {
       IntWritable key = stringListEntry.getKey();
       Canopy canopy = findCanopy(key.get(), canopies);
-      List<VectorWritable> pts = stringListEntry.getValue();
-      for (VectorWritable ptDef : pts) {
-        assertTrue("Point not in canopy", mapper.canopyCovers(canopy, ptDef.get()));
+      List<WeightedVectorWritable> pts = stringListEntry.getValue();
+      for (WeightedVectorWritable ptDef : pts) {
+        assertTrue("Point not in canopy", mapper.canopyCovers(canopy, ptDef.getVector().get()));
       }
     }
   }
@@ -500,10 +499,10 @@ public class TestCanopyCreation extends 
      * while (reader.ready()) { System.out.println(reader.readLine()); count++; }
      */
     IntWritable clusterId = new IntWritable(0);
-    VectorWritable vector = new VectorWritable();
+    WeightedVectorWritable vector = new WeightedVectorWritable();
     while (reader.next(clusterId, vector)) {
       count++;
-      System.out.println("Txt: " + clusterId + " Vec: " + vector.get().asFormatString());
+      System.out.println("Txt: " + clusterId + " Vec: " + vector.getVector().get().asFormatString());
     }
     // the point [3.0,3.0] is covered by both canopies
     assertEquals("number of points", 1 + points.size(), count);
@@ -532,7 +531,7 @@ public class TestCanopyCreation extends 
      * while (reader.ready()) { System.out.println(reader.readLine()); count++; }
      */
     IntWritable canopyId = new IntWritable(0);
-    VectorWritable can = new VectorWritable();
+    WeightedVectorWritable can = new WeightedVectorWritable();
     while (reader.next(canopyId, can)) {
       count++;
     }

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java Thu Apr 29 16:16:10 2010
@@ -31,6 +31,7 @@ import org.apache.hadoop.io.SequenceFile
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.kmeans.TestKmeansClustering;
 import org.apache.mahout.common.DummyOutputCollector;
 import org.apache.mahout.common.DummyReporter;
@@ -87,26 +88,37 @@ public class TestFuzzyKmeansClustering e
   }
 
   private static void computeCluster(List<Vector> points, List<SoftCluster> clusterList, FuzzyKMeansClusterer clusterer,
-      Map<Integer, List<Vector>> pointClusterInfo) {
+      Map<Integer, List<WeightedVectorWritable>> pointClusterInfo) {
 
     for (Vector point : points) {
+      // calculate point distances for all clusters    
       List<Double> clusterDistanceList = new ArrayList<Double>();
-      SoftCluster closestCluster = null;
-      double closestDistance = Double.MAX_VALUE;
       for (SoftCluster cluster : clusterList) {
-        double distance = clusterer.getMeasure().distance(point, cluster.getCenter());
-        if (distance < closestDistance) {
-          closestDistance = distance;
-          closestCluster = cluster;
+        clusterDistanceList.add(clusterer.getMeasure().distance(cluster.getCenter(), point));
+      }
+      // calculate point pdf for all clusters
+      List<Double> clusterPdfList = new ArrayList<Double>();
+      for (int i = 0; i < clusterList.size(); i++) {
+        double probWeight = clusterer.computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
+        clusterPdfList.add(probWeight);
+      }
+      // for now just emit the most likely cluster
+      int clusterId = -1;
+      double clusterPdf = 0;
+      for (int i = 0; i < clusterList.size(); i++) {
+        // System.out.println("cluster-" + clusters.get(i).getId() + "@ " + ClusterBase.formatVector(center, null));
+        double pdf = clusterPdfList.get(i);
+        if (pdf > clusterPdf) {
+          clusterId = clusterList.get(i).getId();
+          clusterPdf = pdf;
         }
-        clusterDistanceList.add(distance);
       }
-      List<Vector> list = pointClusterInfo.get(closestCluster.getId());
+      List<WeightedVectorWritable> list = pointClusterInfo.get(clusterId);
       if (list == null) {
-        list = new ArrayList<Vector>();
-        pointClusterInfo.put(closestCluster.getId(), list);
+        list = new ArrayList<WeightedVectorWritable>();
+        pointClusterInfo.put(clusterId, list);
       }
-      list.add(point);
+      list.add(new WeightedVectorWritable(clusterPdf, new VectorWritable(point)));
       double totalProb = 0;
       for (int i = 0; i < clusterList.size(); i++) {
         SoftCluster cluster = clusterList.get(i);
@@ -118,9 +130,9 @@ public class TestFuzzyKmeansClustering e
 
     for (SoftCluster cluster : clusterList) {
       System.out.println(cluster.asFormatString(null));
-      List<Vector> list = pointClusterInfo.get(cluster.getId());
+      List<WeightedVectorWritable> list = pointClusterInfo.get(cluster.getId());
       if (list != null)
-        for (Vector vector : list) {
+        for (WeightedVectorWritable vector : list) {
           System.out.println("\t" + vector);
         }
     }
@@ -140,7 +152,7 @@ public class TestFuzzyKmeansClustering e
         //cluster.addPoint(cluster.getCenter(), 1);
         clusterList.add(cluster);
       }
-      Map<Integer, List<Vector>> pointClusterInfo = new HashMap<Integer, List<Vector>>();
+      Map<Integer, List<WeightedVectorWritable>> pointClusterInfo = new HashMap<Integer, List<WeightedVectorWritable>>();
       // run reference FuzzyKmeans algorithm
       List<List<SoftCluster>> clusters = FuzzyKMeansClusterer.clusterPoints(points, clusterList, new EuclideanDistanceMeasure(),
           0.001, 2, 2);
@@ -150,7 +162,7 @@ public class TestFuzzyKmeansClustering e
       // iterate for each cluster
       int size = 0;
       for (int cId : pointClusterInfo.keySet()) {
-        List<Vector> pts = pointClusterInfo.get(cId);
+        List<WeightedVectorWritable> pts = pointClusterInfo.get(cId);
         size += pts.size();
       }
       assertEquals("total size", size, points.size());
@@ -222,7 +234,7 @@ public class TestFuzzyKmeansClustering e
       // assertEquals("output dir files?", 4, outFiles.length);
       SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path("output/clusteredPoints/part-00000"), conf);
       IntWritable key = new IntWritable();
-      VectorWritable out = new VectorWritable();
+      WeightedVectorWritable out = new WeightedVectorWritable();
       while (reader.next(key, out)) {
         // make sure we can read all the clusters
       }
@@ -493,7 +505,7 @@ public class TestFuzzyKmeansClustering e
         softCluster.recomputeCenter();
       }
 
-      DummyOutputCollector<IntWritable, VectorWritable> clusterMapperCollector = new DummyOutputCollector<IntWritable, VectorWritable>();
+      DummyOutputCollector<IntWritable, WeightedVectorWritable> clusterMapperCollector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
 
       FuzzyKMeansClusterMapper clusterMapper = new FuzzyKMeansClusterMapper();
       clusterMapper.config(reducerCluster);
@@ -511,7 +523,7 @@ public class TestFuzzyKmeansClustering e
         Vector vec = tweakValue(points.get(i).get());
         reference.add(new SoftCluster(vec, i));
       }
-      Map<Integer, List<Vector>> refClusters = new HashMap<Integer, List<Vector>>();
+      Map<Integer, List<WeightedVectorWritable>> refClusters = new HashMap<Integer, List<WeightedVectorWritable>>();
       List<Vector> pointsVectors = new ArrayList<Vector>();
       for (VectorWritable point : points) {
         pointsVectors.add((Vector) point.get());
@@ -532,7 +544,7 @@ public class TestFuzzyKmeansClustering e
       // make sure all points are allocated to a cluster
       int size = 0;
       for (int cId : refClusters.keySet()) {
-        List<Vector> pts = refClusters.get(cId);
+        List<WeightedVectorWritable> pts = refClusters.get(cId);
         size += pts.size();
       }
       assertEquals("total size", size, points.size());

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java Thu Apr 29 16:16:10 2010
@@ -31,6 +31,7 @@ import org.apache.hadoop.io.SequenceFile
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.canopy.CanopyDriver;
 import org.apache.mahout.common.DummyOutputCollector;
 import org.apache.mahout.common.DummyReporter;
@@ -370,15 +371,15 @@ public class TestKmeansClustering extend
       // assertEquals("output dir files?", 4, outFiles.length);
       SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path("output/clusteredPoints/part-00000"), conf);
       int[] expect = expectedNumPoints[k];
-      DummyOutputCollector<IntWritable, VectorWritable> collector = new DummyOutputCollector<IntWritable, VectorWritable>();
+      DummyOutputCollector<IntWritable, WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
       // The key is the clusterId
       IntWritable clusterId = new IntWritable(0);
-      // The value is the vector
-      VectorWritable value = new VectorWritable();
+      // The value is the weighted vector
+      WeightedVectorWritable value = new WeightedVectorWritable();
       while (reader.next(clusterId, value)) {
         collector.collect(clusterId, value);
         clusterId = new IntWritable(0);
-        value = new VectorWritable();
+        value = new WeightedVectorWritable();
 
       }
       reader.close();
@@ -418,17 +419,17 @@ public class TestKmeansClustering extend
     assertTrue("output dir exists?", outDir.exists());
     String[] outFiles = outDir.list();
     assertEquals("output dir files?", 4, outFiles.length);
-    DummyOutputCollector<IntWritable, VectorWritable> collector = new DummyOutputCollector<IntWritable, VectorWritable>();
+    DummyOutputCollector<IntWritable, WeightedVectorWritable> collector = new DummyOutputCollector<IntWritable, WeightedVectorWritable>();
     SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path("output/clusteredPoints/part-00000"), conf);
 
     // The key is the clusterId
     IntWritable clusterId = new IntWritable(0);
     // The value is the vector
-    VectorWritable value = new VectorWritable();
+    WeightedVectorWritable value = new WeightedVectorWritable();
     while (reader.next(clusterId, value)) {
       collector.collect(clusterId, value);
       clusterId = new IntWritable(0);
-      value = new VectorWritable();
+      value = new WeightedVectorWritable();
 
     }
     reader.close();

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java Thu Apr 29 16:16:10 2010
@@ -35,35 +35,32 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 class DisplayMeanShift extends DisplayDirichlet {
-  
+
   private static final Logger log = LoggerFactory.getLogger(DisplayMeanShift.class);
-  
-  private static final MeanShiftCanopyClusterer clusterer = new MeanShiftCanopyClusterer(
-      new EuclideanDistanceMeasure(), 1.0, 0.05, 0.5);
+
   private static List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
-  
+
+  private static double t1, t2;
+
   private DisplayMeanShift() {
     initialize();
     this.setTitle("Canopy Clusters (> 1.5% of population)");
   }
-  
-  // TODO this is never queried?
-  // private static final List<List<Vector>> iterationCenters = new ArrayList<List<Vector>>();
-  
+
   @Override
   public void paint(Graphics g) {
     Graphics2D g2 = (Graphics2D) g;
     double sx = (double) res / ds;
     g2.setTransform(AffineTransform.getScaleInstance(sx, sx));
-    
+
     // plot the axes
     g2.setColor(Color.BLACK);
     Vector dv = new DenseVector(2).assign(size / 2.0);
-    Vector dv1 = new DenseVector(2).assign(DisplayMeanShift.clusterer.getT1());
-    Vector dv2 = new DenseVector(2).assign(DisplayMeanShift.clusterer.getT2());
+    Vector dv1 = new DenseVector(2).assign(t1);
+    Vector dv2 = new DenseVector(2).assign(t2);
     DisplayDirichlet.plotRectangle(g2, new DenseVector(2).assign(2), dv);
     DisplayDirichlet.plotRectangle(g2, new DenseVector(2).assign(-2), dv);
-    
+
     // plot the sample data
     g2.setColor(Color.DARK_GRAY);
     dv.assign(0.03);
@@ -72,7 +69,7 @@ class DisplayMeanShift extends DisplayDi
     }
     int i = 0;
     for (MeanShiftCanopy canopy : canopies) {
-      if (canopy.getBoundPoints().size() > 0.015 * DisplayDirichlet.sampleData.size()) {
+      if (canopy.getBoundPoints().toList().size() > 0.015 * DisplayDirichlet.sampleData.size()) {
         g2.setColor(colors[Math.min(i++, DisplayDirichlet.colors.length - 1)]);
         for (int v : canopy.getBoundPoints().elements()) {
           DisplayDirichlet.plotRectangle(g2, sampleData.get(v).get(), dv);
@@ -82,7 +79,7 @@ class DisplayMeanShift extends DisplayDi
       }
     }
   }
-  
+
   public static void main(String[] args) {
     RandomUtils.useTestSeed();
     DisplayDirichlet.generateSamples();
@@ -90,14 +87,15 @@ class DisplayMeanShift extends DisplayDi
     for (VectorWritable sample : sampleData) {
       points.add(sample.get());
     }
-    canopies = MeanShiftCanopyClusterer.clusterPoints(points, new EuclideanDistanceMeasure(), 0.5, 1.0, 0.05,
-      10);
+    t1 = 1.5;
+    t2 = 0.5;
+    canopies = MeanShiftCanopyClusterer.clusterPoints(points, new EuclideanDistanceMeasure(), 0.005, t1, t2, 20);
     for (MeanShiftCanopy canopy : canopies) {
       log.info(canopy.toString());
     }
     new DisplayMeanShift();
   }
-  
+
   static void generateResults() {
     DisplayDirichlet.generateResults(new NormalModelDistribution());
   }

Modified: lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDriver.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDriver.java (original)
+++ lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDriver.java Thu Apr 29 16:16:10 2010
@@ -43,7 +43,7 @@ import org.apache.hadoop.mapred.Sequence
 import org.apache.hadoop.mapred.SequenceFileOutputFormat;
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.ClusterBase;
-import org.apache.mahout.clustering.WeightedPointWritable;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.clustering.dirichlet.DirichletCluster;
 import org.apache.mahout.clustering.kmeans.KMeansDriver;
 import org.apache.mahout.common.CommandLineUtil;
@@ -198,7 +198,7 @@ public class CDbwDriver {
     conf.setOutputKeyClass(IntWritable.class);
     conf.setOutputValueClass(VectorWritable.class);
     conf.setMapOutputKeyClass(IntWritable.class);
-    conf.setMapOutputValueClass(WeightedPointWritable.class);
+    conf.setMapOutputValueClass(WeightedVectorWritable.class);
 
     FileInputFormat.setInputPaths(conf, new Path(input));
     Path outPath = new Path(stateOut);

Modified: lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java (original)
+++ lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java Thu Apr 29 16:16:10 2010
@@ -35,37 +35,38 @@ import org.apache.hadoop.mapred.Mapper;
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.hadoop.mapred.OutputLogFilter;
 import org.apache.hadoop.mapred.Reporter;
-import org.apache.mahout.clustering.WeightedPointWritable;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
 import org.apache.mahout.math.VectorWritable;
 
-public class CDbwMapper extends MapReduceBase implements Mapper<IntWritable, VectorWritable, IntWritable, WeightedPointWritable> {
+public class CDbwMapper extends MapReduceBase implements
+    Mapper<IntWritable, WeightedVectorWritable, IntWritable, WeightedVectorWritable> {
 
   private Map<Integer, List<VectorWritable>> representativePoints;
 
-  private Map<Integer, WeightedPointWritable> mostDistantPoints = new HashMap<Integer, WeightedPointWritable>();
+  private Map<Integer, WeightedVectorWritable> mostDistantPoints = new HashMap<Integer, WeightedVectorWritable>();
 
   private DistanceMeasure measure = new EuclideanDistanceMeasure();
 
-  private OutputCollector<IntWritable, WeightedPointWritable> output = null;
+  private OutputCollector<IntWritable, WeightedVectorWritable> output = null;
 
   @Override
-  public void map(IntWritable clusterId, VectorWritable point, OutputCollector<IntWritable, WeightedPointWritable> output,
+  public void map(IntWritable clusterId, WeightedVectorWritable point, OutputCollector<IntWritable, WeightedVectorWritable> output,
       Reporter reporter) throws IOException {
 
     this.output = output;
 
     int key = clusterId.get();
-    WeightedPointWritable currentMDP = mostDistantPoints.get(key);
+    WeightedVectorWritable currentMDP = mostDistantPoints.get(key);
 
     List<VectorWritable> refPoints = representativePoints.get(key);
     double totalDistance = 0.0;
     for (VectorWritable refPoint : refPoints) {
-      totalDistance += measure.distance(refPoint.get(), point.get());
+      totalDistance += measure.distance(refPoint.get(), point.getVector().get());
     }
     if (currentMDP == null || currentMDP.getWeight() < totalDistance) {
-      mostDistantPoints.put(key, new WeightedPointWritable(totalDistance, new VectorWritable(point.get().clone())));
+      mostDistantPoints.put(key, new WeightedVectorWritable(totalDistance, new VectorWritable(point.getVector().get().clone())));
     }
   }
 

Modified: lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java?rev=939360&r1=939359&r2=939360&view=diff
==============================================================================
--- lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java (original)
+++ lucene/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwReducer.java Thu Apr 29 16:16:10 2010
@@ -29,28 +29,28 @@ import org.apache.hadoop.mapred.MapReduc
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.hadoop.mapred.Reducer;
 import org.apache.hadoop.mapred.Reporter;
-import org.apache.mahout.clustering.WeightedPointWritable;
+import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.math.VectorWritable;
 
-public class CDbwReducer extends MapReduceBase implements Reducer<IntWritable, WeightedPointWritable, IntWritable, VectorWritable> {
+public class CDbwReducer extends MapReduceBase implements Reducer<IntWritable, WeightedVectorWritable, IntWritable, VectorWritable> {
 
   private Map<Integer, List<VectorWritable>> referencePoints;
 
   private OutputCollector<IntWritable, VectorWritable> output;
 
   @Override
-  public void reduce(IntWritable key, Iterator<WeightedPointWritable> values, OutputCollector<IntWritable, VectorWritable> output,
+  public void reduce(IntWritable key, Iterator<WeightedVectorWritable> values, OutputCollector<IntWritable, VectorWritable> output,
       Reporter reporter) throws IOException {
     this.output = output;
     // find the most distant point
-    WeightedPointWritable mdp = null;
+    WeightedVectorWritable mdp = null;
     while (values.hasNext()) {
-      WeightedPointWritable dpw = values.next();
+      WeightedVectorWritable dpw = values.next();
       if (mdp == null || mdp.getWeight() < dpw.getWeight()) {
-        mdp = new WeightedPointWritable(dpw.getWeight(), dpw.getPoint());
+        mdp = new WeightedVectorWritable(dpw.getWeight(), dpw.getVector());
       }
     }
-    output.collect(new IntWritable(key.get()), mdp.getPoint());
+    output.collect(new IntWritable(key.get()), mdp.getVector());
   }
 
   public void configure(Map<Integer, List<VectorWritable>> referencePoints) {