You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:07:57 UTC
[26/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java
new file mode 100644
index 0000000..25a4022
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java
@@ -0,0 +1,456 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.cluster;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+import org.apache.mahout.math.random.Multinomial;
+import org.apache.mahout.math.random.WeightedThing;
+
+/**
+ * Implements a ball k-means algorithm for weighted vectors with probabilistic seeding similar to k-means++.
+ * The idea is that k-means++ gives good starting clusters and ball k-means can tune up the final result very nicely
+ * in only a few passes (or even in a single iteration for well-clusterable data).
+ *
+ * A good reference for this class of algorithms is "The Effectiveness of Lloyd-Type Methods for the k-Means Problem"
+ * by Rafail Ostrovsky, Yuval Rabani, Leonard J. Schulman and Chaitanya Swamy. The code here uses the seeding strategy
+ * as described in section 4.1.1 of that paper and the ball k-means step as described in section 4.2. We support
+ * multiple iterations in contrast to the algorithm described in the paper.
+ */
+public class BallKMeans implements Iterable<Centroid> {
+ /**
+ * The searcher containing the centroids.
+ */
+ private final UpdatableSearcher centroids;
+
+ /**
+ * The number of clusters to cluster the data into.
+ */
+ private final int numClusters;
+
+ /**
+ * The maximum number of iterations of the algorithm to run waiting for the cluster assignments
+ * to stabilize. If there are no changes in cluster assignment earlier, we can finish early.
+ */
+ private final int maxNumIterations;
+
+ /**
+ * When deciding which points to include in the new centroid calculation,
+ * it's preferable to exclude outliers since it increases the rate of convergence.
+ * So, we calculate the distance from each cluster to its closest neighboring cluster. When
+ * evaluating the points assigned to a cluster, we compare the distance between the centroid to
+ * the point with the distance between the centroid and its closest centroid neighbor
+ * multiplied by this trimFraction. If the distance between the centroid and the point is
+ * greater, we consider it an outlier and we don't use it.
+ */
+ private final double trimFraction;
+
+ /**
+ * Selecting the initial centroids is the most important part of the ball k-means clustering. Poor choices, like two
+ * centroids in the same actual cluster result in a low-quality final result.
+ * k-means++ initialization yields good quality clusters, especially when using BallKMeans after StreamingKMeans as
+ * the points have weights.
+ * Simple, random selection of the points based on their weights is faster but sometimes fails to produce the
+ * desired number of clusters.
+ * This field is true if the initialization should be done with k-means++.
+ */
+ private final boolean kMeansPlusPlusInit;
+
+ /**
+ * When using trimFraction, the weight of each centroid will not be the sum of the weights of
+ * the vectors assigned to that cluster because outliers are not used to compute the updated
+ * centroid.
+ * So, the total weight is probably wrong. This can be fixed by doing another pass over the
+ * data points and adjusting the weights of each centroid. This doesn't update the coordinates
+ * of the centroids, but is useful if the weights matter.
+ */
+ private final boolean correctWeights;
+
+ /**
+ * When running multiple ball k-means passes to get the one with the smallest total cost, can compute the
+ * overall cost, using all the points for clustering, or reserve a fraction of them, testProbability in a test set.
+ * The cost is the sum of the distances between each point and its corresponding centroid.
+ * We then use this set of points to compute the total cost on. We're therefore trying to select the clustering
+ * that best describes the underlying distribution of the clusters.
+ * This field is the probability of assigning a given point to the test set. If this is 0, the cost will be computed
+ * on the entire set of points.
+ */
+ private final double testProbability;
+
+ /**
+ * Whether or not testProbability > 0, i.e., there exists a non-empty 'test' set.
+ */
+ private final boolean splitTrainTest;
+
+ /**
+ * How many k-means runs to have. If there's more than one run, we compute the cost of each clustering as described
+ * above and select the clustering that minimizes the cost.
+ * Multiple runs are a lot more useful when using the random initialization. With kmeans++, 1-2 runs are enough and
+ * more runs are not likely to help quality much.
+ */
+ private final int numRuns;
+
+ /**
+ * Random object to sample values from.
+ */
+ private final Random random;
+
+ public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations) {
+ // By default, the trimFraction is 0.9, k-means++ is used, the weights will be corrected at the end,
+ // there will be 0 points in the test set and 1 run.
+ this(searcher, numClusters, maxNumIterations, 0.9, true, true, 0.0, 1);
+ }
+
+ public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations,
+ boolean kMeansPlusPlusInit, int numRuns) {
+ // By default, the trimFraction is 0.9, k-means++ is used, the weights will be corrected at the end,
+ // there will be 10% points of in the test set.
+ this(searcher, numClusters, maxNumIterations, 0.9, kMeansPlusPlusInit, true, 0.1, numRuns);
+ }
+
+ public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations,
+ double trimFraction, boolean kMeansPlusPlusInit, boolean correctWeights,
+ double testProbability, int numRuns) {
+ Preconditions.checkArgument(searcher.size() == 0, "Searcher must be empty initially to populate with centroids");
+ Preconditions.checkArgument(numClusters > 0, "The requested number of clusters must be positive");
+ Preconditions.checkArgument(maxNumIterations > 0, "The maximum number of iterations must be positive");
+ Preconditions.checkArgument(trimFraction > 0, "The trim fraction must be positive");
+ Preconditions.checkArgument(testProbability >= 0 && testProbability < 1, "The testProbability must be in [0, 1)");
+ Preconditions.checkArgument(numRuns > 0, "There has to be at least one run");
+
+ this.centroids = searcher;
+ this.numClusters = numClusters;
+ this.maxNumIterations = maxNumIterations;
+
+ this.trimFraction = trimFraction;
+ this.kMeansPlusPlusInit = kMeansPlusPlusInit;
+ this.correctWeights = correctWeights;
+
+ this.testProbability = testProbability;
+ this.splitTrainTest = testProbability > 0;
+ this.numRuns = numRuns;
+
+ this.random = RandomUtils.getRandom();
+ }
+
+ public Pair<List<? extends WeightedVector>, List<? extends WeightedVector>> splitTrainTest(
+ List<? extends WeightedVector> datapoints) {
+ // If there will be no points assigned to the test set, return now.
+ if (testProbability == 0) {
+ return new Pair<List<? extends WeightedVector>, List<? extends WeightedVector>>(datapoints,
+ Lists.<WeightedVector>newArrayList());
+ }
+
+ int numTest = (int) (testProbability * datapoints.size());
+ Preconditions.checkArgument(numTest > 0 && numTest < datapoints.size(),
+ "Must have nonzero number of training and test vectors. Asked for %.1f %% of %d vectors for test",
+ testProbability * 100, datapoints.size());
+
+ Collections.shuffle(datapoints);
+ return new Pair<List<? extends WeightedVector>, List<? extends WeightedVector>>(
+ datapoints.subList(numTest, datapoints.size()), datapoints.subList(0, numTest));
+ }
+
+ /**
+ * Clusters the datapoints in the list doing either random seeding of the centroids or k-means++.
+ *
+ * @param datapoints the points to be clustered.
+ * @return an UpdatableSearcher with the resulting clusters.
+ */
+ public UpdatableSearcher cluster(List<? extends WeightedVector> datapoints) {
+ Pair<List<? extends WeightedVector>, List<? extends WeightedVector>> trainTestSplit = splitTrainTest(datapoints);
+ List<Vector> bestCentroids = Lists.newArrayList();
+ double cost = Double.POSITIVE_INFINITY;
+ double bestCost = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < numRuns; ++i) {
+ centroids.clear();
+ if (kMeansPlusPlusInit) {
+ // Use k-means++ to set initial centroids.
+ initializeSeedsKMeansPlusPlus(trainTestSplit.getFirst());
+ } else {
+ // Randomly select the initial centroids.
+ initializeSeedsRandomly(trainTestSplit.getFirst());
+ }
+ // Do k-means iterations with trimmed mean computation (aka ball k-means).
+ if (numRuns > 1) {
+ // If the clustering is successful (there are no zero-weight centroids).
+ iterativeAssignment(trainTestSplit.getFirst());
+ // Compute the cost of the clustering and possibly save the centroids.
+ cost = ClusteringUtils.totalClusterCost(
+ splitTrainTest ? datapoints : trainTestSplit.getSecond(), centroids);
+ if (cost < bestCost) {
+ bestCost = cost;
+ bestCentroids.clear();
+ Iterables.addAll(bestCentroids, centroids);
+ }
+ } else {
+ // If there is only going to be one run, the cost doesn't need to be computed, so we just return the clustering.
+ iterativeAssignment(datapoints);
+ return centroids;
+ }
+ }
+ if (bestCost == Double.POSITIVE_INFINITY) {
+ throw new RuntimeException("No valid clustering was found");
+ }
+ if (cost != bestCost) {
+ centroids.clear();
+ centroids.addAll(bestCentroids);
+ }
+ if (correctWeights) {
+ for (WeightedVector testDatapoint : trainTestSplit.getSecond()) {
+ WeightedVector closest = (WeightedVector) centroids.searchFirst(testDatapoint, false).getValue();
+ closest.setWeight(closest.getWeight() + testDatapoint.getWeight());
+ }
+ }
+ return centroids;
+ }
+
+ /**
+ * Selects some of the original points randomly with probability proportional to their weights. This is much
+ * less sophisticated than the kmeans++ approach, however it is faster and coupled with
+ *
+ * The side effect of this method is to fill the centroids structure itself.
+ *
+ * @param datapoints The datapoints to select from. These datapoints should be WeightedVectors of some kind.
+ */
+ private void initializeSeedsRandomly(List<? extends WeightedVector> datapoints) {
+ int numDatapoints = datapoints.size();
+ double totalWeight = 0;
+ for (WeightedVector datapoint : datapoints) {
+ totalWeight += datapoint.getWeight();
+ }
+ Multinomial<Integer> seedSelector = new Multinomial<>();
+ for (int i = 0; i < numDatapoints; ++i) {
+ seedSelector.add(i, datapoints.get(i).getWeight() / totalWeight);
+ }
+ for (int i = 0; i < numClusters; ++i) {
+ int sample = seedSelector.sample();
+ seedSelector.delete(sample);
+ Centroid centroid = new Centroid(datapoints.get(sample));
+ centroid.setIndex(i);
+ centroids.add(centroid);
+ }
+ }
+
+ /**
+ * Selects some of the original points according to the k-means++ algorithm. The basic idea is that
+ * points are selected with probability proportional to their distance from any selected point. In
+ * this version, points have weights which multiply their likelihood of being selected. This is the
+ * same as if there were as many copies of the same point as indicated by the weight.
+ *
+ * This is pretty expensive, but it vastly improves the quality and convergences of the k-means algorithm.
+ * The basic idea can be made much faster by only processing a random subset of the original points.
+ * In the context of streaming k-means, the total number of possible seeds will be about k log n so this
+ * selection will cost O(k^2 (log n)^2) which isn't much worse than the random sampling idea. At
+ * n = 10^9, the cost of this initialization will be about 10x worse than a reasonable random sampling
+ * implementation.
+ *
+ * The side effect of this method is to fill the centroids structure itself.
+ *
+ * @param datapoints The datapoints to select from. These datapoints should be WeightedVectors of some kind.
+ */
+ private void initializeSeedsKMeansPlusPlus(List<? extends WeightedVector> datapoints) {
+ Preconditions.checkArgument(datapoints.size() > 1, "Must have at least two datapoints points to cluster " +
+ "sensibly");
+ Preconditions.checkArgument(datapoints.size() >= numClusters,
+ String.format("Must have more datapoints [%d] than clusters [%d]", datapoints.size(), numClusters));
+ // Compute the centroid of all of the datapoints. This is then used to compute the squared radius of the datapoints.
+ Centroid center = new Centroid(datapoints.iterator().next());
+ for (WeightedVector row : Iterables.skip(datapoints, 1)) {
+ center.update(row);
+ }
+
+ // Given the centroid, we can compute \Delta_1^2(X), the total squared distance for the datapoints
+ // this accelerates seed selection.
+ double deltaX = 0;
+ DistanceMeasure distanceMeasure = centroids.getDistanceMeasure();
+ for (WeightedVector row : datapoints) {
+ deltaX += distanceMeasure.distance(row, center);
+ }
+
+ // Find the first seed c_1 (and conceptually the second, c_2) as might be done in the 2-means clustering so that
+ // the probability of selecting c_1 and c_2 is proportional to || c_1 - c_2 ||^2. This is done
+ // by first selecting c_1 with probability:
+ //
+ // p(c_1) = sum_{c_1} || c_1 - c_2 ||^2 \over sum_{c_1, c_2} || c_1 - c_2 ||^2
+ //
+ // This can be simplified to:
+ //
+ // p(c_1) = \Delta_1^2(X) + n || c_1 - c ||^2 / (2 n \Delta_1^2(X))
+ //
+ // where c = \sum x / n and \Delta_1^2(X) = sum || x - c ||^2
+ //
+ // All subsequent seeds c_i (including c_2) can then be selected from the remaining points with probability
+ // proportional to Pr(c_i == x_j) = min_{m < i} || c_m - x_j ||^2.
+
+ // Multinomial distribution of vector indices for the selection seeds. These correspond to
+ // the indices of the vectors in the original datapoints list.
+ Multinomial<Integer> seedSelector = new Multinomial<>();
+ for (int i = 0; i < datapoints.size(); ++i) {
+ double selectionProbability =
+ deltaX + datapoints.size() * distanceMeasure.distance(datapoints.get(i), center);
+ seedSelector.add(i, selectionProbability);
+ }
+
+ int selected = random.nextInt(datapoints.size());
+ Centroid c_1 = new Centroid(datapoints.get(selected).clone());
+ c_1.setIndex(0);
+ // Construct a set of weighted things which can be used for random selection. Initial weights are
+ // set to the squared distance from c_1
+ for (int i = 0; i < datapoints.size(); ++i) {
+ WeightedVector row = datapoints.get(i);
+ double w = distanceMeasure.distance(c_1, row) * 2 * Math.log(1 + row.getWeight());
+ seedSelector.set(i, w);
+ }
+
+ // From here, seeds are selected with probability proportional to:
+ //
+ // r_i = min_{c_j} || x_i - c_j ||^2
+ //
+ // when we only have c_1, we have already set these distances and as we select each new
+ // seed, we update the minimum distances.
+ centroids.add(c_1);
+ int clusterIndex = 1;
+ while (centroids.size() < numClusters) {
+ // Select according to weights.
+ int seedIndex = seedSelector.sample();
+ Centroid nextSeed = new Centroid(datapoints.get(seedIndex));
+ nextSeed.setIndex(clusterIndex++);
+ centroids.add(nextSeed);
+ // Don't select this one again.
+ seedSelector.delete(seedIndex);
+ // Re-weight everything according to the minimum distance to a seed.
+ for (int currSeedIndex : seedSelector) {
+ WeightedVector curr = datapoints.get(currSeedIndex);
+ double newWeight = nextSeed.getWeight() * distanceMeasure.distance(nextSeed, curr);
+ if (newWeight < seedSelector.getWeight(currSeedIndex)) {
+ seedSelector.set(currSeedIndex, newWeight);
+ }
+ }
+ }
+ }
+
+ /**
+ * Examines the datapoints and updates cluster centers to be the centroid of the nearest datapoints points. To
+ * compute a new center for cluster c_i, we average all points that are closer than d_i * trimFraction
+ * where d_i is
+ *
+ * d_i = min_j \sqrt ||c_j - c_i||^2
+ *
+ * By ignoring distant points, the centroids converge more quickly to a good approximation of the
+ * optimal k-means solution (given good starting points).
+ *
+ * @param datapoints the points to cluster.
+ */
+ private void iterativeAssignment(List<? extends WeightedVector> datapoints) {
+ DistanceMeasure distanceMeasure = centroids.getDistanceMeasure();
+ // closestClusterDistances.get(i) is the distance from the i'th cluster to its closest
+ // neighboring cluster.
+ List<Double> closestClusterDistances = Lists.newArrayListWithExpectedSize(numClusters);
+ // clusterAssignments[i] == j means that the i'th point is assigned to the j'th cluster. When
+ // these don't change, we are done.
+ // Each point is assigned to the invalid "-1" cluster initially.
+ List<Integer> clusterAssignments = Lists.newArrayList(Collections.nCopies(datapoints.size(), -1));
+
+ boolean changed = true;
+ for (int i = 0; changed && i < maxNumIterations; i++) {
+ changed = false;
+ // We compute what the distance between each cluster and its closest neighbor is to set a
+ // proportional distance threshold for points that should be involved in calculating the
+ // centroid.
+ closestClusterDistances.clear();
+ for (Vector center : centroids) {
+ // If a centroid has no points assigned to it, the clustering failed.
+ Vector closestOtherCluster = centroids.searchFirst(center, true).getValue();
+ closestClusterDistances.add(distanceMeasure.distance(center, closestOtherCluster));
+ }
+
+ // Copies the current cluster centroids to newClusters and sets their weights to 0. This is
+ // so we calculate the new centroids as we go through the datapoints.
+ List<Centroid> newCentroids = Lists.newArrayList();
+ for (Vector centroid : centroids) {
+ // need a deep copy because we will mutate these values
+ Centroid newCentroid = (Centroid)centroid.clone();
+ newCentroid.setWeight(0);
+ newCentroids.add(newCentroid);
+ }
+
+ // Pass over the datapoints computing new centroids.
+ for (int j = 0; j < datapoints.size(); ++j) {
+ WeightedVector datapoint = datapoints.get(j);
+ // Get the closest cluster this point belongs to.
+ WeightedThing<Vector> closestPair = centroids.searchFirst(datapoint, false);
+ int closestIndex = ((WeightedVector) closestPair.getValue()).getIndex();
+ double closestDistance = closestPair.getWeight();
+ // Update its cluster assignment if necessary.
+ if (closestIndex != clusterAssignments.get(j)) {
+ changed = true;
+ clusterAssignments.set(j, closestIndex);
+ }
+ // Only update if the datapoints point is near enough. What this means is that the weight
+ // of outliers is NOT taken into account and the final weights of the centroids will
+ // reflect this (it will be less or equal to the initial sum of the weights).
+ if (closestDistance < trimFraction * closestClusterDistances.get(closestIndex)) {
+ newCentroids.get(closestIndex).update(datapoint);
+ }
+ }
+ // Add the new centers back into searcher.
+ centroids.clear();
+ centroids.addAll(newCentroids);
+ }
+
+ if (correctWeights) {
+ for (Vector v : centroids) {
+ ((Centroid)v).setWeight(0);
+ }
+ for (WeightedVector datapoint : datapoints) {
+ Centroid closestCentroid = (Centroid) centroids.searchFirst(datapoint, false).getValue();
+ closestCentroid.setWeight(closestCentroid.getWeight() + datapoint.getWeight());
+ }
+ }
+ }
+
+ @Override
+ public Iterator<Centroid> iterator() {
+ return Iterators.transform(centroids.iterator(), new Function<Vector, Centroid>() {
+ @Override
+ public Centroid apply(Vector input) {
+ Preconditions.checkArgument(input instanceof Centroid, "Non-centroid in centroids " +
+ "searcher");
+ //noinspection ConstantConditions
+ return (Centroid)input;
+ }
+ });
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java
new file mode 100644
index 0000000..0e3f068
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java
@@ -0,0 +1,368 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.cluster;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.jet.math.Constants;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+import org.apache.mahout.math.random.WeightedThing;
+
+/**
+ * Implements a streaming k-means algorithm for weighted vectors.
+ * The goal clustering points one at a time, especially useful for MapReduce mappers that get inputs one at a time.
+ *
+ * A rough description of the algorithm:
+ * Suppose there are l clusters at one point and a new point p is added.
+ * The new point can either be added to one of the existing l clusters or become a new cluster. To decide:
+ * - let c be the closest cluster to point p;
+ * - let d be the distance between c and p;
+ * - if d > distanceCutoff, create a new cluster from p (p is too far away from the clusters to be part of them;
+ * distanceCutoff represents the largest distance from a point its assigned cluster's centroid);
+ * - else (d <= distanceCutoff), create a new cluster with probability d / distanceCutoff (the probability of creating
+ * a new cluster increases as d increases).
+ * There will be either l points or l + 1 points after processing a new point.
+ *
+ * As the number of clusters increases, it will go over the numClusters limit (numClusters represents a recommendation
+ * for the number of clusters that there should be at the end). To decrease the number of clusters the existing clusters
+ * are treated as data points and are re-clustered (collapsed). This tends to make the number of clusters go down.
+ * If the number of clusters is still too high, distanceCutoff is increased.
+ *
+ * For more details, see:
+ * - "Streaming k-means approximation" by N. Ailon, R. Jaiswal, C. Monteleoni
+ * http://books.nips.cc/papers/files/nips22/NIPS2009_1085.pdf
+ * - "Fast and Accurate k-means for Large Datasets" by M. Shindler, A. Wong, A. Meyerson,
+ * http://books.nips.cc/papers/files/nips24/NIPS2011_1271.pdf
+ */
+public class StreamingKMeans implements Iterable<Centroid> {
+ /**
+ * The searcher containing the centroids that resulted from the clustering of points until now. When adding a new
+ * point we either assign it to one of the existing clusters in this searcher or create a new centroid for it.
+ */
+ private final UpdatableSearcher centroids;
+
+ /**
+ * The estimated number of clusters to cluster the data in. If the actual number of clusters increases beyond this
+ * limit, the clusters will be "collapsed" (re-clustered, by treating them as data points). This doesn't happen
+ * recursively and a collapse might not necessarily make the number of actual clusters drop to less than this limit.
+ *
+ * If the goal is clustering a large data set into k clusters, numClusters SHOULD NOT BE SET to k. StreamingKMeans is
+ * useful to reduce the size of the data set by the mappers so that it can fit into memory in one reducer that runs
+ * BallKMeans.
+ *
+ * It is NOT MEANT to cluster the data into k clusters in one pass because it can't guarantee that there will in fact
+ * be k clusters in total. This is because of the dynamic nature of numClusters over the course of the runtime.
+ * To get an exact number of clusters, another clustering algorithm needs to be applied to the results.
+ */
+ private int numClusters;
+
+ /**
+ * The number of data points seen so far. This is important for re-estimating numClusters when deciding to collapse
+ * the existing clusters.
+ */
+ private int numProcessedDatapoints = 0;
+
+ /**
+ * This is the current value of the distance cutoff. Points which are much closer than this to a centroid will stick
+ * to it almost certainly. Points further than this to any centroid will form a new cluster.
+ *
+ * This increases (is multiplied by beta) when a cluster collapse did not make the number of clusters drop to below
+ * numClusters (it effectively increases the tolerance for cluster compactness discouraging the creation of new
+ * clusters). Since a collapse only happens when centroids.size() > clusterOvershoot * numClusters, the cutoff
+ * increases when the collapse didn't at least remove the slack in the number of clusters.
+ */
+ private double distanceCutoff;
+
+ /**
+ * Parameter that controls the growth of the distanceCutoff. After n increases of the
+ * distanceCutoff starting at d_0, the final value is d_0 * beta^n (distance cutoffs increase following a geometric
+ * progression with ratio beta).
+ */
+ private final double beta;
+
+ /**
+ * Multiplying clusterLogFactor with numProcessedDatapoints gets an estimate of the suggested
+ * number of clusters. This mirrors the recommended number of clusters for n points where there should be k actual
+ * clusters, k * log n. In the case of our estimate we use clusterLogFactor * log(numProcessedDataPoints).
+ *
+ * It is important to note that numClusters is NOT k. It is an estimate of k * log n.
+ */
+ private final double clusterLogFactor;
+
+ /**
+ * Centroids are collapsed when the number of clusters becomes greater than clusterOvershoot * numClusters. This
+ * effectively means having a slack in numClusters so that the actual number of centroids, centroids.size() tracks
+ * numClusters approximately. The idea is that the actual number of clusters should be at least numClusters but not
+ * much more (so that we don't end up having 1 cluster / point).
+ */
+ private final double clusterOvershoot;
+
+ /**
+ * Random object to sample values from.
+ */
+ private final Random random = RandomUtils.getRandom();
+
+ /**
+ * Calls StreamingKMeans(searcher, numClusters, 1.3, 10, 2).
+ * @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int,
+ * double, double, double, double)
+ */
+ public StreamingKMeans(UpdatableSearcher searcher, int numClusters) {
+ this(searcher, numClusters, 1.0 / numClusters, 1.3, 20, 2);
+ }
+
+ /**
+ * Calls StreamingKMeans(searcher, numClusters, distanceCutoff, 1.3, 10, 2).
+ * @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int,
+ * double, double, double, double)
+ */
+ public StreamingKMeans(UpdatableSearcher searcher, int numClusters, double distanceCutoff) {
+ this(searcher, numClusters, distanceCutoff, 1.3, 20, 2);
+ }
+
+ /**
+ * Creates a new StreamingKMeans class given a searcher and the number of clusters to generate.
+ *
+ * @param searcher A Searcher that is used for performing nearest neighbor search. It MUST BE
+ * EMPTY initially because it will be used to keep track of the cluster
+ * centroids.
+ * @param numClusters An estimated number of clusters to generate for the data points.
+ * This can adjusted, but the actual number will depend on the data. The
+ * @param distanceCutoff The initial distance cutoff representing the value of the
+ * distance between a point and its closest centroid after which
+ * the new point will definitely be assigned to a new cluster.
+ * @param beta Ratio of geometric progression to use when increasing distanceCutoff. After n increases, distanceCutoff
+ * becomes distanceCutoff * beta^n. A smaller value increases the distanceCutoff less aggressively.
+ * @param clusterLogFactor Value multiplied with the number of points counted so far estimating the number of clusters
+ * to aim for. If the final number of clusters is known and this clustering is only for a
+ * sketch of the data, this can be the final number of clusters, k.
+ * @param clusterOvershoot Multiplicative slack factor for slowing down the collapse of the clusters.
+ */
+ public StreamingKMeans(UpdatableSearcher searcher, int numClusters,
+ double distanceCutoff, double beta, double clusterLogFactor, double clusterOvershoot) {
+ this.centroids = searcher;
+ this.numClusters = numClusters;
+ this.distanceCutoff = distanceCutoff;
+ this.beta = beta;
+ this.clusterLogFactor = clusterLogFactor;
+ this.clusterOvershoot = clusterOvershoot;
+ }
+
+ /**
+ * @return an Iterator to the Centroids contained in this clusterer.
+ */
+ @Override
+ public Iterator<Centroid> iterator() {
+ return Iterators.transform(centroids.iterator(), new Function<Vector, Centroid>() {
+ @Override
+ public Centroid apply(Vector input) {
+ return (Centroid)input;
+ }
+ });
+ }
+
+ /**
+ * Cluster the rows of a matrix, treating them as Centroids with weight 1.
+ * @param data matrix whose rows are to be clustered.
+ * @return the UpdatableSearcher containing the resulting centroids.
+ */
+ public UpdatableSearcher cluster(Matrix data) {
+ return cluster(Iterables.transform(data, new Function<MatrixSlice, Centroid>() {
+ @Override
+ public Centroid apply(MatrixSlice input) {
+ // The key in a Centroid is actually the MatrixSlice's index.
+ return Centroid.create(input.index(), input.vector());
+ }
+ }));
+ }
+
+ /**
+ * Cluster the data points in an Iterable<Centroid>.
+ * @param datapoints Iterable whose elements are to be clustered.
+ * @return the UpdatableSearcher containing the resulting centroids.
+ */
+ public UpdatableSearcher cluster(Iterable<Centroid> datapoints) {
+ return clusterInternal(datapoints, false);
+ }
+
+ /**
+ * Cluster one data point.
+ * @param datapoint to be clustered.
+ * @return the UpdatableSearcher containing the resulting centroids.
+ */
+ public UpdatableSearcher cluster(final Centroid datapoint) {
+ return cluster(new Iterable<Centroid>() {
+ @Override
+ public Iterator<Centroid> iterator() {
+ return new Iterator<Centroid>() {
+ private boolean accessed = false;
+
+ @Override
+ public boolean hasNext() {
+ return !accessed;
+ }
+
+ @Override
+ public Centroid next() {
+ accessed = true;
+ return datapoint;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+ });
+ }
+
+ /**
+ * @return the number of clusters computed from the points until now.
+ */
+ public int getNumClusters() {
+ return centroids.size();
+ }
+
+ /**
+ * Internal clustering method that gets called from the other wrappers.
+ * @param datapoints Iterable of data points to be clustered.
+ * @param collapseClusters whether this is an "inner" clustering and the datapoints are the previously computed
+ * centroids. Some logic is different to ensure counters are consistent but it behaves
+ * nearly the same.
+ * @return the UpdatableSearcher containing the resulting centroids.
+ */
+ private UpdatableSearcher clusterInternal(Iterable<Centroid> datapoints, boolean collapseClusters) {
+ Iterator<Centroid> datapointsIterator = datapoints.iterator();
+ if (!datapointsIterator.hasNext()) {
+ return centroids;
+ }
+
+ int oldNumProcessedDataPoints = numProcessedDatapoints;
+ // We clear the centroids we have in case of cluster collapse, the old clusters are the
+ // datapoints but we need to re-cluster them.
+ if (collapseClusters) {
+ centroids.clear();
+ numProcessedDatapoints = 0;
+ }
+
+ if (centroids.size() == 0) {
+ // Assign the first datapoint to the first cluster.
+ // Adding a vector to a searcher would normally just reference the copy,
+ // but we could potentially mutate it and so we need to make a clone.
+ centroids.add(datapointsIterator.next().clone());
+ ++numProcessedDatapoints;
+ }
+
+ // To cluster, we scan the data and either add each point to the nearest group or create a new group.
+ // when we get too many groups, we need to increase the threshold and rescan our current groups
+ while (datapointsIterator.hasNext()) {
+ Centroid row = datapointsIterator.next();
+ // Get the closest vector and its weight as a WeightedThing<Vector>.
+ // The weight of the WeightedThing is the distance to the query and the value is a
+ // reference to one of the vectors we added to the searcher previously.
+ WeightedThing<Vector> closestPair = centroids.searchFirst(row, false);
+
+ // We get a uniformly distributed random number between 0 and 1 and compare it with the
+ // distance to the closest cluster divided by the distanceCutoff.
+ // This is so that if the closest cluster is further than distanceCutoff,
+ // closestPair.getWeight() / distanceCutoff > 1 which will trigger the creation of a new
+ // cluster anyway.
+ // However, if the ratio is less than 1, we want to create a new cluster with probability
+ // proportional to the distance to the closest cluster.
+ double sample = random.nextDouble();
+ if (sample < row.getWeight() * closestPair.getWeight() / distanceCutoff) {
+ // Add new centroid, note that the vector is copied because we may mutate it later.
+ centroids.add(row.clone());
+ } else {
+ // Merge the new point with the existing centroid. This will update the centroid's actual
+ // position.
+ // We know that all the points we inserted in the centroids searcher are (or extend)
+ // WeightedVector, so the cast will always succeed.
+ Centroid centroid = (Centroid) closestPair.getValue();
+
+ // We will update the centroid by removing it from the searcher and reinserting it to
+ // ensure consistency.
+ if (!centroids.remove(centroid, Constants.EPSILON)) {
+ throw new RuntimeException("Unable to remove centroid");
+ }
+ centroid.update(row);
+ centroids.add(centroid);
+
+ }
+ ++numProcessedDatapoints;
+
+ if (!collapseClusters && centroids.size() > clusterOvershoot * numClusters) {
+ numClusters = (int) Math.max(numClusters, clusterLogFactor * Math.log(numProcessedDatapoints));
+
+ List<Centroid> shuffled = Lists.newArrayList();
+ for (Vector vector : centroids) {
+ shuffled.add((Centroid) vector);
+ }
+ Collections.shuffle(shuffled);
+ // Re-cluster using the shuffled centroids as data points. The centroids member variable
+ // is modified directly.
+ clusterInternal(shuffled, true);
+
+ if (centroids.size() > numClusters) {
+ distanceCutoff *= beta;
+ }
+ }
+ }
+
+ if (collapseClusters) {
+ numProcessedDatapoints = oldNumProcessedDataPoints;
+ }
+ return centroids;
+ }
+
+ public void reindexCentroids() {
+ int numCentroids = 0;
+ for (Centroid centroid : this) {
+ centroid.setIndex(numCentroids++);
+ }
+ }
+
+ /**
+ * @return the distanceCutoff (an upper bound for the maximum distance within a cluster).
+ */
+ public double getDistanceCutoff() {
+ return distanceCutoff;
+ }
+
+ public void setDistanceCutoff(double distanceCutoff) {
+ this.distanceCutoff = distanceCutoff;
+ }
+
+ public DistanceMeasure getDistanceMeasure() {
+ return centroids.getDistanceMeasure();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java
new file mode 100644
index 0000000..a41940b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+public class CentroidWritable implements Writable {
+ private Centroid centroid = null;
+
+ public CentroidWritable() {}
+
+ public CentroidWritable(Centroid centroid) {
+ this.centroid = centroid;
+ }
+
+ public Centroid getCentroid() {
+ return centroid;
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ dataOutput.writeInt(centroid.getIndex());
+ dataOutput.writeDouble(centroid.getWeight());
+ VectorWritable.writeVector(dataOutput, centroid.getVector());
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ if (centroid == null) {
+ centroid = read(dataInput);
+ return;
+ }
+ centroid.setIndex(dataInput.readInt());
+ centroid.setWeight(dataInput.readDouble());
+ centroid.assign(VectorWritable.readVector(dataInput));
+ }
+
+ public static Centroid read(DataInput dataInput) throws IOException {
+ int index = dataInput.readInt();
+ double weight = dataInput.readDouble();
+ Vector v = VectorWritable.readVector(dataInput);
+ return new Centroid(index, v, weight);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof CentroidWritable)) {
+ return false;
+ }
+ CentroidWritable writable = (CentroidWritable) o;
+ return centroid.equals(writable.centroid);
+ }
+
+ @Override
+ public int hashCode() {
+ return centroid.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return centroid.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
new file mode 100644
index 0000000..73776b9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
@@ -0,0 +1,493 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.ProjectionSearch;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Classifies the vectors into different clusters found by the clustering
+ * algorithm.
+ */
+public final class StreamingKMeansDriver extends AbstractJob {
+ /**
+ * Streaming KMeans options
+ */
+ /**
+ * The number of cluster that Mappers will use should be \(O(k log n)\) where k is the number of clusters
+ * to get at the end and n is the number of points to cluster. This doesn't need to be exact.
+ * It will be adjusted at runtime.
+ */
+ public static final String ESTIMATED_NUM_MAP_CLUSTERS = "estimatedNumMapClusters";
+ /**
+ * The initial estimated distance cutoff between two points for forming new clusters.
+ * @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans
+ * Defaults to 10e-6.
+ */
+ public static final String ESTIMATED_DISTANCE_CUTOFF = "estimatedDistanceCutoff";
+
+ /**
+ * Ball KMeans options
+ */
+ /**
+ * After mapping finishes, we get an intermediate set of vectors that represent approximate
+ * clusterings of the data from each Mapper. These can be clustered by the Reducer using
+ * BallKMeans in memory. This variable is the maximum number of iterations in the final
+ * BallKMeans algorithm.
+ * Defaults to 10.
+ */
+ public static final String MAX_NUM_ITERATIONS = "maxNumIterations";
+ /**
+ * The "ball" aspect of ball k-means means that only the closest points to the centroid will actually be used
+ * for updating. The fraction of the points to be used is those points whose distance to the center is within
+ * trimFraction * distance to the closest other center.
+ * Defaults to 0.9.
+ */
+ public static final String TRIM_FRACTION = "trimFraction";
+ /**
+ * Whether to use k-means++ initialization or random initialization of the seed centroids.
+ * Essentially, k-means++ provides better clusters, but takes longer, whereas random initialization takes less
+ * time, but produces worse clusters, and tends to fail more often and needs multiple runs to compare to
+ * k-means++. If set, uses randomInit.
+ * @see org.apache.mahout.clustering.streaming.cluster.BallKMeans
+ */
+ public static final String RANDOM_INIT = "randomInit";
+ /**
+ * Whether to correct the weights of the centroids after the clustering is done. The weights end up being wrong
+ * because of the trimFraction and possible train/test splits. In some cases, especially in a pipeline, having
+ * an accurate count of the weights is useful. If set, ignores the final weights.
+ */
+ public static final String IGNORE_WEIGHTS = "ignoreWeights";
+ /**
+ * The percentage of points that go into the "test" set when evaluating BallKMeans runs in the reducer.
+ */
+ public static final String TEST_PROBABILITY = "testProbability";
+ /**
+ * The percentage of points that go into the "training" set when evaluating BallKMeans runs in the reducer.
+ */
+ public static final String NUM_BALLKMEANS_RUNS = "numBallKMeansRuns";
+
+ /**
+ Searcher options
+ */
+ /**
+ * The Searcher class when performing nearest neighbor search in StreamingKMeans.
+ * Defaults to ProjectionSearch.
+ */
+ public static final String SEARCHER_CLASS_OPTION = "searcherClass";
+ /**
+ * The number of projections to use when using a projection searcher like ProjectionSearch or
+ * FastProjectionSearch. Projection searches work by projection the all the vectors on to a set of
+ * basis vectors and searching for the projected query in that totally ordered set. This
+ * however can produce false positives (vectors that are closer when projected than they would
+ * actually be.
+ * So, there must be more than one projection vectors in the basis. This variable is the number
+ * of vectors in a basis.
+ * Defaults to 3
+ */
+ public static final String NUM_PROJECTIONS_OPTION = "numProjections";
+ /**
+ * When using approximate searches (anything that's not BruteSearch),
+ * more than just the seemingly closest element must be considered. This variable has different
+ * meanings depending on the actual Searcher class used but is a measure of how many candidates
+ * will be considered.
+ * See the ProjectionSearch, FastProjectionSearch, LocalitySensitiveHashSearch classes for more
+ * details.
+ * Defaults to 2.
+ */
+ public static final String SEARCH_SIZE_OPTION = "searchSize";
+
+ /**
+ * Whether to run another pass of StreamingKMeans on the reducer's points before BallKMeans. On some data sets
+ * with a large number of mappers, the intermediate number of clusters passed to the reducer is too large to
+ * fit into memory directly, hence the option to collapse the clusters further with StreamingKMeans.
+ */
+ public static final String REDUCE_STREAMING_KMEANS = "reduceStreamingKMeans";
+
+ private static final Logger log = LoggerFactory.getLogger(StreamingKMeansDriver.class);
+
+ public static final float INVALID_DISTANCE_CUTOFF = -1;
+
+ @Override
+ public int run(String[] args) throws Exception {
+ // Standard options for any Mahout job.
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.overwriteOption().create());
+
+ // The number of clusters to create for the data.
+ addOption(DefaultOptionCreator.numClustersOption().withDescription(
+ "The k in k-Means. Approximately this many clusters will be generated.").create());
+
+ // StreamingKMeans (mapper) options
+ // There will be k final clusters, but in the Map phase to get a good approximation of the data, O(k log n)
+ // clusters are needed. Since n is the number of data points and not knowable until reading all the vectors,
+ // provide a decent estimate.
+ addOption(ESTIMATED_NUM_MAP_CLUSTERS, "km", "The estimated number of clusters to use for the "
+ + "Map phase of the job when running StreamingKMeans. This should be around k * log(n), "
+ + "where k is the final number of clusters and n is the total number of data points to "
+ + "cluster.", String.valueOf(1));
+
+ addOption(ESTIMATED_DISTANCE_CUTOFF, "e", "The initial estimated distance cutoff between two "
+ + "points for forming new clusters. If no value is given, it's estimated from the data set",
+ String.valueOf(INVALID_DISTANCE_CUTOFF));
+
+ // BallKMeans (reducer) options
+ addOption(MAX_NUM_ITERATIONS, "mi", "The maximum number of iterations to run for the "
+ + "BallKMeans algorithm used by the reducer. If no value is given, defaults to 10.", String.valueOf(10));
+
+ addOption(TRIM_FRACTION, "tf", "The 'ball' aspect of ball k-means means that only the closest points "
+ + "to the centroid will actually be used for updating. The fraction of the points to be used is those "
+ + "points whose distance to the center is within trimFraction * distance to the closest other center. "
+ + "If no value is given, defaults to 0.9.", String.valueOf(0.9));
+
+ addFlag(RANDOM_INIT, "ri", "Whether to use k-means++ initialization or random initialization "
+ + "of the seed centroids. Essentially, k-means++ provides better clusters, but takes longer, whereas random "
+ + "initialization takes less time, but produces worse clusters, and tends to fail more often and needs "
+ + "multiple runs to compare to k-means++. If set, uses the random initialization.");
+
+ addFlag(IGNORE_WEIGHTS, "iw", "Whether to correct the weights of the centroids after the clustering is done. "
+ + "The weights end up being wrong because of the trimFraction and possible train/test splits. In some cases, "
+ + "especially in a pipeline, having an accurate count of the weights is useful. If set, ignores the final "
+ + "weights");
+
+ addOption(TEST_PROBABILITY, "testp", "A double value between 0 and 1 that represents the percentage of "
+ + "points to be used for 'testing' different clustering runs in the final BallKMeans "
+ + "step. If no value is given, defaults to 0.1", String.valueOf(0.1));
+
+ addOption(NUM_BALLKMEANS_RUNS, "nbkm", "Number of BallKMeans runs to use at the end to try to cluster the "
+ + "points. If no value is given, defaults to 4", String.valueOf(4));
+
+ // Nearest neighbor search options
+ // The distance measure used for computing the distance between two points. Generally, the
+ // SquaredEuclideanDistance is used for clustering problems (it's equivalent to CosineDistance for normalized
+ // vectors).
+ // WARNING! You can use any metric but most of the literature is for the squared euclidean distance.
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+
+ // The default searcher should be something more efficient that BruteSearch (ProjectionSearch, ...). See
+ // o.a.m.math.neighborhood.*
+ addOption(SEARCHER_CLASS_OPTION, "sc", "The type of searcher to be used when performing nearest "
+ + "neighbor searches. Defaults to ProjectionSearch.", ProjectionSearch.class.getCanonicalName());
+
+ // In the original paper, the authors used 1 projection vector.
+ addOption(NUM_PROJECTIONS_OPTION, "np", "The number of projections considered in estimating the "
+ + "distances between vectors. Only used when the distance measure requested is either "
+ + "ProjectionSearch or FastProjectionSearch. If no value is given, defaults to 3.", String.valueOf(3));
+
+ addOption(SEARCH_SIZE_OPTION, "s", "In more efficient searches (non BruteSearch), "
+ + "not all distances are calculated for determining the nearest neighbors. The number of "
+ + "elements whose distances from the query vector is actually computer is proportional to "
+ + "searchSize. If no value is given, defaults to 1.", String.valueOf(2));
+
+ addFlag(REDUCE_STREAMING_KMEANS, "rskm", "There might be too many intermediate clusters from the mapper "
+ + "to fit into memory, so the reducer can run another pass of StreamingKMeans to collapse them down to a "
+ + "fewer clusters");
+
+ addOption(DefaultOptionCreator.methodOption().create());
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+ Path output = getOutputPath();
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ configureOptionsForWorkers();
+ run(getConf(), getInputPath(), output);
+ return 0;
+ }
+
+ private void configureOptionsForWorkers() throws ClassNotFoundException {
+ log.info("Starting to configure options for workers");
+
+ String method = getOption(DefaultOptionCreator.METHOD_OPTION);
+
+ int numClusters = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION));
+
+ // StreamingKMeans
+ int estimatedNumMapClusters = Integer.parseInt(getOption(ESTIMATED_NUM_MAP_CLUSTERS));
+ float estimatedDistanceCutoff = Float.parseFloat(getOption(ESTIMATED_DISTANCE_CUTOFF));
+
+ // BallKMeans
+ int maxNumIterations = Integer.parseInt(getOption(MAX_NUM_ITERATIONS));
+ float trimFraction = Float.parseFloat(getOption(TRIM_FRACTION));
+ boolean randomInit = hasOption(RANDOM_INIT);
+ boolean ignoreWeights = hasOption(IGNORE_WEIGHTS);
+ float testProbability = Float.parseFloat(getOption(TEST_PROBABILITY));
+ int numBallKMeansRuns = Integer.parseInt(getOption(NUM_BALLKMEANS_RUNS));
+
+ // Nearest neighbor search
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ String searcherClass = getOption(SEARCHER_CLASS_OPTION);
+
+ // Get more parameters depending on the kind of search class we're working with. BruteSearch
+ // doesn't need anything else.
+ // LocalitySensitiveHashSearch and ProjectionSearches need searchSize.
+ // ProjectionSearches also need the number of projections.
+ boolean getSearchSize = false;
+ boolean getNumProjections = false;
+ if (!searcherClass.equals(BruteSearch.class.getName())) {
+ getSearchSize = true;
+ getNumProjections = true;
+ }
+
+ // The search size to use. This is quite fuzzy and might end up not being configurable at all.
+ int searchSize = 0;
+ if (getSearchSize) {
+ searchSize = Integer.parseInt(getOption(SEARCH_SIZE_OPTION));
+ }
+
+ // The number of projections to use. This is only useful in projection searches which
+ // project the vectors on multiple basis vectors to get distance estimates that are faster to
+ // calculate.
+ int numProjections = 0;
+ if (getNumProjections) {
+ numProjections = Integer.parseInt(getOption(NUM_PROJECTIONS_OPTION));
+ }
+
+ boolean reduceStreamingKMeans = hasOption(REDUCE_STREAMING_KMEANS);
+
+ configureOptionsForWorkers(getConf(), numClusters,
+ /* StreamingKMeans */
+ estimatedNumMapClusters, estimatedDistanceCutoff,
+ /* BallKMeans */
+ maxNumIterations, trimFraction, randomInit, ignoreWeights, testProbability, numBallKMeansRuns,
+ /* Searcher */
+ measureClass, searcherClass, searchSize, numProjections,
+ method,
+ reduceStreamingKMeans);
+ }
+
+ /**
+ * Checks the parameters for a StreamingKMeans job and prepares a Configuration with them.
+ *
+ * @param conf the Configuration to populate
+ * @param numClusters k, the number of clusters at the end
+ * @param estimatedNumMapClusters O(k log n), the number of clusters requested from each mapper
+ * @param estimatedDistanceCutoff an estimate of the minimum distance that separates two clusters (can be smaller and
+ * will be increased dynamically)
+ * @param maxNumIterations the maximum number of iterations of BallKMeans
+ * @param trimFraction the fraction of the points to be considered in updating a ball k-means
+ * @param randomInit whether to initialize the ball k-means seeds randomly
+ * @param ignoreWeights whether to ignore the invalid final ball k-means weights
+ * @param testProbability the percentage of vectors assigned to the test set for selecting the best final centers
+ * @param numBallKMeansRuns the number of BallKMeans runs in the reducer that determine the centroids to return
+ * (clusters are computed for the training set and the error is computed on the test set)
+ * @param measureClass string, name of the distance measure class; theory works for Euclidean-like distances
+ * @param searcherClass string, name of the searcher that will be used for nearest neighbor search
+ * @param searchSize the number of closest neighbors to look at for selecting the closest one in approximate nearest
+ * neighbor searches
+ * @param numProjections the number of projected vectors to use for faster searching (only useful for ProjectionSearch
+ * or FastProjectionSearch); @see org.apache.mahout.math.neighborhood.ProjectionSearch
+ */
+ public static void configureOptionsForWorkers(Configuration conf,
+ int numClusters,
+ /* StreamingKMeans */
+ int estimatedNumMapClusters, float estimatedDistanceCutoff,
+ /* BallKMeans */
+ int maxNumIterations, float trimFraction, boolean randomInit,
+ boolean ignoreWeights, float testProbability, int numBallKMeansRuns,
+ /* Searcher */
+ String measureClass, String searcherClass,
+ int searchSize, int numProjections,
+ String method,
+ boolean reduceStreamingKMeans) throws ClassNotFoundException {
+ // Checking preconditions for the parameters.
+ Preconditions.checkArgument(numClusters > 0,
+ "Invalid number of clusters requested: " + numClusters + ". Must be: numClusters > 0!");
+
+ // StreamingKMeans
+ Preconditions.checkArgument(estimatedNumMapClusters > numClusters, "Invalid number of estimated map "
+ + "clusters; There must be more than the final number of clusters (k log n vs k)");
+ Preconditions.checkArgument(estimatedDistanceCutoff == INVALID_DISTANCE_CUTOFF || estimatedDistanceCutoff > 0,
+ "estimatedDistanceCutoff must be equal to -1 or must be greater then 0!");
+
+ // BallKMeans
+ Preconditions.checkArgument(maxNumIterations > 0, "Must have at least one BallKMeans iteration");
+ Preconditions.checkArgument(trimFraction > 0, "trimFraction must be positive");
+ Preconditions.checkArgument(testProbability >= 0 && testProbability < 1, "test probability is not in the "
+ + "interval [0, 1)");
+ Preconditions.checkArgument(numBallKMeansRuns > 0, "numBallKMeans cannot be negative");
+
+ // Searcher
+ if (!searcherClass.contains("Brute")) {
+ // These tests only make sense when a relevant searcher is being used.
+ Preconditions.checkArgument(searchSize > 0, "Invalid searchSize. Must be positive.");
+ if (searcherClass.contains("Projection")) {
+ Preconditions.checkArgument(numProjections > 0, "Invalid numProjections. Must be positive");
+ }
+ }
+
+ // Setting the parameters in the Configuration.
+ conf.setInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, numClusters);
+ /* StreamingKMeans */
+ conf.setInt(ESTIMATED_NUM_MAP_CLUSTERS, estimatedNumMapClusters);
+ if (estimatedDistanceCutoff != INVALID_DISTANCE_CUTOFF) {
+ conf.setFloat(ESTIMATED_DISTANCE_CUTOFF, estimatedDistanceCutoff);
+ }
+ /* BallKMeans */
+ conf.setInt(MAX_NUM_ITERATIONS, maxNumIterations);
+ conf.setFloat(TRIM_FRACTION, trimFraction);
+ conf.setBoolean(RANDOM_INIT, randomInit);
+ conf.setBoolean(IGNORE_WEIGHTS, ignoreWeights);
+ conf.setFloat(TEST_PROBABILITY, testProbability);
+ conf.setInt(NUM_BALLKMEANS_RUNS, numBallKMeansRuns);
+ /* Searcher */
+ // Checks if the measureClass is available, throws exception otherwise.
+ Class.forName(measureClass);
+ conf.set(DefaultOptionCreator.DISTANCE_MEASURE_OPTION, measureClass);
+ // Checks if the searcherClass is available, throws exception otherwise.
+ Class.forName(searcherClass);
+ conf.set(SEARCHER_CLASS_OPTION, searcherClass);
+ conf.setInt(SEARCH_SIZE_OPTION, searchSize);
+ conf.setInt(NUM_PROJECTIONS_OPTION, numProjections);
+ conf.set(DefaultOptionCreator.METHOD_OPTION, method);
+
+ conf.setBoolean(REDUCE_STREAMING_KMEANS, reduceStreamingKMeans);
+
+ log.info("Parameters are: [k] numClusters {}; "
+ + "[SKM] estimatedNumMapClusters {}; estimatedDistanceCutoff {} "
+ + "[BKM] maxNumIterations {}; trimFraction {}; randomInit {}; ignoreWeights {}; "
+ + "testProbability {}; numBallKMeansRuns {}; "
+ + "[S] measureClass {}; searcherClass {}; searcherSize {}; numProjections {}; "
+ + "method {}; reduceStreamingKMeans {}", numClusters, estimatedNumMapClusters, estimatedDistanceCutoff,
+ maxNumIterations, trimFraction, randomInit, ignoreWeights, testProbability, numBallKMeansRuns,
+ measureClass, searcherClass, searchSize, numProjections, method, reduceStreamingKMeans);
+ }
+
+ /**
+ * Iterate over the input vectors to produce clusters and, if requested, use the results of the final iteration to
+ * cluster the input vectors.
+ *
+ * @param input the directory pathname for input points.
+ * @param output the directory pathname for output points.
+ * @return 0 on success, -1 on failure.
+ */
+ public static int run(Configuration conf, Path input, Path output)
+ throws IOException, InterruptedException, ClassNotFoundException, ExecutionException {
+ log.info("Starting StreamingKMeans clustering for vectors in {}; results are output to {}",
+ input.toString(), output.toString());
+
+ if (conf.get(DefaultOptionCreator.METHOD_OPTION,
+ DefaultOptionCreator.MAPREDUCE_METHOD).equals(DefaultOptionCreator.SEQUENTIAL_METHOD)) {
+ return runSequentially(conf, input, output);
+ } else {
+ return runMapReduce(conf, input, output);
+ }
+ }
+
+ private static int runSequentially(Configuration conf, Path input, Path output)
+ throws IOException, ExecutionException, InterruptedException {
+ long start = System.currentTimeMillis();
+ // Run StreamingKMeans step in parallel by spawning 1 thread per input path to process.
+ ExecutorService pool = Executors.newCachedThreadPool();
+ List<Future<Iterable<Centroid>>> intermediateCentroidFutures = Lists.newArrayList();
+ for (FileStatus status : HadoopUtil.listStatus(FileSystem.get(conf), input, PathFilters.logsCRCFilter())) {
+ intermediateCentroidFutures.add(pool.submit(new StreamingKMeansThread(status.getPath(), conf)));
+ }
+ log.info("Finished running Mappers");
+ // Merge the resulting "mapper" centroids.
+ List<Centroid> intermediateCentroids = Lists.newArrayList();
+ for (Future<Iterable<Centroid>> futureIterable : intermediateCentroidFutures) {
+ for (Centroid centroid : futureIterable.get()) {
+ intermediateCentroids.add(centroid);
+ }
+ }
+ pool.shutdown();
+ pool.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS);
+ log.info("Finished StreamingKMeans");
+ SequenceFile.Writer writer = SequenceFile.createWriter(FileSystem.get(conf), conf, new Path(output, "part-r-00000"), IntWritable.class,
+ CentroidWritable.class);
+ int numCentroids = 0;
+ // Run BallKMeans on the intermediate centroids.
+ for (Vector finalVector : StreamingKMeansReducer.getBestCentroids(intermediateCentroids, conf)) {
+ Centroid finalCentroid = (Centroid)finalVector;
+ writer.append(new IntWritable(numCentroids++), new CentroidWritable(finalCentroid));
+ }
+ writer.close();
+ long end = System.currentTimeMillis();
+ log.info("Finished BallKMeans. Took {}.", (end - start) / 1000.0);
+ return 0;
+ }
+
+ public static int runMapReduce(Configuration conf, Path input, Path output)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ // Prepare Job for submission.
+ Job job = HadoopUtil.prepareJob(input, output, SequenceFileInputFormat.class,
+ StreamingKMeansMapper.class, IntWritable.class, CentroidWritable.class,
+ StreamingKMeansReducer.class, IntWritable.class, CentroidWritable.class, SequenceFileOutputFormat.class,
+ conf);
+ job.setJobName(HadoopUtil.getCustomJobName(StreamingKMeansDriver.class.getSimpleName(), job,
+ StreamingKMeansMapper.class, StreamingKMeansReducer.class));
+
+ // There is only one reducer so that the intermediate centroids get collected on one
+ // machine and are clustered in memory to get the right number of clusters.
+ job.setNumReduceTasks(1);
+
+ // Set the JAR (so that the required libraries are available) and run.
+ job.setJarByClass(StreamingKMeansDriver.class);
+
+ // Run job!
+ long start = System.currentTimeMillis();
+ if (!job.waitForCompletion(true)) {
+ return -1;
+ }
+ long end = System.currentTimeMillis();
+
+ log.info("StreamingKMeans clustering complete. Results are in {}. Took {} ms", output.toString(), end - start);
+ return 0;
+ }
+
+ /**
+ * Constructor to be used by the ToolRunner.
+ */
+ private StreamingKMeansDriver() {}
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new StreamingKMeansDriver(), args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
new file mode 100644
index 0000000..ced11ea
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+
+public class StreamingKMeansMapper extends Mapper<Writable, VectorWritable, IntWritable, CentroidWritable> {
+ private static final int NUM_ESTIMATE_POINTS = 1000;
+
+ /**
+ * The clusterer object used to cluster the points received by this mapper online.
+ */
+ private StreamingKMeans clusterer;
+
+ /**
+ * Number of points clustered so far.
+ */
+ private int numPoints = 0;
+
+ private boolean estimateDistanceCutoff = false;
+
+ private List<Centroid> estimatePoints;
+
+ @Override
+ public void setup(Context context) {
+ // At this point the configuration received from the Driver is assumed to be valid.
+ // No other checks are made.
+ Configuration conf = context.getConfiguration();
+ UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
+ int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
+ double estimatedDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF,
+ StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF);
+ if (estimatedDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) {
+ estimateDistanceCutoff = true;
+ estimatePoints = Lists.newArrayList();
+ }
+ // There is no way of estimating the distance cutoff unless we have some data.
+ clusterer = new StreamingKMeans(searcher, numClusters, estimatedDistanceCutoff);
+ }
+
+ private void clusterEstimatePoints() {
+ clusterer.setDistanceCutoff(ClusteringUtils.estimateDistanceCutoff(
+ estimatePoints, clusterer.getDistanceMeasure()));
+ clusterer.cluster(estimatePoints);
+ estimateDistanceCutoff = false;
+ }
+
+ @Override
+ public void map(Writable key, VectorWritable point, Context context) {
+ Centroid centroid = new Centroid(numPoints++, point.get(), 1);
+ if (estimateDistanceCutoff) {
+ if (numPoints < NUM_ESTIMATE_POINTS) {
+ estimatePoints.add(centroid);
+ } else if (numPoints == NUM_ESTIMATE_POINTS) {
+ clusterEstimatePoints();
+ }
+ } else {
+ clusterer.cluster(centroid);
+ }
+ }
+
+ @Override
+ public void cleanup(Context context) throws IOException, InterruptedException {
+ // We should cluster the points at the end if they haven't yet been clustered.
+ if (estimateDistanceCutoff) {
+ clusterEstimatePoints();
+ }
+ // Reindex the centroids before passing them to the reducer.
+ clusterer.reindexCentroids();
+ // All outputs have the same key to go to the same final reducer.
+ for (Centroid centroid : clusterer) {
+ context.write(new IntWritable(0), new CentroidWritable(centroid));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
new file mode 100644
index 0000000..2b78acc
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.clustering.streaming.cluster.BallKMeans;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class StreamingKMeansReducer extends Reducer<IntWritable, CentroidWritable, IntWritable, CentroidWritable> {
+
+ private static final Logger log = LoggerFactory.getLogger(StreamingKMeansReducer.class);
+
+ /**
+ * Configuration for the MapReduce job.
+ */
+ private Configuration conf;
+
+ @Override
+ public void setup(Context context) {
+ // At this point the configuration received from the Driver is assumed to be valid.
+ // No other checks are made.
+ conf = context.getConfiguration();
+ }
+
+ @Override
+ public void reduce(IntWritable key, Iterable<CentroidWritable> centroids,
+ Context context) throws IOException, InterruptedException {
+ List<Centroid> intermediateCentroids;
+ // There might be too many intermediate centroids to fit into memory, in which case, we run another pass
+ // of StreamingKMeans to collapse the clusters further.
+ if (conf.getBoolean(StreamingKMeansDriver.REDUCE_STREAMING_KMEANS, false)) {
+ intermediateCentroids = Lists.newArrayList(
+ new StreamingKMeansThread(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() {
+ @Override
+ public Centroid apply(CentroidWritable input) {
+ Preconditions.checkNotNull(input);
+ return input.getCentroid().clone();
+ }
+ }), conf).call());
+ } else {
+ intermediateCentroids = centroidWritablesToList(centroids);
+ }
+
+ int index = 0;
+ for (Vector centroid : getBestCentroids(intermediateCentroids, conf)) {
+ context.write(new IntWritable(index), new CentroidWritable((Centroid) centroid));
+ ++index;
+ }
+ }
+
+ public static List<Centroid> centroidWritablesToList(Iterable<CentroidWritable> centroids) {
+ // A new list must be created because Hadoop iterators mutate the contents of the Writable in
+ // place, without allocating new references when iterating through the centroids Iterable.
+ return Lists.newArrayList(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() {
+ @Override
+ public Centroid apply(CentroidWritable input) {
+ Preconditions.checkNotNull(input);
+ return input.getCentroid().clone();
+ }
+ }));
+ }
+
+ public static Iterable<Vector> getBestCentroids(List<Centroid> centroids, Configuration conf) {
+
+ if (log.isInfoEnabled()) {
+ log.info("Number of Centroids: {}", centroids.size());
+ }
+
+ int numClusters = conf.getInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, 1);
+ int maxNumIterations = conf.getInt(StreamingKMeansDriver.MAX_NUM_ITERATIONS, 10);
+ float trimFraction = conf.getFloat(StreamingKMeansDriver.TRIM_FRACTION, 0.9f);
+ boolean kMeansPlusPlusInit = !conf.getBoolean(StreamingKMeansDriver.RANDOM_INIT, false);
+ boolean correctWeights = !conf.getBoolean(StreamingKMeansDriver.IGNORE_WEIGHTS, false);
+ float testProbability = conf.getFloat(StreamingKMeansDriver.TEST_PROBABILITY, 0.1f);
+ int numRuns = conf.getInt(StreamingKMeansDriver.NUM_BALLKMEANS_RUNS, 3);
+
+ BallKMeans ballKMeansCluster = new BallKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(conf),
+ numClusters, maxNumIterations, trimFraction, kMeansPlusPlusInit, correctWeights, testProbability, numRuns);
+ return ballKMeansCluster.cluster(centroids);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
new file mode 100644
index 0000000..acb2b56
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
@@ -0,0 +1,92 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.concurrent.Callable;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class StreamingKMeansThread implements Callable<Iterable<Centroid>> {
+ private static final Logger log = LoggerFactory.getLogger(StreamingKMeansThread.class);
+
+ private static final int NUM_ESTIMATE_POINTS = 1000;
+
+ private final Configuration conf;
+ private final Iterable<Centroid> dataPoints;
+
+ public StreamingKMeansThread(Path input, Configuration conf) {
+ this(StreamingKMeansUtilsMR.getCentroidsFromVectorWritable(
+ new SequenceFileValueIterable<VectorWritable>(input, false, conf)), conf);
+ }
+
+ public StreamingKMeansThread(Iterable<Centroid> dataPoints, Configuration conf) {
+ this.dataPoints = dataPoints;
+ this.conf = conf;
+ }
+
+ @Override
+ public Iterable<Centroid> call() {
+ UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
+ int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
+ double estimateDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF,
+ StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF);
+
+ Iterator<Centroid> dataPointsIterator = dataPoints.iterator();
+
+ if (estimateDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) {
+ List<Centroid> estimatePoints = Lists.newArrayListWithExpectedSize(NUM_ESTIMATE_POINTS);
+ while (dataPointsIterator.hasNext() && estimatePoints.size() < NUM_ESTIMATE_POINTS) {
+ Centroid centroid = dataPointsIterator.next();
+ estimatePoints.add(centroid);
+ }
+
+ if (log.isInfoEnabled()) {
+ log.info("Estimated Points: {}", estimatePoints.size());
+ }
+ estimateDistanceCutoff = ClusteringUtils.estimateDistanceCutoff(estimatePoints, searcher.getDistanceMeasure());
+ }
+
+ StreamingKMeans streamingKMeans = new StreamingKMeans(searcher, numClusters, estimateDistanceCutoff);
+
+ // datapointsIterator could be empty if no estimate distance was initially provided
+ // hence creating the iterator again here for the clustering
+ if (!dataPointsIterator.hasNext()) {
+ dataPointsIterator = dataPoints.iterator();
+ }
+
+ while (dataPointsIterator.hasNext()) {
+ streamingKMeans.cluster(dataPointsIterator.next());
+ }
+
+ streamingKMeans.reindexCentroids();
+ return streamingKMeans;
+ }
+
+}