You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/06/27 03:58:22 UTC

svn commit: r788915 - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/clustering/kmeans/ core/src/test/java/org/apache/mahout/clustering/kmeans/ examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/

Author: gsingers
Date: Sat Jun 27 01:58:21 2009
New Revision: 788915

URL: http://svn.apache.org/viewvc?rev=788915&view=rev
Log:
cleanup options a bit

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=788915&r1=788914&r2=788915&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Sat Jun 27 01:58:21 2009
@@ -16,48 +16,47 @@
  */
 package org.apache.mahout.clustering.kmeans;
 
-import java.io.IOException;
-
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapred.FileInputFormat;
 import org.apache.hadoop.mapred.FileOutputFormat;
 import org.apache.hadoop.mapred.JobClient;
 import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.io.SequenceFile;
-import org.apache.hadoop.io.Writable;
-import org.apache.hadoop.mapred.SequenceFileOutputFormat;
 import org.apache.hadoop.mapred.SequenceFileInputFormat;
-import org.apache.mahout.matrix.Vector;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
 import org.apache.mahout.matrix.SparseVector;
+import org.apache.mahout.matrix.Vector;
 import org.apache.mahout.utils.CommandLineUtil;
 import org.apache.mahout.utils.SquaredEuclideanDistanceMeasure;
