You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by df...@apache.org on 2013/06/16 18:21:09 UTC

svn commit: r1493527 - in /mahout/trunk: ./ core/src/main/java/org/apache/mahout/clustering/ core/src/main/java/org/apache/mahout/clustering/streaming/cluster/ core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/

Author: dfilimon
Date: Sun Jun 16 16:21:09 2013
New Revision: 1493527

URL: http://svn.apache.org/r1493527
Log:
MAHOUT-1254: Final round of cleanup for StreamingKMeans


Modified:
    mahout/trunk/CHANGELOG
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java

Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Sun Jun 16 16:21:09 2013
@@ -2,6 +2,8 @@ Mahout Change Log
 
 Release 0.8 - unreleased
 
+  MAHOUT-1254: Final round of cleanup for StreamingKMeans (dfilimon)
+
   MAHOUT-1263: Serialise/Deserialise Lambda value for OnlineLogisticRegression (Mike Davy via smarthi)  
 
   MAHOUT-1258: Another shot at findbugs and checkstyle (ssc)

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java Sun Jun 16 16:21:09 2013
@@ -91,15 +91,13 @@ public final class ClusteringUtils {
    * @return the minimum distance between the first sampleLimit points
    * @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans#clusterInternal(Iterable, boolean)
    */
-  public static double estimateDistanceCutoff(Iterable<? extends Vector> data,
-                                              DistanceMeasure distanceMeasure, int sampleLimit) {
-    Iterable<? extends Vector> limitedData = Iterables.limit(data, sampleLimit);
-    ProjectionSearch searcher = new ProjectionSearch(distanceMeasure, 3, 1);
-    searcher.add(limitedData.iterator().next());
+  public static double estimateDistanceCutoff(List<? extends Vector> data, DistanceMeasure distanceMeasure) {
+    BruteSearch searcher = new BruteSearch(distanceMeasure);
+    searcher.addAll(data);
     double minDistance = Double.POSITIVE_INFINITY;
-    for (Vector vector : Iterables.skip(limitedData, 1)) {
-      double closest = searcher.searchFirst(vector, false).getWeight();
-      if (closest < minDistance) {
+    for (Vector vector : data) {
+      double closest = searcher.searchFirst(vector, true).getWeight();
+      if (minDistance > 0 && closest < minDistance) {
         minDistance = closest;
       }
       searcher.add(vector);
@@ -107,6 +105,11 @@ public final class ClusteringUtils {
     return minDistance;
   }
 
+  public static double estimateDistanceCutoff(Iterable<? extends Vector> data, DistanceMeasure distanceMeasure,
+                                              int sampleLimit) {
+    return estimateDistanceCutoff(Lists.newArrayList(Iterables.limit(data, sampleLimit)), distanceMeasure);
+  }
+
   /**
    * Computes the Davies-Bouldin Index for a given clustering.
    * See http://en.wikipedia.org/wiki/Clustering_algorithm#Internal_evaluation
@@ -241,6 +244,7 @@ public final class ClusteringUtils {
     int numRows = confusionMatrix.numRows();
     int numCols = confusionMatrix.numCols();
     double rowChoiceSum = 0;
+    double columnChoiceSum = 0;
     double totalChoiceSum = 0;
     double total = 0;
     for (int i = 0; i < numRows; ++i) {
@@ -252,7 +256,6 @@ public final class ClusteringUtils {
       total += rowSum;
       rowChoiceSum += choose2(rowSum);
     }
-    double columnChoiceSum = 0;
     for (int j = 0; j < numCols; ++j) {
       double columnSum = 0;
       for (int i = 0; i < numRows; ++i) {

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java Sun Jun 16 16:21:09 2013
@@ -20,7 +20,6 @@ package org.apache.mahout.clustering.str
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
-import java.util.NoSuchElementException;
 import java.util.Random;
 
 import com.google.common.base.Function;
@@ -28,6 +27,7 @@ import com.google.common.collect.Iterabl
 import com.google.common.collect.Iterators;
 import com.google.common.collect.Lists;
 import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.Centroid;
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.MatrixSlice;
@@ -128,7 +128,7 @@ public class StreamingKMeans implements 
   /**
    * Random object to sample values from.
    */
-  private final Random random = RandomUtils.getRandom();
+  private Random random = RandomUtils.getRandom();
 
   /**
    * Calls StreamingKMeans(searcher, numClusters, 1.3, 10, 2).
@@ -232,9 +232,6 @@ public class StreamingKMeans implements 
 
           @Override
           public Centroid next() {
-            if (!hasNext()) {
-              throw new NoSuchElementException();
-            }
             accessed = true;
             return datapoint;
           }
@@ -360,4 +357,12 @@ public class StreamingKMeans implements 
   public double getDistanceCutoff() {
     return distanceCutoff;
   }
+
+  public void setDistanceCutoff(double distanceCutoff) {
+    this.distanceCutoff = distanceCutoff;
+  }
+
+  public DistanceMeasure getDistanceMeasure() {
+    return centroids.getDistanceMeasure();
+  }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java Sun Jun 16 16:21:09 2013
@@ -141,13 +141,14 @@ public final class StreamingKMeansDriver
 
   /**
    * Whether to run another pass of StreamingKMeans on the reducer's points before BallKMeans. On some data sets
-   * with a large number of mappers, the intermediate number of
+   * with a large number of mappers, the intermediate number of clusters passed to the reducer is too large to
+   * fit into memory directly, hence the option to collapse the clusters further with StreamingKMeans.
    */
   public static final String REDUCE_STREAMING_KMEANS = "reduceStreamingKMeans";
 
   private static final Logger log = LoggerFactory.getLogger(StreamingKMeansDriver.class);
 
-  private static final double INVALID_DISTANCE_CUTOFF = -1;
+  public static final float INVALID_DISTANCE_CUTOFF = -1;
 
   @Override
   public int run(String[] args) throws Exception {
@@ -405,7 +406,8 @@ public final class StreamingKMeansDriver
    * @return 0 on success, -1 on failure.
    */
   @SuppressWarnings("unchecked")
-  public static int run(Configuration conf, Path input, Path output) throws Exception {
+  public static int run(Configuration conf, Path input, Path output)
+      throws IOException, InterruptedException, ClassNotFoundException, ExecutionException {
     log.info("Starting StreamingKMeans clustering for vectors in {}; results are output to {}",
         input.toString(), output.toString());
 

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java Sun Jun 16 16:21:09 2013
@@ -18,17 +18,22 @@
 package org.apache.mahout.clustering.streaming.mapreduce;
 
 import java.io.IOException;
+import java.util.List;
 
+import com.google.common.collect.Lists;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.ClusteringUtils;
 import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
 import org.apache.mahout.math.Centroid;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.neighborhood.UpdatableSearcher;
 
 public class StreamingKMeansMapper extends Mapper<Writable, VectorWritable, IntWritable, CentroidWritable> {
+  private static final int NUM_ESTIMATE_POINTS = 1000;
+
   /**
    * The clusterer object used to cluster the points received by this mapper online.
    */
@@ -39,6 +44,10 @@ public class StreamingKMeansMapper exten
    */
   private int numPoints = 0;
 
+  private boolean estimateDistanceCutoff = false;
+
+  private List<Centroid> estimatePoints;
+
   @Override
   public void setup(Context context) {
     // At this point the configuration received from the Driver is assumed to be valid.
@@ -46,18 +55,43 @@ public class StreamingKMeansMapper exten
     Configuration conf = context.getConfiguration();
     UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
     int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
+    double estimatedDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF,
+        StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF);
+    if (estimatedDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) {
+      estimateDistanceCutoff = true;
+      estimatePoints = Lists.newArrayList();
+    }
     // There is no way of estimating the distance cutoff unless we have some data.
-    clusterer = new StreamingKMeans(searcher, numClusters,
-        conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, 1.0e-4f));
+    clusterer = new StreamingKMeans(searcher, numClusters, estimatedDistanceCutoff);
+  }
+
+  private void clusterEstimatePoints() {
+    clusterer.setDistanceCutoff(ClusteringUtils.estimateDistanceCutoff(
+        estimatePoints, clusterer.getDistanceMeasure()));
+    clusterer.cluster(estimatePoints);
+    estimateDistanceCutoff = false;
   }
 
   @Override
   public void map(Writable key, VectorWritable point, Context context) {
-    clusterer.cluster(new Centroid(numPoints++, point.get().clone(), 1));
+    Centroid centroid = new Centroid(numPoints++, point.get(), 1);
+    if (estimateDistanceCutoff) {
+      if (numPoints < NUM_ESTIMATE_POINTS) {
+        estimatePoints.add(centroid);
+      } else if (numPoints == NUM_ESTIMATE_POINTS) {
+        clusterEstimatePoints();
+  }
+    } else {
+      clusterer.cluster(centroid);
+    }
   }
 
   @Override
   public void cleanup(Context context) throws IOException, InterruptedException {
+    // We should cluster the points at the end if they haven't yet been clustered.
+    if (estimateDistanceCutoff) {
+      clusterEstimatePoints();
+    }
     // Reindex the centroids before passing them to the reducer.
     clusterer.reindexCentroids();
     // All outputs have the same key to go to the same final reducer.

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java Sun Jun 16 16:21:09 2013
@@ -28,11 +28,9 @@ import org.apache.hadoop.conf.Configurat
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.mapreduce.Reducer;
 import org.apache.mahout.clustering.streaming.cluster.BallKMeans;
-import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
 import org.apache.mahout.math.Centroid;
 import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.neighborhood.UpdatableSearcher;
 
 public class StreamingKMeansReducer extends Reducer<IntWritable, CentroidWritable, IntWritable, CentroidWritable> {
   /**
@@ -47,15 +45,6 @@ public class StreamingKMeansReducer exte
     conf = context.getConfiguration();
   }
 
-  private StreamingKMeans getStreamingKMeans(int numClusters) {
-    // At this point the configuration received from the Driver is assumed to be valid.
-    // No other checks are made.
-    UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
-    // There is no way of estimating the distance cutoff unless we have some data.
-    return new StreamingKMeans(searcher, numClusters,
-        conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, 1.0e-4f));
-  }
-
   @Override
   public void reduce(IntWritable key, Iterable<CentroidWritable> centroids,
                      Context context) throws IOException, InterruptedException {
@@ -63,16 +52,14 @@ public class StreamingKMeansReducer exte
     // There might be too many intermediate centroids to fit into memory, in which case, we run another pass
     // of StreamingKMeans to collapse the clusters further.
     if (conf.getBoolean(StreamingKMeansDriver.REDUCE_STREAMING_KMEANS, false)) {
-      StreamingKMeans clusterer = getStreamingKMeans(conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1));
-      clusterer.cluster(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() {
+      intermediateCentroids = Lists.newArrayList(
+          new StreamingKMeansThread(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() {
         @Override
-        public Centroid apply(org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable input) {
+            public Centroid apply(CentroidWritable input) {
           Preconditions.checkNotNull(input);
           return input.getCentroid();
         }
-      }));
-      clusterer.reindexCentroids();
-      intermediateCentroids = Lists.newArrayList(clusterer);
+          }), conf).call());
     } else {
       intermediateCentroids = centroidWritablesToList(centroids);
     }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java Sun Jun 16 16:21:09 2013
@@ -1,7 +1,10 @@
 package org.apache.mahout.clustering.streaming.mapreduce;
 
+import java.util.Iterator;
+import java.util.List;
 import java.util.concurrent.Callable;
 
+import com.google.common.collect.Lists;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.mahout.clustering.ClusteringUtils;
@@ -12,6 +15,8 @@ import org.apache.mahout.math.VectorWrit
 import org.apache.mahout.math.neighborhood.UpdatableSearcher;
 
 public class StreamingKMeansThread implements Callable<Iterable<Centroid>> {
+  private static final int NUM_ESTIMATE_POINTS = 1000;
+
   private Configuration conf;
   private Iterable<Centroid> datapoints;
 
@@ -22,15 +27,25 @@ public class StreamingKMeansThread imple
   }
 
   @Override
-  public Iterable<Centroid> call() throws Exception {
+  public Iterable<Centroid> call() {
     UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
     int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
-
     double estimateDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF,
-        (float) ClusteringUtils.estimateDistanceCutoff(datapoints, searcher.getDistanceMeasure(), 100));
+        StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF);
+
+    Iterator<Centroid> datapointsIterator = datapoints.iterator();
+    if (estimateDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) {
+      List<Centroid> estimatePoints = Lists.newArrayListWithExpectedSize(NUM_ESTIMATE_POINTS);
+      while (datapointsIterator.hasNext() && estimatePoints.size() < NUM_ESTIMATE_POINTS) {
+        estimatePoints.add(datapointsIterator.next());
+      }
+      estimateDistanceCutoff = ClusteringUtils.estimateDistanceCutoff(estimatePoints, searcher.getDistanceMeasure());
+    }
 
     StreamingKMeans clusterer = new StreamingKMeans(searcher, numClusters, estimateDistanceCutoff);
-    clusterer.cluster(datapoints);
+    while (datapointsIterator.hasNext()) {
+      clusterer.cluster(datapointsIterator.next());
+    }
     clusterer.reindexCentroids();
 
     return clusterer;

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java Sun Jun 16 16:21:09 2013
@@ -24,10 +24,7 @@ import org.apache.mahout.math.neighborho
 import org.apache.mahout.math.neighborhood.ProjectionSearch;
 import org.apache.mahout.math.neighborhood.UpdatableSearcher;
 
-public final class StreamingKMeansUtilsMR {
-
-  private StreamingKMeansUtilsMR() {
-  }
+public class StreamingKMeansUtilsMR {
 
   /**
    * Instantiates a searcher from a given configuration.
@@ -40,7 +37,7 @@ public final class StreamingKMeansUtilsM
     DistanceMeasure distanceMeasure;
     String distanceMeasureClass = conf.get(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
     try {
-      distanceMeasure = (DistanceMeasure)Class.forName(distanceMeasureClass).getConstructor().newInstance();
+      distanceMeasure = (DistanceMeasure)Class.forName(distanceMeasureClass).newInstance();
     } catch (Exception e) {
       throw new RuntimeException("Failed to instantiate distanceMeasure", e);
     }
@@ -93,7 +90,7 @@ public final class StreamingKMeansUtilsM
    * @param input Iterable of Vectors to cast
    * @return the new Centroids
    */
-  public static Iterable<Centroid> castVectorsToCentroids(Iterable<Vector> input) {
+  public static Iterable<Centroid> castVectorsToCentroids(final Iterable<Vector> input) {
     return Iterables.transform(input, new Function<Vector, Centroid>() {
       private int numVectors = 0;
       @Override
@@ -126,7 +123,7 @@ public final class StreamingKMeansUtilsM
         writer.append(new IntWritable(i++), new CentroidWritable(centroid));
       }
     } finally {
-      Closeables.close(writer, false);
+      Closeables.close(writer, true);
     }
   }
 
@@ -141,7 +138,7 @@ public final class StreamingKMeansUtilsM
         writer.append(new IntWritable(i++), new VectorWritable(vector));
       }
     } finally {
-      Closeables.close(writer, false);
+      Closeables.close(writer, true);
     }
   }
 }