You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by df...@apache.org on 2013/06/16 18:21:09 UTC
svn commit: r1493527 - in /mahout/trunk: ./
core/src/main/java/org/apache/mahout/clustering/
core/src/main/java/org/apache/mahout/clustering/streaming/cluster/
core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/
Author: dfilimon
Date: Sun Jun 16 16:21:09 2013
New Revision: 1493527
URL: http://svn.apache.org/r1493527
Log:
MAHOUT-1254: Final round of cleanup for StreamingKMeans
Modified:
mahout/trunk/CHANGELOG
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Sun Jun 16 16:21:09 2013
@@ -2,6 +2,8 @@ Mahout Change Log
Release 0.8 - unreleased
+ MAHOUT-1254: Final round of cleanup for StreamingKMeans (dfilimon)
+
MAHOUT-1263: Serialise/Deserialise Lambda value for OnlineLogisticRegression (Mike Davy via smarthi)
MAHOUT-1258: Another shot at findbugs and checkstyle (ssc)
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java Sun Jun 16 16:21:09 2013
@@ -91,15 +91,13 @@ public final class ClusteringUtils {
* @return the minimum distance between the first sampleLimit points
* @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans#clusterInternal(Iterable, boolean)
*/
- public static double estimateDistanceCutoff(Iterable<? extends Vector> data,
- DistanceMeasure distanceMeasure, int sampleLimit) {
- Iterable<? extends Vector> limitedData = Iterables.limit(data, sampleLimit);
- ProjectionSearch searcher = new ProjectionSearch(distanceMeasure, 3, 1);
- searcher.add(limitedData.iterator().next());
+ public static double estimateDistanceCutoff(List<? extends Vector> data, DistanceMeasure distanceMeasure) {
+ BruteSearch searcher = new BruteSearch(distanceMeasure);
+ searcher.addAll(data);
double minDistance = Double.POSITIVE_INFINITY;
- for (Vector vector : Iterables.skip(limitedData, 1)) {
- double closest = searcher.searchFirst(vector, false).getWeight();
- if (closest < minDistance) {
+ for (Vector vector : data) {
+ double closest = searcher.searchFirst(vector, true).getWeight();
+ if (minDistance > 0 && closest < minDistance) {
minDistance = closest;
}
searcher.add(vector);
@@ -107,6 +105,11 @@ public final class ClusteringUtils {
return minDistance;
}
+ public static double estimateDistanceCutoff(Iterable<? extends Vector> data, DistanceMeasure distanceMeasure,
+ int sampleLimit) {
+ return estimateDistanceCutoff(Lists.newArrayList(Iterables.limit(data, sampleLimit)), distanceMeasure);
+ }
+
/**
* Computes the Davies-Bouldin Index for a given clustering.
* See http://en.wikipedia.org/wiki/Clustering_algorithm#Internal_evaluation
@@ -241,6 +244,7 @@ public final class ClusteringUtils {
int numRows = confusionMatrix.numRows();
int numCols = confusionMatrix.numCols();
double rowChoiceSum = 0;
+ double columnChoiceSum = 0;
double totalChoiceSum = 0;
double total = 0;
for (int i = 0; i < numRows; ++i) {
@@ -252,7 +256,6 @@ public final class ClusteringUtils {
total += rowSum;
rowChoiceSum += choose2(rowSum);
}
- double columnChoiceSum = 0;
for (int j = 0; j < numCols; ++j) {
double columnSum = 0;
for (int i = 0; i < numRows; ++i) {
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java Sun Jun 16 16:21:09 2013
@@ -20,7 +20,6 @@ package org.apache.mahout.clustering.str
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
-import java.util.NoSuchElementException;
import java.util.Random;
import com.google.common.base.Function;
@@ -28,6 +27,7 @@ import com.google.common.collect.Iterabl
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.MatrixSlice;
@@ -128,7 +128,7 @@ public class StreamingKMeans implements
/**
* Random object to sample values from.
*/
- private final Random random = RandomUtils.getRandom();
+ private Random random = RandomUtils.getRandom();
/**
* Calls StreamingKMeans(searcher, numClusters, 1.3, 10, 2).
@@ -232,9 +232,6 @@ public class StreamingKMeans implements
@Override
public Centroid next() {
- if (!hasNext()) {
- throw new NoSuchElementException();
- }
accessed = true;
return datapoint;
}
@@ -360,4 +357,12 @@ public class StreamingKMeans implements
public double getDistanceCutoff() {
return distanceCutoff;
}
+
+ public void setDistanceCutoff(double distanceCutoff) {
+ this.distanceCutoff = distanceCutoff;
+ }
+
+ public DistanceMeasure getDistanceMeasure() {
+ return centroids.getDistanceMeasure();
+ }
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java Sun Jun 16 16:21:09 2013
@@ -141,13 +141,14 @@ public final class StreamingKMeansDriver
/**
* Whether to run another pass of StreamingKMeans on the reducer's points before BallKMeans. On some data sets
- * with a large number of mappers, the intermediate number of
+ * with a large number of mappers, the intermediate number of clusters passed to the reducer is too large to
+ * fit into memory directly, hence the option to collapse the clusters further with StreamingKMeans.
*/
public static final String REDUCE_STREAMING_KMEANS = "reduceStreamingKMeans";
private static final Logger log = LoggerFactory.getLogger(StreamingKMeansDriver.class);
- private static final double INVALID_DISTANCE_CUTOFF = -1;
+ public static final float INVALID_DISTANCE_CUTOFF = -1;
@Override
public int run(String[] args) throws Exception {
@@ -405,7 +406,8 @@ public final class StreamingKMeansDriver
* @return 0 on success, -1 on failure.
*/
@SuppressWarnings("unchecked")
- public static int run(Configuration conf, Path input, Path output) throws Exception {
+ public static int run(Configuration conf, Path input, Path output)
+ throws IOException, InterruptedException, ClassNotFoundException, ExecutionException {
log.info("Starting StreamingKMeans clustering for vectors in {}; results are output to {}",
input.toString(), output.toString());
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java Sun Jun 16 16:21:09 2013
@@ -18,17 +18,22 @@
package org.apache.mahout.clustering.streaming.mapreduce;
import java.io.IOException;
+import java.util.List;
+import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.ClusteringUtils;
import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
public class StreamingKMeansMapper extends Mapper<Writable, VectorWritable, IntWritable, CentroidWritable> {
+ private static final int NUM_ESTIMATE_POINTS = 1000;
+
/**
* The clusterer object used to cluster the points received by this mapper online.
*/
@@ -39,6 +44,10 @@ public class StreamingKMeansMapper exten
*/
private int numPoints = 0;
+ private boolean estimateDistanceCutoff = false;
+
+ private List<Centroid> estimatePoints;
+
@Override
public void setup(Context context) {
// At this point the configuration received from the Driver is assumed to be valid.
@@ -46,18 +55,43 @@ public class StreamingKMeansMapper exten
Configuration conf = context.getConfiguration();
UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
+ double estimatedDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF,
+ StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF);
+ if (estimatedDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) {
+ estimateDistanceCutoff = true;
+ estimatePoints = Lists.newArrayList();
+ }
// There is no way of estimating the distance cutoff unless we have some data.
- clusterer = new StreamingKMeans(searcher, numClusters,
- conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, 1.0e-4f));
+ clusterer = new StreamingKMeans(searcher, numClusters, estimatedDistanceCutoff);
+ }
+
+ private void clusterEstimatePoints() {
+ clusterer.setDistanceCutoff(ClusteringUtils.estimateDistanceCutoff(
+ estimatePoints, clusterer.getDistanceMeasure()));
+ clusterer.cluster(estimatePoints);
+ estimateDistanceCutoff = false;
}
@Override
public void map(Writable key, VectorWritable point, Context context) {
- clusterer.cluster(new Centroid(numPoints++, point.get().clone(), 1));
+ Centroid centroid = new Centroid(numPoints++, point.get(), 1);
+ if (estimateDistanceCutoff) {
+ if (numPoints < NUM_ESTIMATE_POINTS) {
+ estimatePoints.add(centroid);
+ } else if (numPoints == NUM_ESTIMATE_POINTS) {
+ clusterEstimatePoints();
+ }
+ } else {
+ clusterer.cluster(centroid);
+ }
}
@Override
public void cleanup(Context context) throws IOException, InterruptedException {
+ // We should cluster the points at the end if they haven't yet been clustered.
+ if (estimateDistanceCutoff) {
+ clusterEstimatePoints();
+ }
// Reindex the centroids before passing them to the reducer.
clusterer.reindexCentroids();
// All outputs have the same key to go to the same final reducer.
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java Sun Jun 16 16:21:09 2013
@@ -28,11 +28,9 @@ import org.apache.hadoop.conf.Configurat
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Reducer;
import org.apache.mahout.clustering.streaming.cluster.BallKMeans;
-import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.math.Centroid;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.neighborhood.UpdatableSearcher;
public class StreamingKMeansReducer extends Reducer<IntWritable, CentroidWritable, IntWritable, CentroidWritable> {
/**
@@ -47,15 +45,6 @@ public class StreamingKMeansReducer exte
conf = context.getConfiguration();
}
- private StreamingKMeans getStreamingKMeans(int numClusters) {
- // At this point the configuration received from the Driver is assumed to be valid.
- // No other checks are made.
- UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
- // There is no way of estimating the distance cutoff unless we have some data.
- return new StreamingKMeans(searcher, numClusters,
- conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, 1.0e-4f));
- }
-
@Override
public void reduce(IntWritable key, Iterable<CentroidWritable> centroids,
Context context) throws IOException, InterruptedException {
@@ -63,16 +52,14 @@ public class StreamingKMeansReducer exte
// There might be too many intermediate centroids to fit into memory, in which case, we run another pass
// of StreamingKMeans to collapse the clusters further.
if (conf.getBoolean(StreamingKMeansDriver.REDUCE_STREAMING_KMEANS, false)) {
- StreamingKMeans clusterer = getStreamingKMeans(conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1));
- clusterer.cluster(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() {
+ intermediateCentroids = Lists.newArrayList(
+ new StreamingKMeansThread(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() {
@Override
- public Centroid apply(org.apache.mahout.clustering.streaming.mapreduce.CentroidWritable input) {
+ public Centroid apply(CentroidWritable input) {
Preconditions.checkNotNull(input);
return input.getCentroid();
}
- }));
- clusterer.reindexCentroids();
- intermediateCentroids = Lists.newArrayList(clusterer);
+ }), conf).call());
} else {
intermediateCentroids = centroidWritablesToList(centroids);
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java Sun Jun 16 16:21:09 2013
@@ -1,7 +1,10 @@
package org.apache.mahout.clustering.streaming.mapreduce;
+import java.util.Iterator;
+import java.util.List;
import java.util.concurrent.Callable;
+import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.mahout.clustering.ClusteringUtils;
@@ -12,6 +15,8 @@ import org.apache.mahout.math.VectorWrit
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
public class StreamingKMeansThread implements Callable<Iterable<Centroid>> {
+ private static final int NUM_ESTIMATE_POINTS = 1000;
+
private Configuration conf;
private Iterable<Centroid> datapoints;
@@ -22,15 +27,25 @@ public class StreamingKMeansThread imple
}
@Override
- public Iterable<Centroid> call() throws Exception {
+ public Iterable<Centroid> call() {
UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
-
double estimateDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF,
- (float) ClusteringUtils.estimateDistanceCutoff(datapoints, searcher.getDistanceMeasure(), 100));
+ StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF);
+
+ Iterator<Centroid> datapointsIterator = datapoints.iterator();
+ if (estimateDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) {
+ List<Centroid> estimatePoints = Lists.newArrayListWithExpectedSize(NUM_ESTIMATE_POINTS);
+ while (datapointsIterator.hasNext() && estimatePoints.size() < NUM_ESTIMATE_POINTS) {
+ estimatePoints.add(datapointsIterator.next());
+ }
+ estimateDistanceCutoff = ClusteringUtils.estimateDistanceCutoff(estimatePoints, searcher.getDistanceMeasure());
+ }
StreamingKMeans clusterer = new StreamingKMeans(searcher, numClusters, estimateDistanceCutoff);
- clusterer.cluster(datapoints);
+ while (datapointsIterator.hasNext()) {
+ clusterer.cluster(datapointsIterator.next());
+ }
clusterer.reindexCentroids();
return clusterer;
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java?rev=1493527&r1=1493526&r2=1493527&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java Sun Jun 16 16:21:09 2013
@@ -24,10 +24,7 @@ import org.apache.mahout.math.neighborho
import org.apache.mahout.math.neighborhood.ProjectionSearch;
import org.apache.mahout.math.neighborhood.UpdatableSearcher;
-public final class StreamingKMeansUtilsMR {
-
- private StreamingKMeansUtilsMR() {
- }
+public class StreamingKMeansUtilsMR {
/**
* Instantiates a searcher from a given configuration.
@@ -40,7 +37,7 @@ public final class StreamingKMeansUtilsM
DistanceMeasure distanceMeasure;
String distanceMeasureClass = conf.get(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
try {
- distanceMeasure = (DistanceMeasure)Class.forName(distanceMeasureClass).getConstructor().newInstance();
+ distanceMeasure = (DistanceMeasure)Class.forName(distanceMeasureClass).newInstance();
} catch (Exception e) {
throw new RuntimeException("Failed to instantiate distanceMeasure", e);
}
@@ -93,7 +90,7 @@ public final class StreamingKMeansUtilsM
* @param input Iterable of Vectors to cast
* @return the new Centroids
*/
- public static Iterable<Centroid> castVectorsToCentroids(Iterable<Vector> input) {
+ public static Iterable<Centroid> castVectorsToCentroids(final Iterable<Vector> input) {
return Iterables.transform(input, new Function<Vector, Centroid>() {
private int numVectors = 0;
@Override
@@ -126,7 +123,7 @@ public final class StreamingKMeansUtilsM
writer.append(new IntWritable(i++), new CentroidWritable(centroid));
}
} finally {
- Closeables.close(writer, false);
+ Closeables.close(writer, true);
}
}
@@ -141,7 +138,7 @@ public final class StreamingKMeansUtilsM
writer.append(new IntWritable(i++), new VectorWritable(vector));
}
} finally {
- Closeables.close(writer, false);
+ Closeables.close(writer, true);
}
}
}