-import org.apache.commons.cli2.builder.DefaultOptionBuilder;
-import org.apache.commons.cli2.builder.ArgumentBuilder;
-import org.apache.commons.cli2.builder.GroupBuilder;
-import org.apache.commons.cli2.Option;
-import org.apache.commons.cli2.Group;
-import org.apache.commons.cli2.CommandLine;
-import org.apache.commons.cli2.OptionException;
-import org.apache.commons.cli2.commandline.Parser;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.IOException;
+
 public class KMeansDriver {
 
   /**
-   * The name of the directory used to output final results. 
+   * The name of the directory used to output final results.
    */
   public static final String DEFAULT_OUTPUT_DIRECTORY = "/points";
-  
+
   private static final Logger log = LoggerFactory.getLogger(KMeansDriver.class);
 
   private KMeansDriver() {
   }
 
   /**
-   * 
    * @param args Expects 7 args and they all correspond to the order of the params in {@link #runJob}
    */
   public static void main(String[] args) throws ClassNotFoundException, IOException, IllegalAccessException, InstantiationException {
@@ -69,36 +68,46 @@
     Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
             abuilder.withName("input").withMinimum(1).withMaximum(1).create()).
             withDescription("The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+
     Option clustersOpt = obuilder.withLongName("clusters").withRequired(true).withArgument(
             abuilder.withName("clusters").withMinimum(1).withMaximum(1).create()).
             withDescription("The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.  " +
                     "If k is also specified, then a random set of vectors will be selected and written out to this path first").withShortName("c").create();
+
     Option kOpt = obuilder.withLongName("k").withRequired(false).withArgument(
             abuilder.withName("k").withMinimum(1).withMaximum(1).create()).
             withDescription("The k in k-Means.  If specified, then a random selection of k Vectors will be chosen as the Centroid and written to the clusters output path.").withShortName("k").create();
+
     Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
             abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
             withDescription("The Path to put the output in").withShortName("o").create();
-    Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withArgument(
-            abuilder.withName("overwrite").withMinimum(1).withMaximum(1).create()).
+
+    Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).
             withDescription("If set, overwrite the output directory").withShortName("w").create();
+
     Option measureClassOpt = obuilder.withLongName("distance").withRequired(false).withArgument(
             abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).
             withDescription("The Distance Measure to use.  Default is SquaredEuclidean").withShortName("m").create();
+
     Option convergenceDeltaOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(
             abuilder.withName("convergence").withMinimum(1).withMaximum(1).create()).
             withDescription("The threshold below which the clusters are considered to be converged.  Default is 0.5").withShortName("d").create();
+
     Option maxIterationsOpt = obuilder.withLongName("max").withRequired(false).withArgument(
             abuilder.withName("max").withMinimum(1).withMaximum(1).create()).
             withDescription("The maximum number of iterations to perform.  Default is 20").withShortName("x").create();
+
     Option vectorClassOpt = obuilder.withLongName("vectorClass").withRequired(false).withArgument(
             abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).
             withDescription("The Vector implementation class name.  Default is SparseVector.class").withShortName("v").create();
+
     Option numReduceTasksOpt = obuilder.withLongName("numReduce").withRequired(false).withArgument(
             abuilder.withName("numReduce").withMinimum(1).withMaximum(1).create()).
             withDescription("The number of reduce tasks").withShortName("r").create();
+
     Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(clustersOpt).withOption(outputOpt).withOption(measureClassOpt)
-            .withOption(convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(numReduceTasksOpt).withOption(kOpt).withOption(vectorClassOpt).create();
+            .withOption(convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(numReduceTasksOpt).withOption(kOpt)
+            .withOption(vectorClassOpt).withOption(overwriteOutput).create();
     Option helpOpt = obuilder.withLongName("help").
             withDescription("Print out help").withShortName("h").create();
     try {
@@ -124,11 +133,8 @@
 
       Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
               SparseVector.class
-                : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
+              : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
 
-      if (cmdLine.hasOption(kOpt)) {
-        clusters = buildRandom(input, clusters, Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
-      }
 
       int maxIterations = 20;
       if (cmdLine.hasOption(maxIterationsOpt)) {
@@ -138,21 +144,26 @@
       if (cmdLine.hasOption(numReduceTasksOpt)) {
         numReduceTasks = Integer.parseInt(cmdLine.getValue(numReduceTasksOpt).toString());
       }
-
+      if (cmdLine.hasOption(overwriteOutput) == true) {
+        overwriteOutput(output);
+      }
+      if (cmdLine.hasOption(kOpt)) {
+        clusters = buildRandom(input, clusters, Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
+      }
       runJob(input, clusters, output, measureClass, convergenceDelta,
-              maxIterations, numReduceTasks, vectorClass, cmdLine.hasOption(overwriteOutput));
+              maxIterations, numReduceTasks, vectorClass);
     } catch (OptionException e) {
       log.error("Exception", e);
       CommandLineUtil.printHelp(group);
     }
   }
 
-  private static void overwriteOutput(String output) throws IOException {
+  public static void overwriteOutput(String output) throws IOException {
     JobConf conf = new JobConf(KMeansDriver.class);
     Path outPath = new Path(output);
     FileSystem fs = FileSystem.get(outPath.toUri(), conf);
     boolean exists = fs.exists(outPath);
-    if (exists == true){
+    if (exists == true) {
       log.warn("Deleting " + outPath);
       fs.delete(outPath, true);
     }
@@ -167,32 +178,32 @@
   /**
    * Run the job using supplied arguments
    *
-   * @param input the directory pathname for input points
-   * @param clustersIn the directory pathname for initial & computed clusters
-   * @param output the directory pathname for output points
-   * @param measureClass the classname of the DistanceMeasure
+   * @param input            the directory pathname for input points
+   * @param clustersIn       the directory pathname for initial & computed clusters
+   * @param output           the directory pathname for output points
+   * @param measureClass     the classname of the DistanceMeasure
    * @param convergenceDelta the convergence delta value
-   * @param maxIterations the maximum number of iterations
-   * @param numReduceTasks the number of reducers
+   * @param maxIterations    the maximum number of iterations
+   * @param numReduceTasks   the number of reducers
    * @param vectorClass
-   * @param overwrite
    */
   public static void runJob(String input, String clustersIn, String output,
                             String measureClass, double convergenceDelta, int maxIterations,
-                            int numReduceTasks, Class<? extends Vector> vectorClass, boolean overwrite) throws IOException {
+                            int numReduceTasks, Class<? extends Vector> vectorClass) throws IOException {
     // iterate until the clusters converge
     boolean converged = false;
     int iteration = 0;
     String delta = Double.toString(convergenceDelta);
-    if (overwrite == true){
-      overwriteOutput(output);
+    if (log.isInfoEnabled()) {
+      log.info("Input: " + input + " Clusters In: " + clustersIn + " Out: " + output + " Distance: " + measureClass);
+      log.info("convergence: " + convergenceDelta + " max Iterations: " + maxIterations + " num Reduce Tasks: " + numReduceTasks + " Input Vectors: " + vectorClass.getName());
     }
     while (!converged && iteration < maxIterations) {
       log.info("Iteration {}", iteration);
       // point the output to a new directory per iteration
       String clustersOut = output + "/clusters-" + iteration;
       converged = runIteration(input, clustersIn, clustersOut, measureClass,
-          delta, numReduceTasks, iteration);
+              delta, numReduceTasks, iteration);
       // now point the input to the old output directory
       clustersIn = output + "/clusters-" + iteration;
       iteration++;
@@ -205,13 +216,13 @@
   /**
    * Run the job using supplied arguments
    *
-   * @param input the directory pathname for input points
-   * @param clustersIn the directory pathname for input clusters
-   * @param clustersOut the directory pathname for output clusters
-   * @param measureClass the classname of the DistanceMeasure
+   * @param input            the directory pathname for input points
+   * @param clustersIn       the directory pathname for input clusters
+   * @param clustersOut      the directory pathname for output clusters
+   * @param measureClass     the classname of the DistanceMeasure
    * @param convergenceDelta the convergence delta value
-   * @param numReduceTasks the number of reducer tasks
-   * @param iteration The iteration number
+   * @param numReduceTasks   the number of reducer tasks
+   * @param iteration        The iteration number
    * @return true if the iteration successfully runs
    */
   private static boolean runIteration(String input, String clustersIn,
@@ -236,7 +247,7 @@
     conf.set(Cluster.DISTANCE_MEASURE_KEY, measureClass);
     conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
     conf.setInt(Cluster.ITERATION_NUMBER, iteration);
-    
+
 
     try {
       JobClient.runJob(conf);
@@ -251,14 +262,14 @@
   /**
    * Run the job using supplied arguments
    *
-   * @param input the directory pathname for input points
-   * @param clustersIn the directory pathname for input clusters
-   * @param output the directory pathname for output points
-   * @param measureClass the classname of the DistanceMeasure
+   * @param input            the directory pathname for input points
+   * @param clustersIn       the directory pathname for input clusters
+   * @param output           the directory pathname for output points
+   * @param measureClass     the classname of the DistanceMeasure
    * @param convergenceDelta the convergence delta value
    */
   private static void runClustering(String input, String clustersIn,
-      String output, String measureClass, String convergenceDelta, Class<? extends Vector> vectorClass) {
+                                    String output, String measureClass, String convergenceDelta, Class<? extends Vector> vectorClass) {
     JobConf conf = new JobConf(KMeansDriver.class);
     conf.setInputFormat(SequenceFileInputFormat.class);
     conf.setOutputFormat(SequenceFileOutputFormat.class);
@@ -295,7 +306,7 @@
    * @return true if all Clusters are converged
    * @throws IOException if there was an IO error
    */
- private static boolean isConverged(String filePath, JobConf conf, FileSystem fs)
+  private static boolean isConverged(String filePath, JobConf conf, FileSystem fs)
           throws IOException {
     Path outPart = new Path(filePath);
     SequenceFile.Reader reader = new SequenceFile.Reader(fs, outPart, conf);

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java?rev=788915&r1=788914&r2=788915&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java Sat Jun 27 01:58:21 2009
@@ -58,7 +58,7 @@
         count++;
       }
     }
-    log.info("Wrote " + count + " vectors");
+    log.info("Wrote " + count + " vectors to " + outFile);
     reader.close();
     writer.close();
     }

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java?rev=788915&r1=788914&r2=788915&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java Sat Jun 27 01:58:21 2009
@@ -402,8 +402,9 @@
       }
       writer.close();
       // now run the Job
+      KMeansDriver.overwriteOutput("output");
       KMeansDriver.runJob("testdata/points", "testdata/clusters", "output",
-          EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1, SparseVector.class, true);
+          EuclideanDistanceMeasure.class.getName(), 0.001, 10, k + 1, SparseVector.class);
       // now compare the expected clusters with actual
       File outDir = new File("output/points");
       assertTrue("output dir exists?", outDir.exists());
@@ -461,7 +462,7 @@
 
     // now run the KMeans job
     KMeansDriver.runJob("testdata/points", "testdata/canopies", "output",
-        EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1, SparseVector.class, true);
+        EuclideanDistanceMeasure.class.getName(), 0.001, 10, 1, SparseVector.class);
 
     // now compare the expected clusters with actual
     File outDir = new File("output/points");

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java?rev=788915&r1=788914&r2=788915&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/kmeans/Job.java Sat Jun 27 01:58:21 2009
@@ -99,6 +99,6 @@
     System.out.println("Running KMeans");
     KMeansDriver.runJob(directoryContainingConvertedInput, output
         + CanopyClusteringJob.DEFAULT_CANOPIES_OUTPUT_DIRECTORY, output,
-        measureClass, convergenceDelta, maxIterations, 1, vectorClass, false);
+        measureClass, convergenceDelta, maxIterations, 1, vectorClass);
   }
 }