You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/03/01 06:42:36 UTC
svn commit: r917396 [1/3] - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/canopy/
core/src/main/java/org/apache/mahout/clustering/dirichlet/
core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/
core/src/main/java/org/...
Author: robinanil
Date: Mon Mar 1 05:42:35 2010
New Revision: 917396
URL: http://svn.apache.org/viewvc?rev=917396&view=rev
Log:
MAHOUT-307 MAHOUT-295 MAHOUT-304 clustering improvements
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyCreatorMapper.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/canopy/DisplayCanopy.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/kmeans/DisplayKMeans.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/meanshift/OutputMapper.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java Mon Mar 1 05:42:35 2010
@@ -18,6 +18,8 @@
package org.apache.mahout.clustering.canopy;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
import java.util.List;
import org.apache.hadoop.io.Text;
@@ -174,7 +176,7 @@
isCovered = true;
vw.set(point);
collector.collect(new Text(canopy.getIdentifier()), vw);
- reporter.setStatus("Emit Canopy ID:" +canopy.getIdentifier());
+ reporter.setStatus("Emit Canopy ID:" + canopy.getIdentifier());
} else if (dist < minDist) {
minDist = dist;
closest = canopy;
@@ -199,4 +201,78 @@
public boolean canopyCovers(Canopy canopy, Vector point) {
return measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point) < t1;
}
+
+ /**
+ * Iterate through the points, adding new canopies. Return the canopies.
+ *
+ * @param points
+ * a list<Vector> defining the points to be clustered
+ * @param measure
+ * a DistanceMeasure to use
+ * @param t1
+ * the T1 distance threshold
+ * @param t2
+ * the T2 distance threshold
+ * @return the List<Canopy> created
+ */
+ public static List<Canopy> createCanopies(List<Vector> points, DistanceMeasure measure, double t1, double t2) {
+ List<Canopy> canopies = new ArrayList<Canopy>();
+ /**
+ * Reference Implementation: Given a distance metric, one can create canopies as follows: Start with a
+ * list of the data points in any order, and with two distance thresholds, T1 and T2, where T1 > T2.
+ * (These thresholds can be set by the user, or selected by cross-validation.) Pick a point on the list
+ * and measure its distance to all other points. Put all points that are within distance threshold T1 into
+ * a canopy. Remove from the list all points that are within distance threshold T2. Repeat until the list
+ * is empty.
+ */
+ int nextCanopyId = 0;
+ while (!points.isEmpty()) {
+ Iterator<Vector> ptIter = points.iterator();
+ Vector p1 = ptIter.next();
+ ptIter.remove();
+ Canopy canopy = new Canopy(p1, nextCanopyId++);
+ canopies.add(canopy);
+ while (ptIter.hasNext()) {
+ Vector p2 = ptIter.next();
+ double dist = measure.distance(p1, p2);
+ // Put all points that are within distance threshold T1 into the canopy
+ if (dist < t1) {
+ canopy.addPoint(p2);
+ }
+ // Remove from the list all points that are within distance threshold T2
+ if (dist < t2) {
+ ptIter.remove();
+ }
+ }
+ }
+ return canopies;
+ }
+
+ /**
+ * Iterate through the canopies, adding their centroids to a list
+ *
+ * @param canopies
+ * a List<Canopy>
+ * @return the List<Vector>
+ */
+ public static List<Vector> calculateCentroids(List<Canopy> canopies) {
+ List<Vector> result = new ArrayList<Vector>();
+ for (Canopy canopy : canopies) {
+ result.add(canopy.computeCentroid());
+ }
+ return result;
+ }
+
+ /**
+ * Iterate through the canopies, resetting their center to their centroids
+ *
+ * @param canopies
+ * a List<Canopy>
+ */
+ public static void updateCentroids(List<Canopy> canopies) {
+ for (Canopy canopy : canopies) {
+ canopy.setCenter(canopy.computeCentroid());
+ }
+ }
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java Mon Mar 1 05:42:35 2010
@@ -125,7 +125,7 @@
this.thin = thin;
this.burnin = burnin;
this.numClusters = numClusters;
- state = new DirichletState<O>(modelFactory, numClusters, alpha_0, thin, burnin);
+ state = new DirichletState<O>(modelFactory, numClusters, alpha_0);
}
/**
@@ -198,4 +198,34 @@
return pi;
}
+ /**
+ * Create a new instance on the sample data with the given additional parameters
+ *
+ * @param points
+ * the observed data to be clustered
+ * @param modelFactory
+ * the ModelDistribution to use
+ * @param alpha_0
+ * the double value for the beta distributions
+ * @param numClusters
+ * the int number of clusters
+ * @param thin
+ * the int thinning interval, used to report every n iterations
+ * @param burnin
+ * the int burnin interval, used to suppress early iterations
+ * @param numIterations
+ * number of iterations to be performed
+ */
+ public static List<Model<Vector>[]> clusterPoints(List<Vector> points,
+ ModelDistribution<Vector> modelFactory,
+ double alpha_0,
+ int numClusters,
+ int thin,
+ int burnin,
+ int numIterations) {
+ DirichletClusterer<Vector> clusterer = new DirichletClusterer<Vector>(points, modelFactory, alpha_0,
+ numClusters, thin, burnin);
+ return clusterer.cluster(numIterations);
+
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Mon Mar 1 05:42:35 2010
@@ -65,7 +65,7 @@
private static final Logger log = LoggerFactory.getLogger(DirichletDriver.class);
- private DirichletDriver() { }
+ private DirichletDriver() {}
public static void main(String[] args) throws Exception {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
@@ -84,15 +84,18 @@
Option modelOpt = obuilder.withLongName("modelClass").withRequired(true).withShortName("d").withArgument(
abuilder.withName("modelClass").withMinimum(1).withMaximum(1).create()).withDescription(
- "The ModelDistribution class name.").create();
+ "The ModelDistribution class name. "
+ + "Defaults to org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution").create();
- Option prototypeOpt = obuilder.withLongName("modelPrototypeClass").withRequired(true).withShortName("p")
+ Option prototypeOpt = obuilder.withLongName("modelPrototypeClass").withRequired(false).withShortName("p")
.withArgument(abuilder.withName("prototypeClass").withMinimum(1).withMaximum(1).create())
- .withDescription("The ModelDistribution prototype Vector class name.").create();
+ .withDescription(
+ "The ModelDistribution prototype Vector class name. "
+ + "Defaults to org.apache.mahout.math.RandomAccessSparseVector").create();
Option sizeOpt = obuilder.withLongName("prototypeSize").withRequired(true).withShortName("s")
.withArgument(abuilder.withName("prototypeSize").withMinimum(1).withMaximum(1).create())
- .withDescription("The ModelDistribution prototype Vector size.").create();
+ .withDescription("The ModelDistribution prototype Vector size. ").create();
Option numRedOpt = obuilder.withLongName("maxRed").withRequired(true).withShortName("r").withArgument(
abuilder.withName("maxRed").withMinimum(1).withMaximum(1).create()).withDescription(
@@ -113,15 +116,17 @@
String input = cmdLine.getValue(inputOpt).toString();
String output = cmdLine.getValue(outputOpt).toString();
- String modelFactory = cmdLine.getValue(modelOpt).toString();
- String modelPrototype = cmdLine.getValue(prototypeOpt).toString();
+ String modelFactory = "org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution";
+ if (cmdLine.hasOption(modelOpt)) modelFactory = cmdLine.getValue(modelOpt).toString();
+ String modelPrototype = "org.apache.mahout.math.RandomAccessSparseVector";
+ if (cmdLine.hasOption(prototypeOpt)) modelPrototype = cmdLine.getValue(prototypeOpt).toString();
int prototypeSize = Integer.parseInt(cmdLine.getValue(sizeOpt).toString());
int numReducers = Integer.parseInt(cmdLine.getValue(numRedOpt).toString());
int numModels = Integer.parseInt(cmdLine.getValue(topicsOpt).toString());
int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
double alpha_0 = Double.parseDouble(cmdLine.getValue(mOpt).toString());
- runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels,
- maxIterations, alpha_0, numReducers);
+ runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels, maxIterations, alpha_0,
+ numReducers);
} catch (OptionException e) {
log.error("Exception parsing command line: ", e);
CommandLineUtil.printHelp(group);
@@ -161,8 +166,8 @@
SecurityException,
NoSuchMethodException,
InvocationTargetException {
- runJob(input, output, modelFactory, "org.apache.mahout.math.DenseVector", 2, numClusters,
- maxIterations, alpha_0, numReducers);
+ runJob(input, output, modelFactory, "org.apache.mahout.math.DenseVector", 2, numClusters, maxIterations,
+ alpha_0, numReducers);
}
/**
@@ -200,15 +205,14 @@
InvocationTargetException {
String stateIn = output + "/state-0";
- writeInitialState(output, stateIn, modelFactory, modelPrototype, prototypeSize,
- numClusters, alpha_0);
+ writeInitialState(output, stateIn, modelFactory, modelPrototype, prototypeSize, numClusters, alpha_0);
for (int iteration = 0; iteration < maxIterations; iteration++) {
log.info("Iteration {}", iteration);
// point the output to a new directory per iteration
String stateOut = output + "/state-" + (iteration + 1);
- runIteration(input, stateIn, stateOut, modelFactory, modelPrototype, prototypeSize,
- numClusters, alpha_0, numReducers);
+ runIteration(input, stateIn, stateOut, modelFactory, modelPrototype, prototypeSize, numClusters,
+ alpha_0, numReducers);
// now point the input to the old output directory
stateIn = stateOut;
}
@@ -228,8 +232,8 @@
NoSuchMethodException,
InvocationTargetException {
- DirichletState<VectorWritable> state = createState(modelFactory, modelPrototype,
- prototypeSize, numModels, alpha_0);
+ DirichletState<VectorWritable> state = createState(modelFactory, modelPrototype, prototypeSize,
+ numModels, alpha_0);
JobConf job = new JobConf(KMeansDriver.class);
Path outPath = new Path(output);
FileSystem fs = FileSystem.get(outPath.toUri(), job);
@@ -278,7 +282,7 @@
Class<? extends Vector> vcl = ccl.loadClass(modelPrototype).asSubclass(Vector.class);
Constructor<? extends Vector> v = vcl.getConstructor(int.class);
factory.setModelPrototype(new VectorWritable(v.newInstance(prototypeSize)));
- return new DirichletState<VectorWritable>(factory, numModels, alpha_0, 1, 1);
+ return new DirichletState<VectorWritable>(factory, numModels, alpha_0);
}
/**
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java Mon Mar 1 05:42:35 2010
@@ -41,7 +41,7 @@
private static final Logger log = LoggerFactory.getLogger(DirichletJob.class);
- private DirichletJob() { }
+ private DirichletJob() {}
public static void main(String[] args) throws Exception {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
@@ -58,13 +58,17 @@
abuilder.withName("alpha").withMinimum(1).withMaximum(1).create()).withDescription(
"The alpha0 value for the DirichletDistribution.").create();
- Option modelOpt = obuilder.withLongName("modelClass").withRequired(true).withShortName("d").withArgument(
- abuilder.withName("modelClass").withMinimum(1).withMaximum(1).create()).withDescription(
- "The ModelDistribution class name.").create();
+ Option modelOpt = obuilder.withLongName("modelClass").withRequired(false).withShortName("d")
+ .withArgument(abuilder.withName("modelClass").withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "The ModelDistribution class name."
+ + "Defaults to org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution").create();
- Option prototypeOpt = obuilder.withLongName("modelPrototypeClass").withRequired(true).withShortName("p")
+ Option prototypeOpt = obuilder.withLongName("modelPrototypeClass").withRequired(false).withShortName("p")
.withArgument(abuilder.withName("prototypeClass").withMinimum(1).withMaximum(1).create())
- .withDescription("The ModelDistribution prototype Vector class name.").create();
+ .withDescription(
+ "The ModelDistribution prototype Vector class name."
+ + "Defaults to org.apache.mahout.math.RandomAccessSparseVector").create();
Option sizeOpt = obuilder.withLongName("prototypeSize").withRequired(true).withShortName("s")
.withArgument(abuilder.withName("prototypeSize").withMinimum(1).withMaximum(1).create())
@@ -85,14 +89,15 @@
String input = cmdLine.getValue(inputOpt).toString();
String output = cmdLine.getValue(outputOpt).toString();
- String modelFactory = cmdLine.getValue(modelOpt).toString();
- String modelPrototype = cmdLine.getValue(prototypeOpt).toString();
+ String modelFactory = "org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution";
+ if (cmdLine.hasOption(modelOpt)) modelFactory = cmdLine.getValue(modelOpt).toString();
+ String modelPrototype = "org.apache.mahout.math.RandomAccessSparseVector";
+ if (cmdLine.hasOption(prototypeOpt)) modelPrototype = cmdLine.getValue(prototypeOpt).toString();
int prototypeSize = Integer.parseInt(cmdLine.getValue(sizeOpt).toString());
int numModels = Integer.parseInt(cmdLine.getValue(topicsOpt).toString());
int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
double alpha_0 = Double.parseDouble(cmdLine.getValue(mOpt).toString());
- runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels,
- maxIterations, alpha_0);
+ runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels, maxIterations, alpha_0);
} catch (OptionException e) {
log.error("Exception parsing command line: ", e);
CommandLineUtil.printHelp(group);
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java Mon Mar 1 05:42:35 2010
@@ -39,9 +39,7 @@
public DirichletState(ModelDistribution<O> modelFactory,
int numClusters,
- double alpha_0,
- int thin,
- int burnin) {
+ double alpha_0) {
this.numClusters = numClusters;
this.modelFactory = modelFactory;
this.alpha_0 = alpha_0;
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java Mon Mar 1 05:42:35 2010
@@ -29,19 +29,11 @@
public class FuzzyKMeansClusterer {
- private static final double MINIMAL_VALUE = 0.0000000001; // using it for
- // adding
- // exception
- // this value to any
- // zero valued
- // variable to avoid
- // divide by Zero
-
- // private int nextClusterId = 0;
+ private static final double MINIMAL_VALUE = 0.0000000001;
private DistanceMeasure measure;
- private double convergenceDelta = 0;
+ private double convergenceDelta;
private double m = 2.0; // default value
@@ -65,20 +57,6 @@
}
/**
- * Configure the distance measure directly. Used by unit tests.
- *
- * @param aMeasure
- * the DistanceMeasure
- * @param aConvergenceDelta
- * the delta value used to define convergence
- */
- private void config(DistanceMeasure aMeasure, double aConvergenceDelta) {
- measure = aMeasure;
- convergenceDelta = aConvergenceDelta;
- // nextClusterId = 0;
- }
-
- /**
* Configure the distance measure from the job
*
* @param job
@@ -123,10 +101,8 @@
for (int i = 0; i < clusters.size(); i++) {
double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
- Text key = new Text(clusters.get(i).getIdentifier()); // just output the
- // identifier,avoids
- // too much data
- // traffic
+ Text key = new Text(clusters.get(i).getIdentifier());
+ // just output the identifier,avoids too much data traffic
/*
* Text value = new Text(Double.toString(probWeight) + FuzzyKMeansDriver.MAPPER_VALUE_SEPARATOR +
* values.toString());
@@ -154,7 +130,7 @@
List<Double> clusterDistanceList = new ArrayList<Double>();
for (SoftCluster cluster : clusters) {
- clusterDistanceList.add(measure.distance(point, cluster.getCenter()));
+ clusterDistanceList.add(measure.distance(cluster.getCenter(), point));
}
FuzzyKMeansOutput fOutput = new FuzzyKMeansOutput(clusters.size());
for (int i = 0; i < clusters.size(); i++) {
@@ -175,9 +151,7 @@
if (eachCDist == 0.0) {
eachCDist = MINIMAL_VALUE;
}
-
denom += Math.pow(clusterDistance / eachCDist, 2.0 / (m - 1));
-
}
return 1.0 / denom;
}
@@ -200,4 +174,81 @@
public DistanceMeasure getMeasure() {
return this.measure;
}
+
+ /**
+ * This is the reference k-means implementation. Given its inputs it iterates over the points and clusters
+ * until their centers converge or until the maximum number of iterations is exceeded.
+ *
+ * @param points
+ * the input List<Vector> of points
+ * @param clusters
+ * the initial List<SoftCluster> of clusters
+ * @param measure
+ * the DistanceMeasure to use
+ * @param maxIter
+ * the maximum number of iterations
+ */
+ public static List<List<SoftCluster>> clusterPoints(List<Vector> points,
+ List<SoftCluster> clusters,
+ DistanceMeasure measure,
+ double threshold,
+ double m,
+ int numIter) {
+ List<List<SoftCluster>> clustersList = new ArrayList<List<SoftCluster>>();
+ clustersList.add(clusters);
+ FuzzyKMeansClusterer clusterer = new FuzzyKMeansClusterer(measure, threshold, m);
+ boolean converged = false;
+ int iteration = 0;
+ for (int iter = 0; !converged && iter < numIter; iter++) {
+ List<SoftCluster> next = new ArrayList<SoftCluster>();
+ List<SoftCluster> cs = clustersList.get(iteration++);
+ for (SoftCluster c : cs) {
+ next.add(new SoftCluster(c.getCenter(), c.getId()));
+ }
+ clustersList.add(next);
+ converged = runFuzzyKMeansIteration(points, clustersList.get(iteration), clusterer);
+ }
+ return clustersList;
+ }
+
+ /**
+ * Perform a single iteration over the points and clusters, assigning points to clusters and returning if
+ * the iterations are completed.
+ *
+ * @param points
+ * the List<Vector> having the input points
+ * @param clusterList
+ * the List<Cluster> clusters
+ * @return
+ */
+ public static boolean runFuzzyKMeansIteration(List<Vector> points,
+ List<SoftCluster> clusterList,
+ FuzzyKMeansClusterer clusterer) {
+ // for each
+ for (Vector point : points) {
+ List<Double> clusterDistanceList = new ArrayList<Double>();
+ for (SoftCluster cluster : clusterList) {
+ clusterDistanceList.add(clusterer.getMeasure().distance(point, cluster.getCenter()));
+ }
+
+ for (int i = 0; i < clusterList.size(); i++) {
+ double probWeight = clusterer.computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
+ clusterList.get(i).addPoint(point, Math.pow(probWeight, clusterer.getM()));
+ }
+ }
+ boolean converged = true;
+ for (SoftCluster cluster : clusterList) {
+ if (!clusterer.computeConvergence(cluster)) {
+ converged = false;
+ }
+ }
+ // update the cluster centers
+ if (!converged) {
+ for (SoftCluster cluster : clusterList) {
+ cluster.recomputeCenter();
+ }
+ }
+ return converged;
+
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java Mon Mar 1 05:42:35 2010
@@ -21,35 +21,30 @@
import java.io.DataOutput;
import java.io.IOException;
-import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.ClusterBase;
import org.apache.mahout.math.AbstractVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.SquareRootFunction;
-public class SoftCluster implements Writable {
-
- // this cluster's clusterId
- private int clusterId;
-
- // the current center
- private Vector center = new RandomAccessSparseVector(0);
-
+public class SoftCluster extends ClusterBase{
+
// the current centroid is lazy evaluated and may be null
- private Vector centroid = null;
+ private Vector centroid;
// The Probability of belongingness sum
- private double pointProbSum = 0.0;
+ private double pointProbSum;
// the total of all points added to the cluster
- private Vector weightedPointTotal = null;
+ private Vector weightedPointTotal;
// has the centroid converged with the center?
- private boolean converged = false;
+ private boolean converged;
// track membership parameters
- private double s0 = 0;
+ private double s0;
private Vector s1;
@@ -87,10 +82,25 @@
}
return null;
}
+
+ // For Writable
+ public SoftCluster() { }
+
+ /**
+ * Construct a new SoftCluster with the given point as its center
+ *
+ * @param center
+ * the center point
+ */
+ public SoftCluster(Vector center) {
+ setCenter(new RandomAccessSparseVector(center));
+ this.pointProbSum = 0;
+ this.weightedPointTotal = getCenter().like();
+ }
@Override
public void write(DataOutput out) throws IOException {
- out.writeInt(clusterId);
+ out.writeInt(this.getId());
out.writeBoolean(converged);
Vector vector = computeCentroid();
VectorWritable.writeVector(out, vector);
@@ -98,13 +108,13 @@
@Override
public void readFields(DataInput in) throws IOException {
- clusterId = in.readInt();
+ this.setId(in.readInt());
converged = in.readBoolean();
VectorWritable temp = new VectorWritable();
temp.readFields(in);
- center = temp.get();
+ this.setCenter(new RandomAccessSparseVector(temp.get()));
this.pointProbSum = 0;
- this.weightedPointTotal = center.like();
+ this.weightedPointTotal = getCenter().like();
}
/**
@@ -112,6 +122,7 @@
*
* @return the new centroid
*/
+ @Override
public Vector computeCentroid() {
if (pointProbSum == 0) {
return weightedPointTotal;
@@ -122,22 +133,6 @@
return centroid;
}
- // For Writable
- public SoftCluster() { }
-
- /**
- * Construct a new SoftCluster with the given point as its center
- *
- * @param center
- * the center point
- */
- public SoftCluster(Vector center) {
- this.center = center;
- this.pointProbSum = 0;
-
- this.weightedPointTotal = center.like();
- }
-
/**
* Construct a new SoftCluster with the given point as its center
*
@@ -145,8 +140,8 @@
* the center point
*/
public SoftCluster(Vector center, int clusterId) {
- this.clusterId = clusterId;
- this.center = center;
+ this.setId(clusterId);
+ this.setCenter(center);
this.pointProbSum = 0;
this.weightedPointTotal = center.like();
}
@@ -154,7 +149,7 @@
/** Construct a new softcluster with the given clusterID */
public SoftCluster(String clusterId) {
- this.clusterId = Integer.parseInt(clusterId.substring(1));
+ this.setId(Integer.parseInt(clusterId.substring(1)));
this.pointProbSum = 0;
// this.weightedPointTotal = center.like();
this.converged = clusterId.charAt(0) == 'V';
@@ -162,14 +157,15 @@
@Override
public String toString() {
- return getIdentifier() + " - " + center.asFormatString();
+ return getIdentifier() + " - " + getCenter().asFormatString();
}
+ @Override
public String getIdentifier() {
if (converged) {
- return "V" + clusterId;
+ return "V" + this.getId();
} else {
- return "C" + clusterId;
+ return "C" + this.getId();
}
}
@@ -212,9 +208,9 @@
centroid = null;
pointProbSum += ptProb;
if (weightedPointTotal == null) {
- weightedPointTotal = point.clone().times(ptProb);
+ weightedPointTotal = point.clone().assign(Functions.mult, ptProb);
} else {
- weightedPointTotal = weightedPointTotal.plus(point.times(ptProb));
+ point.clone().assign(Functions.mult, ptProb).addTo(weightedPointTotal);
}
}
@@ -230,33 +226,25 @@
if (weightedPointTotal == null) {
weightedPointTotal = delta.clone();
} else {
- weightedPointTotal = weightedPointTotal.plus(delta);
+ delta.addTo(weightedPointTotal);
}
}
- public Vector getCenter() {
- return center;
- }
-
public double getPointProbSum() {
return pointProbSum;
}
/** Compute the centroid and set the center to it. */
public void recomputeCenter() {
- center = computeCentroid();
+ this.setCenter(computeCentroid());
pointProbSum = 0;
- weightedPointTotal = center.like();
+ weightedPointTotal = getCenter().like();
}
public Vector getWeightedPointTotal() {
return weightedPointTotal;
}
-
- public void setWeightedPointTotal(Vector v) {
- this.weightedPointTotal = v;
- }
-
+
public boolean isConverged() {
return converged;
}
@@ -265,8 +253,9 @@
this.converged = converged;
}
- public int getClusterId() {
- return clusterId;
+ @Override
+ public String asFormatString() {
+ return formatCluster(this);
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java Mon Mar 1 05:42:35 2010
@@ -43,6 +43,8 @@
/** Has the centroid converged with the center? */
private boolean converged;
+ private double std = 0.00000001;
+
/**
* Format the cluster for output
*
@@ -134,8 +136,8 @@
super();
this.setCenter(new RandomAccessSparseVector(center));
this.setNumPoints(0);
- this.setPointTotal(center.like());
- this.pointSquaredTotal = center.like();
+ this.setPointTotal(getCenter().like());
+ this.pointSquaredTotal = getCenter().like();
}
/** For (de)serialization as a Writable */
@@ -152,7 +154,7 @@
this.setId(clusterId);
this.setCenter(new RandomAccessSparseVector(center));
this.setNumPoints(0);
- this.setPointTotal(center.like());
+ this.setPointTotal(getCenter().like());
this.pointSquaredTotal = getCenter().like();
}
@@ -194,21 +196,24 @@
*/
public void addPoints(int count, Vector delta) {
centroid = null;
- setNumPoints(getNumPoints() + count);
- if (getPointTotal() == null) {
- setPointTotal(delta.clone());
+ if (getNumPoints() == 0) {
+ setPointTotal(new RandomAccessSparseVector(delta.clone()));
pointSquaredTotal = new RandomAccessSparseVector(delta.clone().assign(Functions.square));
} else {
delta.addTo(getPointTotal());
delta.clone().assign(Functions.square).addTo(pointSquaredTotal);
}
+ setNumPoints(getNumPoints() + count);
}
/** Compute the centroid and set the center to it. */
public void recomputeCenter() {
+ std = getStd();
setCenter(computeCentroid());
+ centroid = null;
setNumPoints(0);
- setPointTotal(getCenter().like());
+ this.setPointTotal(getCenter().like());
+ this.pointSquaredTotal = getCenter().like();
}
/**
@@ -236,6 +241,7 @@
/** @return the std */
public double getStd() {
+ if (getNumPoints() == 0) return std;
Vector stds = pointSquaredTotal.times(getNumPoints()).minus(getPointTotal().times(getPointTotal()))
.assign(new SquareRootFunction()).divide(getNumPoints());
return stds.zSum() / stds.size();
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansClusterer.java Mon Mar 1 05:42:35 2010
@@ -17,6 +17,7 @@
package org.apache.mahout.clustering.kmeans;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.io.Text;
@@ -95,4 +96,86 @@
String key = (name != null) && (name.length() != 0) ? name : point.asFormatString();
output.collect(new Text(key), new Text(String.valueOf(nearestCluster.getId())));
}
+
+ /**
+ * This is the reference k-means implementation. Given its inputs it iterates over the points and clusters
+ * until their centers converge or until the maximum number of iterations is exceeded.
+ *
+ * @param points
+ * the input List<Vector> of points
+ * @param clusters
+ * the List<Cluster> of initial clusters
+ * @param measure
+ * the DistanceMeasure to use
+ * @param maxIter
+ * the maximum number of iterations
+ */
+ public static List<List<Cluster>> clusterPoints(List<Vector> points,
+ List<Cluster> clusters,
+ DistanceMeasure measure,
+ int maxIter,
+ double distanceThreshold) {
+ List<List<Cluster>> clustersList = new ArrayList<List<Cluster>>();
+ clustersList.add(clusters);
+
+ boolean converged = false;
+ int iteration = 0;
+ while (!converged && iteration < maxIter) {
+ System.out.println("iteration: " + iteration);
+ List<Cluster> next = new ArrayList<Cluster>();
+ List<Cluster> cs = clustersList.get(iteration++);
+ for (Cluster c : cs) {
+ next.add(new Cluster(c.getCenter()));
+ }
+ clustersList.add(next);
+ converged = runKMeansIteration(points, next, measure, distanceThreshold);
+ }
+ return clustersList;
+ }
+
+ /**
+ * Perform a single iteration over the points and clusters, assigning points to clusters and returning if
+ * the iterations are completed.
+ *
+ * @param points
+ * the List<Vector> having the input points
+ * @param clusters
+ * the List<Cluster> clusters
+ * @param measure
+ * a DistanceMeasure to use
+ * @return
+ */
+ public static boolean runKMeansIteration(List<Vector> points,
+ List<Cluster> clusters,
+ DistanceMeasure measure,
+ double distanceThreshold) {
+ // iterate through all points, assigning each to the nearest cluster
+ for (Vector point : points) {
+ Cluster closestCluster = null;
+ double closestDistance = Double.MAX_VALUE;
+ for (Cluster cluster : clusters) {
+ double distance = measure.distance(cluster.getCenter(), point);
+ if (closestCluster == null || closestDistance > distance) {
+ closestCluster = cluster;
+ closestDistance = distance;
+ }
+ }
+ closestCluster.addPoint(point);
+ }
+ // test for convergence
+ boolean converged = true;
+ for (Cluster cluster : clusters) {
+ if (!cluster.computeConvergence(measure, distanceThreshold)) {
+ converged = false;
+ }
+ }
+ // update the cluster centers
+ if (!converged) {
+ for (Cluster cluster : clusters) {
+ cluster.recomputeCenter();
+ }
+ }
+ return converged;
+ }
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java Mon Mar 1 05:42:35 2010
@@ -30,6 +30,7 @@
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.JobConf;
import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -46,7 +47,7 @@
public static final String K = "k";
- private RandomSeedGenerator() { }
+ private RandomSeedGenerator() {}
public static Path buildRandom(String input, String output, int k) throws IOException,
IllegalAccessException,
@@ -117,4 +118,20 @@
return outFile;
}
+
+ public static List<Vector> chooseRandomPoints(List<Vector> vectors, int k) {
+ List<Vector> chosenPoints = new ArrayList<Vector>(k);
+ Random random = RandomUtils.getRandom();
+ for(Vector value : vectors){
+ int currentSize = chosenPoints.size();
+ if (currentSize < k) {
+ chosenPoints.add(value);
+ } else if (random.nextInt(currentSize + 1) == 0) { // with chance 1/(currentSize+1) pick new element
+ int indexToRemove = random.nextInt(currentSize); // evict one chosen randomly
+ chosenPoints.remove(indexToRemove);
+ chosenPoints.add(value);
+ }
+ }
+ return chosenPoints;
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java Mon Mar 1 05:42:35 2010
@@ -21,16 +21,15 @@
import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.Type;
-import java.util.ArrayList;
-import java.util.List;
import org.apache.mahout.clustering.ClusterBase;
import org.apache.mahout.math.CardinalityException;
-import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.list.IntArrayList;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
@@ -44,7 +43,7 @@
public class MeanShiftCanopy extends ClusterBase {
// TODO: this is problematic, but how else to encode membership?
- private List<Vector> boundPoints = new ArrayList<Vector>();
+ private IntArrayList boundPoints = new IntArrayList();
private boolean converged = false;
@@ -78,9 +77,9 @@
public MeanShiftCanopy(Vector point, int id) {
this.setId(id);
this.setCenter(point);
- this.setPointTotal(point.clone());
+ this.setPointTotal(new RandomAccessSparseVector(point.clone()));
this.setNumPoints(1);
- this.boundPoints.add(point);
+ this.boundPoints.add(id);
}
/**
@@ -91,14 +90,14 @@
* @param id
* an int identifying the canopy local to this process only
* @param boundPoints
- * a List<Vector> containing points bound to the canopy
+ * a IntArrayList containing points ids bound to the canopy
* @param converged
* true if the canopy has converged
*/
- MeanShiftCanopy(Vector point, int id, List<Vector> boundPoints, boolean converged) {
+ MeanShiftCanopy(Vector point, int id, IntArrayList boundPoints, boolean converged) {
this.setId(id);
this.setCenter(point);
- this.setPointTotal(point.clone());
+ this.setPointTotal(new RandomAccessSparseVector(point));
this.setNumPoints(1);
this.boundPoints = boundPoints;
this.converged = converged;
@@ -117,22 +116,14 @@
void addPoints(Vector point, int nPoints) {
setNumPoints(getNumPoints() + nPoints);
Vector subTotal = nPoints == 1 ? point.clone() : point.times(nPoints);
- setPointTotal(getPointTotal() == null ? subTotal : getPointTotal().plus(subTotal));
- }
-
- /**
- * Compute the bound centroid by averaging the bound points
- *
- * @return a Vector which is the new bound centroid
- */
- public Vector computeBoundCentroid() {
- Vector result = new DenseVector(getCenter().size());
- for (Vector v : boundPoints) {
- result.assign(v, Functions.plus);
+ if (getPointTotal() == null) {
+ setPointTotal(new RandomAccessSparseVector(subTotal));
+ } else {
+ subTotal.addTo(getPointTotal());
}
- return result.divide(boundPoints.size());
}
+
/**
* Compute the centroid by normalizing the pointTotal
*
@@ -147,7 +138,7 @@
}
}
- public List<Vector> getBoundPoints() {
+ public IntArrayList getBoundPoints() {
return boundPoints;
}
@@ -164,7 +155,7 @@
setId(canopy.getId());
setCenter(canopy.getCenter());
addPoints(getCenter(), 1);
- boundPoints.addAll(canopy.getBoundPoints());
+ boundPoints.addAllOf(canopy.getBoundPoints());
}
public boolean isConverged() {
@@ -178,7 +169,7 @@
* an existing MeanShiftCanopy
*/
void merge(MeanShiftCanopy canopy) {
- boundPoints.addAll(canopy.boundPoints);
+ boundPoints.addAllOf(canopy.boundPoints);
}
@Override
@@ -204,20 +195,19 @@
temp.readFields(in);
this.setCenter(temp.get());
int numpoints = in.readInt();
- this.boundPoints = new ArrayList<Vector>();
+ this.boundPoints = new IntArrayList();
for (int i = 0; i < numpoints; i++) {
- temp.readFields(in);
- this.boundPoints.add(temp.get());
+ this.boundPoints.add(in.readInt());
}
}
@Override
public void write(DataOutput out) throws IOException {
super.write(out);
- VectorWritable.writeVector(out, computeCentroid());
+ VectorWritable.writeVector(out, new SequentialAccessSparseVector(computeCentroid()));
out.writeInt(boundPoints.size());
- for (Vector v : boundPoints) {
- VectorWritable.writeVector(out, v);
+ for (int v : boundPoints.elements()) {
+ out.writeInt(v);
}
}
@@ -236,7 +226,7 @@
return formatCanopy(this);
}
- public void setBoundPoints(List<Vector> boundPoints) {
+ public void setBoundPoints(IntArrayList boundPoints) {
this.boundPoints = boundPoints;
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyClusterer.java Mon Mar 1 05:42:35 2010
@@ -1,6 +1,7 @@
package org.apache.mahout.clustering.meanshift;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.List;
import org.apache.hadoop.io.Text;
@@ -15,7 +16,7 @@
private double convergenceDelta = 0;
// the next canopyId to be allocated
- //private int nextCanopyId = 0;
+ // private int nextCanopyId = 0;
// the T1 distance threshold
private double t1;
// the T2 distance threshold
@@ -57,7 +58,7 @@
} catch (InstantiationException e) {
throw new IllegalStateException(e);
}
- //nextCanopyId = 0; // never read?
+ // nextCanopyId = 0; // never read?
t1 = Double.parseDouble(job.get(MeanShiftCanopyConfigKeys.T1_KEY));
t2 = Double.parseDouble(job.get(MeanShiftCanopyConfigKeys.T2_KEY));
convergenceDelta = Double.parseDouble(job.get(MeanShiftCanopyConfigKeys.CLUSTER_CONVERGENCE_KEY));
@@ -70,7 +71,7 @@
* the convergence criteria
*/
public void config(DistanceMeasure aMeasure, double aT1, double aT2, double aDelta) {
- //nextCanopyId = 100; // so canopyIds will sort properly // never read?
+ // nextCanopyId = 100; // so canopyIds will sort properly // never read?
measure = aMeasure;
t1 = aT1;
t2 = aT2;
@@ -111,8 +112,7 @@
}
/** Emit the new canopy to the collector, keyed by the canopy's Id */
- static void emitCanopy(MeanShiftCanopy canopy,
- OutputCollector<Text,WritableComparable<?>> collector) throws IOException {
+ static void emitCanopy(MeanShiftCanopy canopy, OutputCollector<Text,WritableComparable<?>> collector) throws IOException {
String identifier = canopy.getIdentifier();
collector.collect(new Text(identifier), new Text("new " + canopy.toString()));
}
@@ -126,8 +126,7 @@
*/
public boolean shiftToMean(MeanShiftCanopy canopy) {
Vector centroid = canopy.computeCentroid();
- canopy
- .setConverged(new EuclideanDistanceMeasure().distance(centroid, canopy.getCenter()) < convergenceDelta);
+ canopy.setConverged(measure.distance(centroid, canopy.getCenter()) < convergenceDelta);
canopy.setCenter(centroid);
canopy.setNumPoints(1);
canopy.setPointTotal(centroid.clone());
@@ -159,4 +158,82 @@
public boolean closelyBound(MeanShiftCanopy canopy, Vector point) {
return measure.distance(canopy.getCenter(), point) < t2;
}
+
+ /**
+ * Story: User can exercise the reference implementation to verify that the test datapoints are clustered in
+ * a reasonable manner.
+ */
+ public void testReferenceImplementation() {
+ MeanShiftCanopyClusterer clusterer = new MeanShiftCanopyClusterer(new EuclideanDistanceMeasure(), 4.0,
+ 1.0, 0.5);
+ List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
+ // add all points to the canopies
+
+ boolean done = false;
+ int iter = 1;
+ while (!done) {// shift canopies to their centroids
+ done = true;
+ List<MeanShiftCanopy> migratedCanopies = new ArrayList<MeanShiftCanopy>();
+ for (MeanShiftCanopy canopy : canopies) {
+ done = clusterer.shiftToMean(canopy) && done;
+ clusterer.mergeCanopy(canopy, migratedCanopies);
+ }
+ canopies = migratedCanopies;
+ System.out.println(iter++);
+ }
+ }
+
+ /**
+ * This is the reference mean-shift implementation. Given its inputs it iterates over the points and
+ * clusters until their centers converge or until the maximum number of iterations is exceeded.
+ *
+ * @param points
+ * the input List<Vector> of points
+ * @param measure
+ * the DistanceMeasure to use
+ * @param maxIter
+ * the maximum number of iterations
+ */
+ public static List<MeanShiftCanopy> clusterPoints(List<Vector> points,
+ DistanceMeasure measure,
+ double convergenceThreshold,
+ double t1,
+ double t2,
+ int numIter) {
+ MeanShiftCanopyClusterer clusterer = new MeanShiftCanopyClusterer(measure, t1, t2, convergenceThreshold);
+
+ List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>();
+ int nextCanopyId = 0;
+ for (Vector point : points) {
+ clusterer.mergeCanopy(new MeanShiftCanopy(point, nextCanopyId++), canopies);
+ }
+
+ boolean converged = false;
+ for (int iter = 0; !converged && iter < numIter; iter++) {
+ converged = runMeanShiftCanopyIteration(canopies, clusterer);
+ }
+ return canopies;
+ }
+
+ /**
+ * Perform a single iteration over the points and clusters, assigning points to clusters and returning if
+ * the iterations are completed.
+ *
+ * @param clusterList
+ * the List<MeanShiftCanopy> clusters
+ * @return
+ */
+ public static boolean runMeanShiftCanopyIteration(List<MeanShiftCanopy> canopies,
+ MeanShiftCanopyClusterer clusterer) {
+ boolean converged = true;
+ List<MeanShiftCanopy> migratedCanopies = new ArrayList<MeanShiftCanopy>();
+ for (MeanShiftCanopy canopy : canopies) {
+ converged = clusterer.shiftToMean(canopy) && converged;
+ clusterer.mergeCanopy(canopy, migratedCanopies);
+ }
+ canopies = migratedCanopies;
+ return converged;
+
+ }
+
}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyCreatorMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyCreatorMapper.java?rev=917396&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyCreatorMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyCreatorMapper.java Mon Mar 1 05:42:35 2010
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.meanshift;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.math.VectorWritable;
+
+public class MeanShiftCanopyCreatorMapper extends MapReduceBase implements
+ Mapper<WritableComparable<?>,VectorWritable,Text,MeanShiftCanopy> {
+
+ private static int nextCanopyId = -1;
+
+ @Override
+ public void map(WritableComparable<?> key,
+ VectorWritable vector,
+ OutputCollector<Text,MeanShiftCanopy> output,
+ Reporter reporter) throws IOException {
+ MeanShiftCanopy canopy = new MeanShiftCanopy(vector.get(), nextCanopyId++);
+ output.collect(new Text(key.toString()), canopy);
+ }
+
+ @Override
+ public void configure(JobConf job) {
+ super.configure(job);
+ if (nextCanopyId == -1) {
+ String taskId = job.get("mapred.task.id");
+ String[] parts = taskId.split("_");
+ if (parts.length != 6 || !parts[0].equals("attempt")
+ || (!"m".equals(parts[3]) && !"r".equals(parts[3]))) {
+ throw new IllegalArgumentException("TaskAttemptId string : " + taskId + " is not properly formed");
+ }
+ nextCanopyId = ((1 << 31) / 50000) * (Integer.parseInt(parts[4]));
+ //each mapper has 42,949 ids to give.
+ }
+ }
+}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyDriver.java Mon Mar 1 05:42:35 2010
@@ -45,7 +45,7 @@
private static final Logger log = LoggerFactory.getLogger(MeanShiftCanopyDriver.class);
- private MeanShiftCanopyDriver() { }
+ private MeanShiftCanopyDriver() {}
public static void main(String[] args) {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
@@ -88,7 +88,8 @@
double t1 = Double.parseDouble(cmdLine.getValue(threshold1Opt).toString());
double t2 = Double.parseDouble(cmdLine.getValue(threshold2Opt).toString());
double convergenceDelta = Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt).toString());
- runJob(input, output, output + MeanShiftCanopyConfigKeys.CONTROL_PATH_KEY,
+ createCanopyFromVectors(input, output + "/intial-canopies");
+ runJob(output + "/intial-canopies", output, output + MeanShiftCanopyConfigKeys.CONTROL_PATH_KEY,
measureClassName, t1, t2, convergenceDelta);
} catch (OptionException e) {
log.error("Exception parsing command line: ", e);
@@ -150,4 +151,47 @@
log.warn(e.toString(), e);
}
}
+
+ /**
+ * Run the job
+ *
+ * @param input
+ * the input pathname String
+ * @param output
+ * the output pathname String
+ * @param control
+ * the control path
+ * @param measureClassName
+ * the DistanceMeasure class name
+ * @param t1
+ * the T1 distance threshold
+ * @param t2
+ * the T2 distance threshold
+ * @param convergenceDelta
+ * the double convergence criteria
+ */
+ public static void createCanopyFromVectors(String input, String output) {
+
+ Configurable client = new JobClient();
+ JobConf conf = new JobConf(MeanShiftCanopyDriver.class);
+
+ conf.setOutputKeyClass(Text.class);
+ conf.setOutputValueClass(MeanShiftCanopy.class);
+
+ FileInputFormat.setInputPaths(conf, new Path(input));
+ Path outPath = new Path(output);
+ FileOutputFormat.setOutputPath(conf, outPath);
+
+ conf.setMapperClass(MeanShiftCanopyCreatorMapper.class);
+ conf.setNumReduceTasks(0);
+ conf.setInputFormat(SequenceFileInputFormat.class);
+ conf.setOutputFormat(SequenceFileOutputFormat.class);
+
+ client.setConf(conf);
+ try {
+ JobClient.runJob(conf);
+ } catch (IOException e) {
+ log.warn(e.toString(), e);
+ }
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java?rev=917396&r1=917395&r2=917396&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopyJob.java Mon Mar 1 05:42:35 2010
@@ -127,10 +127,13 @@
fs.delete(outPath, true);
}
fs.mkdirs(outPath);
+
+ MeanShiftCanopyDriver.createCanopyFromVectors(input, output+"/initial-canopies");
+
// iterate until the clusters converge
boolean converged = false;
int iteration = 0;
- String clustersIn = input;
+ String clustersIn = output+"/initial-canopies";
while (!converged && (iteration < maxIterations)) {
log.info("Iteration {}", iteration);
// point the output to a new directory per iteration