You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/06/26 22:25:07 UTC
svn commit: r788858 - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/kmeans/
core/src/test/java/org/apache/mahout/clustering/kmeans/
examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/
Author: gsingers
Date: Fri Jun 26 20:25:07 2009
New Revision: 788858
URL: http://svn.apache.org/viewvc?rev=788858&view=rev
Log:
MAHOUT-138: KMeans options, Random Seed Generator and delete KMeansJob
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
Removed:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=788858&r1=788857&r2=788858&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Fri Jun 26 20:25:07 2009
@@ -26,9 +26,21 @@
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.matrix.SparseVector;
+import org.apache.mahout.utils.CommandLineUtil;
+import org.apache.mahout.utils.SquaredEuclideanDistanceMeasure;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.commandline.Parser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -48,17 +60,108 @@
*
* @param args Expects 7 args and they all correspond to the order of the params in {@link #runJob}
*/
- public static void main(String[] args) throws ClassNotFoundException {
- String input = args[0];
- String clusters = args[1];
- String output = args[2];
- String measureClass = args[3];
- double convergenceDelta = Double.parseDouble(args[4]);
- int maxIterations = Integer.parseInt(args[5]);
- String vectorClassName = args[6];
- Class<? extends Vector> vectorClass = (Class<? extends Vector>) Class.forName(vectorClassName);
- runJob(input, clusters, output, measureClass, convergenceDelta,
- maxIterations, 2, vectorClass);
+ public static void main(String[] args) throws ClassNotFoundException, IOException, IllegalAccessException, InstantiationException {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).
+ withDescription("The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+ Option clustersOpt = obuilder.withLongName("clusters").withRequired(true).withArgument(
+ abuilder.withName("clusters").withMinimum(1).withMaximum(1).create()).
+ withDescription("The input centroids, as Vectors. Must be a SequenceFile of Writable, Cluster/Canopy. " +
+ "If k is also specified, then a random set of vectors will be selected and written out to this path first").withShortName("c").create();
+ Option kOpt = obuilder.withLongName("k").withRequired(false).withArgument(
+ abuilder.withName("k").withMinimum(1).withMaximum(1).create()).
+ withDescription("The k in k-Means. If specified, then a random selection of k Vectors will be chosen as the Centroid and written to the clusters output path.").withShortName("k").create();
+ Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
+ withDescription("The Path to put the output in").withShortName("o").create();
+ Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withArgument(
+ abuilder.withName("overwrite").withMinimum(1).withMaximum(1).create()).
+ withDescription("If set, overwrite the output directory").withShortName("w").create();
+ Option measureClassOpt = obuilder.withLongName("distance").withRequired(false).withArgument(
+ abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).
+ withDescription("The Distance Measure to use. Default is SquaredEuclidean").withShortName("m").create();
+ Option convergenceDeltaOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(
+ abuilder.withName("convergence").withMinimum(1).withMaximum(1).create()).
+ withDescription("The threshold below which the clusters are considered to be converged. Default is 0.5").withShortName("d").create();
+ Option maxIterationsOpt = obuilder.withLongName("max").withRequired(false).withArgument(
+ abuilder.withName("max").withMinimum(1).withMaximum(1).create()).
+ withDescription("The maximum number of iterations to perform. Default is 20").withShortName("x").create();
+ Option vectorClassOpt = obuilder.withLongName("vectorClass").withRequired(false).withArgument(
+ abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).
+ withDescription("The Vector implementation class name. Default is SparseVector.class").withShortName("v").create();
+ Option numReduceTasksOpt = obuilder.withLongName("numReduce").withRequired(false).withArgument(
+ abuilder.withName("numReduce").withMinimum(1).withMaximum(1).create()).
+ withDescription("The number of reduce tasks").withShortName("r").create();
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(clustersOpt).withOption(outputOpt).withOption(measureClassOpt)
+ .withOption(convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(numReduceTasksOpt).withOption(kOpt).withOption(vectorClassOpt).create();
+ Option helpOpt = obuilder.withLongName("help").
+ withDescription("Print out help").withShortName("h").create();
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+ String input = cmdLine.getValue(inputOpt).toString();
+ String clusters = cmdLine.getValue(clustersOpt).toString();
+ String output = cmdLine.getValue(outputOpt).toString();
+ String measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+ if (cmdLine.hasOption(measureClassOpt)) {
+ measureClass = cmdLine.getValue(measureClassOpt).toString();
+ }
+ double convergenceDelta = 0.5;
+ if (cmdLine.hasOption(convergenceDeltaOpt)) {
+ convergenceDelta = Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt).toString());
+ }
+
+ Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
+ SparseVector.class
+ : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
+
+ if (cmdLine.hasOption(kOpt)) {
+ clusters = buildRandom(input, clusters, Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
+ }
+
+ int maxIterations = 20;
+ if (cmdLine.hasOption(maxIterationsOpt)) {
+ maxIterations = Integer.parseInt(cmdLine.getValue(maxIterationsOpt).toString());
+ }
+ int numReduceTasks = 2;
+ if (cmdLine.hasOption(numReduceTasksOpt)) {
+ numReduceTasks = Integer.parseInt(cmdLine.getValue(numReduceTasksOpt).toString());
+ }
+
+ runJob(input, clusters, output, measureClass, convergenceDelta,
+ maxIterations, numReduceTasks, vectorClass, cmdLine.hasOption(overwriteOutput));
+ } catch (OptionException e) {
+ log.error("Exception", e);
+ CommandLineUtil.printHelp(group);
+ }
+ }
+
+ private static void overwriteOutput(String output) throws IOException {
+ JobConf conf = new JobConf(KMeansDriver.class);
+ Path outPath = new Path(output);
+ FileSystem fs = FileSystem.get(outPath.toUri(), conf);
+ boolean exists = fs.exists(outPath);
+ if (exists == true){
+ log.warn("Deleting " + outPath);
+ fs.delete(outPath, true);
+ }
+ fs.mkdirs(outPath);
+
+ }
+
+ public static Path buildRandom(String input, String clusters, int k) throws IOException, IllegalAccessException, InstantiationException {
+ return RandomSeedGenerator.runJob(input, clusters, k);
}
/**
@@ -72,15 +175,18 @@
* @param maxIterations the maximum number of iterations
* @param numReduceTasks the number of reducers
* @param vectorClass
+ * @param overwrite
*/
public static void runJob(String input, String clustersIn, String output,
String measureClass, double convergenceDelta, int maxIterations,
- int numReduceTasks, Class<? extends Vector> vectorClass) {
+ int numReduceTasks, Class<? extends Vector> vectorClass, boolean overwrite) throws IOException {
// iterate until the clusters converge
boolean converged = false;
int iteration = 0;
String delta = Double.toString(convergenceDelta);
-
+ if (overwrite == true){
+ overwriteOutput(output);
+ }
while (!converged && iteration < maxIterations) {
log.info("Iteration {}", iteration);
// point the output to a new directory per iteration
@@ -111,7 +217,6 @@
private static boolean runIteration(String input, String clustersIn,
String clustersOut, String measureClass, String convergenceDelta,
int numReduceTasks, int iteration) {
- JobClient client = new JobClient();
JobConf conf = new JobConf(KMeansDriver.class);
conf.setMapOutputKeyClass(Text.class);
conf.setMapOutputValueClass(KMeansInfo.class);
@@ -132,7 +237,7 @@
conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
conf.setInt(Cluster.ITERATION_NUMBER, iteration);
- client.setConf(conf);
+
try {
JobClient.runJob(conf);
FileSystem fs = FileSystem.get(outPath.toUri(), conf);
@@ -154,7 +259,6 @@
*/
private static void runClustering(String input, String clustersIn,
String output, String measureClass, String convergenceDelta, Class<? extends Vector> vectorClass) {
- JobClient client = new JobClient();
JobConf conf = new JobConf(KMeansDriver.class);
conf.setInputFormat(SequenceFileInputFormat.class);
conf.setOutputFormat(SequenceFileOutputFormat.class);
@@ -175,7 +279,6 @@
conf.set(Cluster.DISTANCE_MEASURE_KEY, measureClass);
conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
- client.setConf(conf);
try {
JobClient.runJob(conf);
} catch (IOException e) {
@@ -196,7 +299,16 @@
throws IOException {
Path outPart = new Path(filePath);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, outPart, conf);
- Text key = new Text();
+ Writable key = null;
+ try {
+ key = (Writable) reader.getKeyClass().newInstance();
+ } catch (InstantiationException e) {//shouldn't happen
+ log.error("Exception", e);
+ throw new RuntimeException(e);
+ } catch (IllegalAccessException e) {
+ log.error("Exception", e);
+ throw new RuntimeException(e);
+ }
Cluster value = new Cluster();
boolean converged = true;
while (converged && reader.next(key, value)) {
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java?rev=788858&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java Fri Jun 26 20:25:07 2009
@@ -0,0 +1,65 @@
+package org.apache.mahout.clustering.kmeans;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.matrix.Vector;
+
+import java.io.IOException;
+import java.util.Random;
+
+
+/**
+ * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, randomly select k vectors
+ * and write them to the output file as a {@link org.apache.mahout.clustering.kmeans.Cluster} representing
+ * the initial centroid to use.
+ * <p/>
+ *
+ */
+public class RandomSeedGenerator {
+ private transient static Log log = LogFactory.getLog(RandomSeedGenerator.class);
+ public static final String K = "k";
+
+ public static Path runJob(String input, String output,
+ int k ) throws IOException, IllegalAccessException, InstantiationException {
+ // delete the output directory
+ JobConf conf = new JobConf(RandomSeedGenerator.class);
+ Path outPath = new Path(output);
+ FileSystem fs = FileSystem.get(outPath.toUri(), conf);
+ if (fs.exists(outPath)) {
+ fs.delete(outPath, true);
+ }
+ fs.mkdirs(outPath);
+ Path outFile = new Path(outPath, "part-randomSeed");
+ if (fs.exists(outFile) == true){
+ log.warn("Deleting " + outFile);
+ fs.delete(outFile, false);
+ }
+ boolean newFile = fs.createNewFile(outFile);
+ if (newFile == true){
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(input), conf);
+ Writable key = (Writable) reader.getKeyClass().newInstance();
+ Vector value = (Vector) reader.getValueClass().newInstance();
+ SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outFile, Text.class, Cluster.class);
+ Random random = new Random();
+ int count = 0;
+
+ while (reader.next(key, value) && count < k){
+ if (random.nextBoolean() == true){
+ writer.append(new Text(key.toString()), new Cluster(value));
+ count++;
+ }
+ }
+ log.info("Wrote " + count + " vectors");
+ reader.close();
+ writer.close();
+ }
+
+ return outFile;
+ }
+}
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java?rev=788858&r1=788857&r2=788858&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java Fri Jun 26 20:25:07 2009
@@ -402,8 +402,8 @@
}
writer.close();
// now run the Job
- KMeansJob.runJob("testdata/points", "testdata/clusters", "output",
- EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1, SparseVector.class);
+ KMeansDriver.runJob("testdata/points", "testdata/clusters", "output",
+ EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1, SparseVector.class, true);
// now compare the expected clusters with actual
File outDir = new File("output/points");
assertTrue("output dir exists?", outDir.exists());
@@ -460,8 +460,8 @@
ManhattanDistanceMeasure.class.getName(), 3.1, 2.1, SparseVector.class);
// now run the KMeans job
- KMeansJob.runJob("testdata/points", "testdata/canopies", "output",
- EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1, SparseVector.class);
+ KMeansDriver.runJob("testdata/points", "testdata/canopies", "output",
+ EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1, SparseVector.class, true);
// now compare the expected clusters with actual
File outDir = new File("output/points");
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java?rev=788858&r1=788857&r2=788858&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java Fri Jun 26 20:25:07 2009
@@ -99,6 +99,6 @@
System.out.println("Running KMeans");
KMeansDriver.runJob(directoryContainingConvertedInput, output
+ CanopyClusteringJob.DEFAULT_CANOPIES_OUTPUT_DIRECTORY, output,
- measureClass, convergenceDelta, maxIterations, 1, vectorClass);
+ measureClass, convergenceDelta, maxIterations, 1, vectorClass, true);
}
}