You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/06/27 03:58:22 UTC
svn commit: r788915 - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/kmeans/
core/src/test/java/org/apache/mahout/clustering/kmeans/
examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/
Author: gsingers
Date: Sat Jun 27 01:58:21 2009
New Revision: 788915
URL: http://svn.apache.org/viewvc?rev=788915&view=rev
Log:
cleanup options a bit
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=788915&r1=788914&r2=788915&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Sat Jun 27 01:58:21 2009
@@ -16,48 +16,47 @@
*/
package org.apache.mahout.clustering.kmeans;
-import java.io.IOException;
-
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.io.SequenceFile;
-import org.apache.hadoop.io.Writable;
-import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
-import org.apache.mahout.matrix.Vector;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.mahout.matrix.SparseVector;
+import org.apache.mahout.matrix.Vector;
import org.apache.mahout.utils.CommandLineUtil;
import org.apache.mahout.utils.SquaredEuclideanDistanceMeasure;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.OptionException;
-import org.apache.commons.cli2.commandline.Parser;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.io.IOException;
+
public class KMeansDriver {
/**
- * The name of the directory used to output final results.
+ * The name of the directory used to output final results.
*/
public static final String DEFAULT_OUTPUT_DIRECTORY = "/points";
-
+
private static final Logger log = LoggerFactory.getLogger(KMeansDriver.class);
private KMeansDriver() {
}
/**
- *
* @param args Expects 7 args and they all correspond to the order of the params in {@link #runJob}
*/
public static void main(String[] args) throws ClassNotFoundException, IOException, IllegalAccessException, InstantiationException {
@@ -69,36 +68,46 @@
Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
abuilder.withName("input").withMinimum(1).withMaximum(1).create()).
withDescription("The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+
Option clustersOpt = obuilder.withLongName("clusters").withRequired(true).withArgument(
abuilder.withName("clusters").withMinimum(1).withMaximum(1).create()).
withDescription("The input centroids, as Vectors. Must be a SequenceFile of Writable, Cluster/Canopy. " +
"If k is also specified, then a random set of vectors will be selected and written out to this path first").withShortName("c").create();
+
Option kOpt = obuilder.withLongName("k").withRequired(false).withArgument(
abuilder.withName("k").withMinimum(1).withMaximum(1).create()).
withDescription("The k in k-Means. If specified, then a random selection of k Vectors will be chosen as the Centroid and written to the clusters output path.").withShortName("k").create();
+
Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
withDescription("The Path to put the output in").withShortName("o").create();
- Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withArgument(
- abuilder.withName("overwrite").withMinimum(1).withMaximum(1).create()).
+
+ Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).
withDescription("If set, overwrite the output directory").withShortName("w").create();
+
Option measureClassOpt = obuilder.withLongName("distance").withRequired(false).withArgument(
abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).
withDescription("The Distance Measure to use. Default is SquaredEuclidean").withShortName("m").create();
+
Option convergenceDeltaOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(
abuilder.withName("convergence").withMinimum(1).withMaximum(1).create()).
withDescription("The threshold below which the clusters are considered to be converged. Default is 0.5").withShortName("d").create();
+
Option maxIterationsOpt = obuilder.withLongName("max").withRequired(false).withArgument(
abuilder.withName("max").withMinimum(1).withMaximum(1).create()).
withDescription("The maximum number of iterations to perform. Default is 20").withShortName("x").create();
+
Option vectorClassOpt = obuilder.withLongName("vectorClass").withRequired(false).withArgument(
abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).
withDescription("The Vector implementation class name. Default is SparseVector.class").withShortName("v").create();
+
Option numReduceTasksOpt = obuilder.withLongName("numReduce").withRequired(false).withArgument(
abuilder.withName("numReduce").withMinimum(1).withMaximum(1).create()).
withDescription("The number of reduce tasks").withShortName("r").create();
+
Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(clustersOpt).withOption(outputOpt).withOption(measureClassOpt)
- .withOption(convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(numReduceTasksOpt).withOption(kOpt).withOption(vectorClassOpt).create();
+ .withOption(convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(numReduceTasksOpt).withOption(kOpt)
+ .withOption(vectorClassOpt).withOption(overwriteOutput).create();
Option helpOpt = obuilder.withLongName("help").
withDescription("Print out help").withShortName("h").create();
try {
@@ -124,11 +133,8 @@
Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
SparseVector.class
- : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
+ : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
- if (cmdLine.hasOption(kOpt)) {
- clusters = buildRandom(input, clusters, Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
- }
int maxIterations = 20;
if (cmdLine.hasOption(maxIterationsOpt)) {
@@ -138,21 +144,26 @@
if (cmdLine.hasOption(numReduceTasksOpt)) {
numReduceTasks = Integer.parseInt(cmdLine.getValue(numReduceTasksOpt).toString());
}
-
+ if (cmdLine.hasOption(overwriteOutput) == true) {
+ overwriteOutput(output);
+ }
+ if (cmdLine.hasOption(kOpt)) {
+ clusters = buildRandom(input, clusters, Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
+ }
runJob(input, clusters, output, measureClass, convergenceDelta,
- maxIterations, numReduceTasks, vectorClass, cmdLine.hasOption(overwriteOutput));
+ maxIterations, numReduceTasks, vectorClass);
} catch (OptionException e) {
log.error("Exception", e);
CommandLineUtil.printHelp(group);
}
}
- private static void overwriteOutput(String output) throws IOException {
+ public static void overwriteOutput(String output) throws IOException {
JobConf conf = new JobConf(KMeansDriver.class);
Path outPath = new Path(output);
FileSystem fs = FileSystem.get(outPath.toUri(), conf);
boolean exists = fs.exists(outPath);
- if (exists == true){
+ if (exists == true) {
log.warn("Deleting " + outPath);
fs.delete(outPath, true);
}
@@ -167,32 +178,32 @@
/**
* Run the job using supplied arguments
*
- * @param input the directory pathname for input points
- * @param clustersIn the directory pathname for initial & computed clusters
- * @param output the directory pathname for output points
- * @param measureClass the classname of the DistanceMeasure
+ * @param input the directory pathname for input points
+ * @param clustersIn the directory pathname for initial & computed clusters
+ * @param output the directory pathname for output points
+ * @param measureClass the classname of the DistanceMeasure
* @param convergenceDelta the convergence delta value
- * @param maxIterations the maximum number of iterations
- * @param numReduceTasks the number of reducers
+ * @param maxIterations the maximum number of iterations
+ * @param numReduceTasks the number of reducers
* @param vectorClass
- * @param overwrite
*/
public static void runJob(String input, String clustersIn, String output,
String measureClass, double convergenceDelta, int maxIterations,
- int numReduceTasks, Class<? extends Vector> vectorClass, boolean overwrite) throws IOException {
+ int numReduceTasks, Class<? extends Vector> vectorClass) throws IOException {
// iterate until the clusters converge
boolean converged = false;
int iteration = 0;
String delta = Double.toString(convergenceDelta);
- if (overwrite == true){
- overwriteOutput(output);
+ if (log.isInfoEnabled()) {
+ log.info("Input: " + input + " Clusters In: " + clustersIn + " Out: " + output + " Distance: " + measureClass);
+ log.info("convergence: " + convergenceDelta + " max Iterations: " + maxIterations + " num Reduce Tasks: " + numReduceTasks + " Input Vectors: " + vectorClass.getName());
}
while (!converged && iteration < maxIterations) {
log.info("Iteration {}", iteration);
// point the output to a new directory per iteration
String clustersOut = output + "/clusters-" + iteration;
converged = runIteration(input, clustersIn, clustersOut, measureClass,
- delta, numReduceTasks, iteration);
+ delta, numReduceTasks, iteration);
// now point the input to the old output directory
clustersIn = output + "/clusters-" + iteration;
iteration++;
@@ -205,13 +216,13 @@
/**
* Run the job using supplied arguments
*
- * @param input the directory pathname for input points
- * @param clustersIn the directory pathname for input clusters
- * @param clustersOut the directory pathname for output clusters
- * @param measureClass the classname of the DistanceMeasure
+ * @param input the directory pathname for input points
+ * @param clustersIn the directory pathname for input clusters
+ * @param clustersOut the directory pathname for output clusters
+ * @param measureClass the classname of the DistanceMeasure
* @param convergenceDelta the convergence delta value
- * @param numReduceTasks the number of reducer tasks
- * @param iteration The iteration number
+ * @param numReduceTasks the number of reducer tasks
+ * @param iteration The iteration number
* @return true if the iteration successfully runs
*/
private static boolean runIteration(String input, String clustersIn,
@@ -236,7 +247,7 @@
conf.set(Cluster.DISTANCE_MEASURE_KEY, measureClass);
conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
conf.setInt(Cluster.ITERATION_NUMBER, iteration);
-
+
try {
JobClient.runJob(conf);
@@ -251,14 +262,14 @@
/**
* Run the job using supplied arguments
*
- * @param input the directory pathname for input points
- * @param clustersIn the directory pathname for input clusters
- * @param output the directory pathname for output points
- * @param measureClass the classname of the DistanceMeasure
+ * @param input the directory pathname for input points
+ * @param clustersIn the directory pathname for input clusters
+ * @param output the directory pathname for output points
+ * @param measureClass the classname of the DistanceMeasure
* @param convergenceDelta the convergence delta value
*/
private static void runClustering(String input, String clustersIn,
- String output, String measureClass, String convergenceDelta, Class<? extends Vector> vectorClass) {
+ String output, String measureClass, String convergenceDelta, Class<? extends Vector> vectorClass) {
JobConf conf = new JobConf(KMeansDriver.class);
conf.setInputFormat(SequenceFileInputFormat.class);
conf.setOutputFormat(SequenceFileOutputFormat.class);
@@ -295,7 +306,7 @@
* @return true if all Clusters are converged
* @throws IOException if there was an IO error
*/
- private static boolean isConverged(String filePath, JobConf conf, FileSystem fs)
+ private static boolean isConverged(String filePath, JobConf conf, FileSystem fs)
throws IOException {
Path outPart = new Path(filePath);
SequenceFile.Reader reader = new SequenceFile.Reader(fs, outPart, conf);
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java?rev=788915&r1=788914&r2=788915&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java Sat Jun 27 01:58:21 2009
@@ -58,7 +58,7 @@
count++;
}
}
- log.info("Wrote " + count + " vectors");
+ log.info("Wrote " + count + " vectors to " + outFile);
reader.close();
writer.close();
}
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java?rev=788915&r1=788914&r2=788915&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java Sat Jun 27 01:58:21 2009
@@ -402,8 +402,9 @@
}
writer.close();
// now run the Job
+ KMeansDriver.overwriteOutput("output");
KMeansDriver.runJob("testdata/points", "testdata/clusters", "output",
- EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1, SparseVector.class, true);
+ EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1, SparseVector.class);
// now compare the expected clusters with actual
File outDir = new File("output/points");
assertTrue("output dir exists?", outDir.exists());
@@ -461,7 +462,7 @@
// now run the KMeans job
KMeansDriver.runJob("testdata/points", "testdata/canopies", "output",
- EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1, SparseVector.class, true);
+ EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1, SparseVector.class);
// now compare the expected clusters with actual
File outDir = new File("output/points");
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java?rev=788915&r1=788914&r2=788915&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java Sat Jun 27 01:58:21 2009
@@ -99,6 +99,6 @@
System.out.println("Running KMeans");
KMeansDriver.runJob(directoryContainingConvertedInput, output
+ CanopyClusteringJob.DEFAULT_CANOPIES_OUTPUT_DIRECTORY, output,
- measureClass, convergenceDelta, maxIterations, 1, vectorClass, false);
+ measureClass, convergenceDelta, maxIterations, 1, vectorClass);
}
}