You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/06/26 22:25:07 UTC

svn commit: r788858 - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/clustering/kmeans/ core/src/test/java/org/apache/mahout/clustering/kmeans/ examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/

Author: gsingers
Date: Fri Jun 26 20:25:07 2009
New Revision: 788858

URL: http://svn.apache.org/viewvc?rev=788858&view=rev
Log:
MAHOUT-138: KMeans options, Random Seed Generator and delete KMeansJob

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
Removed:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansJob.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=788858&r1=788857&r2=788858&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Fri Jun 26 20:25:07 2009
@@ -26,9 +26,21 @@
 import org.apache.hadoop.mapred.JobClient;
 import org.apache.hadoop.mapred.JobConf;
 import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapred.SequenceFileOutputFormat;
 import org.apache.hadoop.mapred.SequenceFileInputFormat;
 import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.matrix.SparseVector;
+import org.apache.mahout.utils.CommandLineUtil;
+import org.apache.mahout.utils.SquaredEuclideanDistanceMeasure;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.commandline.Parser;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -48,17 +60,108 @@
    * 
    * @param args Expects 7 args and they all correspond to the order of the params in {@link #runJob}
    */
-  public static void main(String[] args) throws ClassNotFoundException {
-    String input = args[0];
-    String clusters = args[1];
-    String output = args[2];
-    String measureClass = args[3];
-    double convergenceDelta = Double.parseDouble(args[4]);
-    int maxIterations = Integer.parseInt(args[5]);
-    String vectorClassName = args[6];
-    Class<? extends Vector> vectorClass = (Class<? extends Vector>) Class.forName(vectorClassName);
-    runJob(input, clusters, output, measureClass, convergenceDelta,
-        maxIterations, 2, vectorClass);
+  public static void main(String[] args) throws ClassNotFoundException, IOException, IllegalAccessException, InstantiationException {
+
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+
+    Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
+            abuilder.withName("input").withMinimum(1).withMaximum(1).create()).
+            withDescription("The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+    Option clustersOpt = obuilder.withLongName("clusters").withRequired(true).withArgument(
+            abuilder.withName("clusters").withMinimum(1).withMaximum(1).create()).
+            withDescription("The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.  " +
+                    "If k is also specified, then a random set of vectors will be selected and written out to this path first").withShortName("c").create();
+    Option kOpt = obuilder.withLongName("k").withRequired(false).withArgument(
+            abuilder.withName("k").withMinimum(1).withMaximum(1).create()).
+            withDescription("The k in k-Means.  If specified, then a random selection of k Vectors will be chosen as the Centroid and written to the clusters output path.").withShortName("k").create();
+    Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
+            abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
+            withDescription("The Path to put the output in").withShortName("o").create();
+    Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withArgument(
+            abuilder.withName("overwrite").withMinimum(1).withMaximum(1).create()).
+            withDescription("If set, overwrite the output directory").withShortName("w").create();
+    Option measureClassOpt = obuilder.withLongName("distance").withRequired(false).withArgument(
+            abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).
+            withDescription("The Distance Measure to use.  Default is SquaredEuclidean").withShortName("m").create();
+    Option convergenceDeltaOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(
+            abuilder.withName("convergence").withMinimum(1).withMaximum(1).create()).
+            withDescription("The threshold below which the clusters are considered to be converged.  Default is 0.5").withShortName("d").create();
+    Option maxIterationsOpt = obuilder.withLongName("max").withRequired(false).withArgument(
+            abuilder.withName("max").withMinimum(1).withMaximum(1).create()).
+            withDescription("The maximum number of iterations to perform.  Default is 20").withShortName("x").create();
+    Option vectorClassOpt = obuilder.withLongName("vectorClass").withRequired(false).withArgument(
+            abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).
+            withDescription("The Vector implementation class name.  Default is SparseVector.class").withShortName("v").create();
+    Option numReduceTasksOpt = obuilder.withLongName("numReduce").withRequired(false).withArgument(
+            abuilder.withName("numReduce").withMinimum(1).withMaximum(1).create()).
+            withDescription("The number of reduce tasks").withShortName("r").create();
+    Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(clustersOpt).withOption(outputOpt).withOption(measureClassOpt)
+            .withOption(convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(numReduceTasksOpt).withOption(kOpt).withOption(vectorClassOpt).create();
+    Option helpOpt = obuilder.withLongName("help").
+            withDescription("Print out help").withShortName("h").create();
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        return;
+      }
+      String input = cmdLine.getValue(inputOpt).toString();
+      String clusters = cmdLine.getValue(clustersOpt).toString();
+      String output = cmdLine.getValue(outputOpt).toString();
+      String measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+      if (cmdLine.hasOption(measureClassOpt)) {
+        measureClass = cmdLine.getValue(measureClassOpt).toString();
+      }
+      double convergenceDelta = 0.5;
+      if (cmdLine.hasOption(convergenceDeltaOpt)) {
+        convergenceDelta = Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt).toString());
+      }
+
+      Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
+              SparseVector.class
+                : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
+
+      if (cmdLine.hasOption(kOpt)) {
+        clusters = buildRandom(input, clusters, Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
+      }
+
+      int maxIterations = 20;
+      if (cmdLine.hasOption(maxIterationsOpt)) {
+        maxIterations = Integer.parseInt(cmdLine.getValue(maxIterationsOpt).toString());
+      }
+      int numReduceTasks = 2;
+      if (cmdLine.hasOption(numReduceTasksOpt)) {
+        numReduceTasks = Integer.parseInt(cmdLine.getValue(numReduceTasksOpt).toString());
+      }
+
+      runJob(input, clusters, output, measureClass, convergenceDelta,
+              maxIterations, numReduceTasks, vectorClass, cmdLine.hasOption(overwriteOutput));
+    } catch (OptionException e) {
+      log.error("Exception", e);
+      CommandLineUtil.printHelp(group);
+    }
+  }
+
+  private static void overwriteOutput(String output) throws IOException {
+    JobConf conf = new JobConf(KMeansDriver.class);
+    Path outPath = new Path(output);
+    FileSystem fs = FileSystem.get(outPath.toUri(), conf);
+    boolean exists = fs.exists(outPath);
+    if (exists == true){
+      log.warn("Deleting " + outPath);
+      fs.delete(outPath, true);
+    }
+    fs.mkdirs(outPath);
+
+  }
+
+  public static Path buildRandom(String input, String clusters, int k) throws IOException, IllegalAccessException, InstantiationException {
+    return RandomSeedGenerator.runJob(input, clusters, k);
   }
 
   /**
@@ -72,15 +175,18 @@
    * @param maxIterations the maximum number of iterations
    * @param numReduceTasks the number of reducers
    * @param vectorClass
+   * @param overwrite
    */
   public static void runJob(String input, String clustersIn, String output,
                             String measureClass, double convergenceDelta, int maxIterations,
-                            int numReduceTasks, Class<? extends Vector> vectorClass) {
+                            int numReduceTasks, Class<? extends Vector> vectorClass, boolean overwrite) throws IOException {
     // iterate until the clusters converge
     boolean converged = false;
     int iteration = 0;
     String delta = Double.toString(convergenceDelta);
-
+    if (overwrite == true){
+      overwriteOutput(output);
+    }
     while (!converged && iteration < maxIterations) {
       log.info("Iteration {}", iteration);
       // point the output to a new directory per iteration
@@ -111,7 +217,6 @@
   private static boolean runIteration(String input, String clustersIn,
                                       String clustersOut, String measureClass, String convergenceDelta,
                                       int numReduceTasks, int iteration) {
-    JobClient client = new JobClient();
     JobConf conf = new JobConf(KMeansDriver.class);
     conf.setMapOutputKeyClass(Text.class);
     conf.setMapOutputValueClass(KMeansInfo.class);
@@ -132,7 +237,7 @@
     conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
     conf.setInt(Cluster.ITERATION_NUMBER, iteration);
     
-    client.setConf(conf);
+
     try {
       JobClient.runJob(conf);
       FileSystem fs = FileSystem.get(outPath.toUri(), conf);
@@ -154,7 +259,6 @@
    */
   private static void runClustering(String input, String clustersIn,
       String output, String measureClass, String convergenceDelta, Class<? extends Vector> vectorClass) {
-    JobClient client = new JobClient();
     JobConf conf = new JobConf(KMeansDriver.class);
     conf.setInputFormat(SequenceFileInputFormat.class);
     conf.setOutputFormat(SequenceFileOutputFormat.class);
@@ -175,7 +279,6 @@
     conf.set(Cluster.DISTANCE_MEASURE_KEY, measureClass);
     conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
 
-    client.setConf(conf);
     try {
       JobClient.runJob(conf);
     } catch (IOException e) {
@@ -196,7 +299,16 @@
           throws IOException {
     Path outPart = new Path(filePath);
     SequenceFile.Reader reader = new SequenceFile.Reader(fs, outPart, conf);
-    Text key = new Text();
+    Writable key = null;
+    try {
+      key = (Writable) reader.getKeyClass().newInstance();
+    } catch (InstantiationException e) {//shouldn't happen
+      log.error("Exception", e);
+      throw new RuntimeException(e);
+    } catch (IllegalAccessException e) {
+      log.error("Exception", e);
+      throw new RuntimeException(e);
+    }
     Cluster value = new Cluster();
     boolean converged = true;
     while (converged && reader.next(key, value)) {

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java?rev=788858&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java Fri Jun 26 20:25:07 2009
@@ -0,0 +1,65 @@
+package org.apache.mahout.clustering.kmeans;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.matrix.Vector;
+
+import java.io.IOException;
+import java.util.Random;
+
+
+/**
+ * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, randomly select k vectors
+ * and write them to the output file as a {@link org.apache.mahout.clustering.kmeans.Cluster} representing
+ * the initial centroid to use.
+ * <p/>
+ *
+ */
+public class RandomSeedGenerator {
+  private transient static Log log = LogFactory.getLog(RandomSeedGenerator.class);
+  public static final String K = "k";
+
+  public static Path runJob(String input, String output,
+                            int k ) throws IOException, IllegalAccessException, InstantiationException {
+    // delete the output directory
+    JobConf conf = new JobConf(RandomSeedGenerator.class);
+    Path outPath = new Path(output);
+    FileSystem fs = FileSystem.get(outPath.toUri(), conf);
+    if (fs.exists(outPath)) {
+      fs.delete(outPath, true);
+    }
+    fs.mkdirs(outPath);
+    Path outFile = new Path(outPath, "part-randomSeed");
+    if (fs.exists(outFile) == true){
+      log.warn("Deleting " + outFile);
+      fs.delete(outFile, false);
+    }
+    boolean newFile = fs.createNewFile(outFile);
+    if (newFile == true){
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, new Path(input), conf);
+      Writable key = (Writable) reader.getKeyClass().newInstance();
+      Vector value = (Vector) reader.getValueClass().newInstance();
+      SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outFile, Text.class, Cluster.class);
+      Random random = new Random();
+    int count = 0;
+
+    while (reader.next(key, value) && count < k){
+      if (random.nextBoolean() == true){
+        writer.append(new Text(key.toString()), new Cluster(value));
+        count++;
+      }
+    }
+    log.info("Wrote " + count + " vectors");
+    reader.close();
+    writer.close();
+    }
+
+    return outFile;
+  }
+}

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java?rev=788858&r1=788857&r2=788858&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java Fri Jun 26 20:25:07 2009
@@ -402,8 +402,8 @@
       }
       writer.close();
       // now run the Job
-      KMeansJob.runJob("testdata/points", "testdata/clusters", "output",
-          EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1, SparseVector.class);
+      KMeansDriver.runJob("testdata/points", "testdata/clusters", "output",
+          EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1, SparseVector.class, true);
       // now compare the expected clusters with actual
       File outDir = new File("output/points");
       assertTrue("output dir exists?", outDir.exists());
@@ -460,8 +460,8 @@
         ManhattanDistanceMeasure.class.getName(), 3.1, 2.1, SparseVector.class);
 
     // now run the KMeans job
-    KMeansJob.runJob("testdata/points", "testdata/canopies", "output",
-        EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1, SparseVector.class);
+    KMeansDriver.runJob("testdata/points", "testdata/canopies", "output",
+        EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1, SparseVector.class, true);
 
     // now compare the expected clusters with actual
     File outDir = new File("output/points");

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java?rev=788858&r1=788857&r2=788858&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java Fri Jun 26 20:25:07 2009
@@ -99,6 +99,6 @@
     System.out.println("Running KMeans");
     KMeansDriver.runJob(directoryContainingConvertedInput, output
         + CanopyClusteringJob.DEFAULT_CANOPIES_OUTPUT_DIRECTORY, output,
-        measureClass, convergenceDelta, maxIterations, 1, vectorClass);
+        measureClass, convergenceDelta, maxIterations, 1, vectorClass, true);
   }
 }