You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/02/13 22:08:12 UTC

svn commit: r909914 [4/5] - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/clustering/ main/java/org/apache/mahout/clustering/canopy/ main/java/org/apache/mahout/clustering/dirichlet/ main/java/org/apache/mahout/clustering/dirichlet/mode...

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java Sat Feb 13 21:07:53 2010
@@ -16,6 +16,8 @@
  */
 package org.apache.mahout.clustering.kmeans;
 
+import java.io.IOException;
+
 import org.apache.commons.cli2.CommandLine;
 import org.apache.commons.cli2.Group;
 import org.apache.commons.cli2.Option;
@@ -43,85 +45,87 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
-
-public class KMeansDriver {
-
+public final class KMeansDriver {
+  
   /** The name of the directory used to output final results. */
   public static final String DEFAULT_OUTPUT_DIRECTORY = "/points";
-
+  
   private static final Logger log = LoggerFactory.getLogger(KMeansDriver.class);
-
-  private KMeansDriver() {
-  }
-
-  /** @param args Expects 7 args and they all correspond to the order of the params in {@link #runJob} */
+  
+  private KMeansDriver() { }
+  
+  /**
+   * @param args
+   *          Expects 7 args and they all correspond to the order of the params in {@link #runJob}
+   */
   public static void main(String[] args) throws Exception {
-
+    
     DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
     ArgumentBuilder abuilder = new ArgumentBuilder();
     GroupBuilder gbuilder = new GroupBuilder();
-
+    
     Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
-        abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
-
+      abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+    
     Option clustersOpt = obuilder
         .withLongName("clusters")
         .withRequired(true)
         .withArgument(abuilder.withName("clusters").withMinimum(1).withMaximum(1).create())
         .withDescription(
-            "The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.  "
-                + "If k is also specified, then a random set of vectors will be selected and written out to this path first")
+          "The input centroids, as Vectors.  Must be a SequenceFile of Writable, Cluster/Canopy.  "
+              + "If k is also specified, then a random set of vectors will be selected and written out to this path first")
         .withShortName("c").create();
-
+    
     Option kOpt = obuilder
         .withLongName("k")
         .withRequired(false)
         .withArgument(abuilder.withName("k").withMinimum(1).withMaximum(1).create())
         .withDescription(
-            "The k in k-Means.  If specified, then a random selection of k Vectors will be chosen as the Centroid and written to the clusters output path.")
+          "The k in k-Means.  If specified, then a random selection of k Vectors will be chosen as the Centroid and written to the clusters output path.")
         .withShortName("k").create();
-
+    
     Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
-        abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The Path to put the output in").withShortName("o").create();
-
+      abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The Path to put the output in").withShortName("o").create();
+    
     Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withDescription(
-        "If set, overwrite the output directory").withShortName("w").create();
-
+      "If set, overwrite the output directory").withShortName("w").create();
+    
     Option measureClassOpt = obuilder.withLongName("distance").withRequired(false).withArgument(
-        abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The Distance Measure to use.  Default is SquaredEuclidean").withShortName("m").create();
-
+      abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The Distance Measure to use.  Default is SquaredEuclidean").withShortName("m").create();
+    
     Option convergenceDeltaOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(
-        abuilder.withName("convergence").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The threshold below which the clusters are considered to be converged.  Default is 0.5").withShortName("d")
-        .create();
-
+      abuilder.withName("convergence").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The threshold below which the clusters are considered to be converged.  Default is 0.5")
+        .withShortName("d").create();
+    
     Option maxIterationsOpt = obuilder.withLongName("max").withRequired(false).withArgument(
-        abuilder.withName("max").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The maximum number of iterations to perform.  Default is 20").withShortName("x").create();
-
+      abuilder.withName("max").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The maximum number of iterations to perform.  Default is 20").withShortName("x").create();
+    
     Option vectorClassOpt = obuilder.withLongName("vectorClass").withRequired(false).withArgument(
-        abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The Vector implementation class name.  Default is RandomAccessSparseVector.class").withShortName("v").create();
-
+      abuilder.withName("vectorClass").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The Vector implementation class name.  Default is RandomAccessSparseVector.class").withShortName("v")
+        .create();
+    
     Option numReduceTasksOpt = obuilder.withLongName("numReduce").withRequired(false).withArgument(
-        abuilder.withName("numReduce").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The number of reduce tasks").withShortName("r").create();
-
-    Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
-
-    Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(clustersOpt).withOption(outputOpt)
-        .withOption(measureClassOpt).withOption(convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(
-            numReduceTasksOpt).withOption(kOpt).withOption(vectorClassOpt).withOption(overwriteOutput).withOption(
-            helpOpt).create();
+      abuilder.withName("numReduce").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The number of reduce tasks").withShortName("r").create();
+    
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+        .create();
+    
+    Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(clustersOpt).withOption(
+      outputOpt).withOption(measureClassOpt).withOption(convergenceDeltaOpt).withOption(maxIterationsOpt)
+        .withOption(numReduceTasksOpt).withOption(kOpt).withOption(vectorClassOpt)
+        .withOption(overwriteOutput).withOption(helpOpt).create();
     try {
       Parser parser = new Parser();
       parser.setGroup(group);
       CommandLine cmdLine = parser.parse(args);
-
+      
       if (cmdLine.hasOption(helpOpt)) {
         CommandLineUtil.printHelp(group);
         return;
@@ -137,10 +141,11 @@
       if (cmdLine.hasOption(convergenceDeltaOpt)) {
         convergenceDelta = Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt).toString());
       }
-
-      //Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ? RandomAccessSparseVector.class
-      //    : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
-
+      
+      // Class<? extends Vector> vectorClass = cmdLine.hasOption(vectorClassOpt) == false ?
+      // RandomAccessSparseVector.class
+      // : (Class<? extends Vector>) Class.forName(cmdLine.getValue(vectorClassOpt).toString());
+      
       int maxIterations = 20;
       if (cmdLine.hasOption(maxIterationsOpt)) {
         maxIterations = Integer.parseInt(cmdLine.getValue(maxIterationsOpt).toString());
@@ -153,72 +158,102 @@
         HadoopUtil.overwriteOutput(output);
       }
       if (cmdLine.hasOption(kOpt)) {
-        clusters = RandomSeedGenerator
-            .buildRandom(input, clusters, Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
+        clusters = RandomSeedGenerator.buildRandom(input, clusters,
+          Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
       }
-      runJob(input, clusters, output, measureClass, convergenceDelta, maxIterations, numReduceTasks);
+      KMeansDriver.runJob(input, clusters, output, measureClass, convergenceDelta, maxIterations,
+        numReduceTasks);
     } catch (OptionException e) {
-      log.error("Exception", e);
+      KMeansDriver.log.error("Exception", e);
       CommandLineUtil.printHelp(group);
     }
   }
-
+  
   /**
    * Run the job using supplied arguments
    * 
-   * @param input the directory pathname for input points
-   * @param clustersIn the directory pathname for initial & computed clusters
-   * @param output the directory pathname for output points
-   * @param measureClass the classname of the DistanceMeasure
-   * @param convergenceDelta the convergence delta value
-   * @param maxIterations the maximum number of iterations
-   * @param numReduceTasks the number of reducers
+   * @param input
+   *          the directory pathname for input points
+   * @param clustersIn
+   *          the directory pathname for initial & computed clusters
+   * @param output
+   *          the directory pathname for output points
+   * @param measureClass
+   *          the classname of the DistanceMeasure
+   * @param convergenceDelta
+   *          the convergence delta value
+   * @param maxIterations
+   *          the maximum number of iterations
+   * @param numReduceTasks
+   *          the number of reducers
    */
-  public static void runJob(String input, String clustersIn, String output, String measureClass,
-      double convergenceDelta, int maxIterations, int numReduceTasks) {
+  public static void runJob(String input,
+                            String clustersIn,
+                            String output,
+                            String measureClass,
+                            double convergenceDelta,
+                            int maxIterations,
+                            int numReduceTasks) {
     // iterate until the clusters converge
     String delta = Double.toString(convergenceDelta);
-    if (log.isInfoEnabled()) {
-      log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] {input, clustersIn, output, measureClass});
-      log.info("convergence: {} max Iterations: {} num Reduce Tasks: {} Input Vectors: {}",
-               new Object[] {convergenceDelta, maxIterations, numReduceTasks, VectorWritable.class.getName()});
+    if (KMeansDriver.log.isInfoEnabled()) {
+      KMeansDriver.log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] {input,
+                                                                                            clustersIn,
+                                                                                            output,
+                                                                                            measureClass});
+      KMeansDriver.log.info("convergence: {} max Iterations: {} num Reduce Tasks: {} Input Vectors: {}",
+        new Object[] {convergenceDelta, maxIterations, numReduceTasks, VectorWritable.class.getName()});
     }
     boolean converged = false;
     int iteration = 0;
-    while (!converged && iteration < maxIterations) {
-      log.info("Iteration {}", iteration);
+    while (!converged && (iteration < maxIterations)) {
+      KMeansDriver.log.info("Iteration {}", iteration);
       // point the output to a new directory per iteration
       String clustersOut = output + "/clusters-" + iteration;
-      converged = runIteration(input, clustersIn, clustersOut, measureClass, delta, numReduceTasks, iteration);
+      converged = KMeansDriver.runIteration(input, clustersIn, clustersOut, measureClass, delta,
+        numReduceTasks, iteration);
       // now point the input to the old output directory
       clustersIn = output + "/clusters-" + iteration;
       iteration++;
     }
     // now actually cluster the points
-    log.info("Clustering ");
-    runClustering(input, clustersIn, output + DEFAULT_OUTPUT_DIRECTORY, measureClass, delta);
+    KMeansDriver.log.info("Clustering ");
+    KMeansDriver.runClustering(input, clustersIn, output + KMeansDriver.DEFAULT_OUTPUT_DIRECTORY,
+      measureClass, delta);
   }
-
+  
   /**
    * Run the job using supplied arguments
    * 
-   * @param input the directory pathname for input points
-   * @param clustersIn the directory pathname for input clusters
-   * @param clustersOut the directory pathname for output clusters
-   * @param measureClass the classname of the DistanceMeasure
-   * @param convergenceDelta the convergence delta value
-   * @param numReduceTasks the number of reducer tasks
-   * @param iteration The iteration number
+   * @param input
+   *          the directory pathname for input points
+   * @param clustersIn
+   *          the directory pathname for input clusters
+   * @param clustersOut
+   *          the directory pathname for output clusters
+   * @param measureClass
+   *          the classname of the DistanceMeasure
+   * @param convergenceDelta
+   *          the convergence delta value
+   * @param numReduceTasks
+   *          the number of reducer tasks
+   * @param iteration
+   *          The iteration number
    * @return true if the iteration successfully runs
    */
-  private static boolean runIteration(String input, String clustersIn, String clustersOut, String measureClass,
-      String convergenceDelta, int numReduceTasks, int iteration) {
+  private static boolean runIteration(String input,
+                                      String clustersIn,
+                                      String clustersOut,
+                                      String measureClass,
+                                      String convergenceDelta,
+                                      int numReduceTasks,
+                                      int iteration) {
     JobConf conf = new JobConf(KMeansDriver.class);
     conf.setMapOutputKeyClass(Text.class);
     conf.setMapOutputValueClass(KMeansInfo.class);
     conf.setOutputKeyClass(Text.class);
     conf.setOutputValueClass(Cluster.class);
-
+    
     FileInputFormat.setInputPaths(conf, new Path(input));
     Path outPath = new Path(clustersOut);
     FileOutputFormat.setOutputPath(conf, outPath);
@@ -236,65 +271,80 @@
     try {
       JobClient.runJob(conf);
       FileSystem fs = FileSystem.get(outPath.toUri(), conf);
-      return isConverged(clustersOut, conf, fs);
+      return KMeansDriver.isConverged(clustersOut, conf, fs);
     } catch (IOException e) {
-      log.warn(e.toString(), e);
+      KMeansDriver.log.warn(e.toString(), e);
       return true;
     }
   }
-
+  
   /**
    * Run the job using supplied arguments
    * 
-   * @param input the directory pathname for input points
-   * @param clustersIn the directory pathname for input clusters
-   * @param output the directory pathname for output points
-   * @param measureClass the classname of the DistanceMeasure
-   * @param convergenceDelta the convergence delta value
+   * @param input
+   *          the directory pathname for input points
+   * @param clustersIn
+   *          the directory pathname for input clusters
+   * @param output
+   *          the directory pathname for output points
+   * @param measureClass
+   *          the classname of the DistanceMeasure
+   * @param convergenceDelta
+   *          the convergence delta value
    */
-  private static void runClustering(String input, String clustersIn, String output, String measureClass,
-      String convergenceDelta) {
-    if (log.isInfoEnabled()) {
-      log.info("Running Clustering");
-      log.info("Input: {} Clusters In: {} Out: {} Distance: {}",
-               new Object[] {input, clustersIn, output, measureClass});
-      log.info("convergence: {} Input Vectors: {}", convergenceDelta, VectorWritable.class.getName());
+  private static void runClustering(String input,
+                                    String clustersIn,
+                                    String output,
+                                    String measureClass,
+                                    String convergenceDelta) {
+    if (KMeansDriver.log.isInfoEnabled()) {
+      KMeansDriver.log.info("Running Clustering");
+      KMeansDriver.log.info("Input: {} Clusters In: {} Out: {} Distance: {}", new Object[] {input,
+                                                                                            clustersIn,
+                                                                                            output,
+                                                                                            measureClass});
+      KMeansDriver.log.info("convergence: {} Input Vectors: {}", convergenceDelta, VectorWritable.class
+          .getName());
     }
     JobConf conf = new JobConf(KMeansDriver.class);
     conf.setInputFormat(SequenceFileInputFormat.class);
     conf.setOutputFormat(SequenceFileOutputFormat.class);
-
+    
     conf.setMapOutputKeyClass(Text.class);
     conf.setMapOutputValueClass(VectorWritable.class);
     conf.setOutputKeyClass(Text.class);
     // the output is the cluster id
     conf.setOutputValueClass(Text.class);
-
+    
     FileInputFormat.setInputPaths(conf, new Path(input));
     Path outPath = new Path(output);
     FileOutputFormat.setOutputPath(conf, outPath);
-
+    
     conf.setMapperClass(KMeansClusterMapper.class);
     conf.setNumReduceTasks(0);
     conf.set(KMeansConfigKeys.CLUSTER_PATH_KEY, clustersIn);
     conf.set(KMeansConfigKeys.DISTANCE_MEASURE_KEY, measureClass);
     conf.set(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
-
+    
     try {
       JobClient.runJob(conf);
     } catch (IOException e) {
-      log.warn(e.toString(), e);
+      KMeansDriver.log.warn(e.toString(), e);
     }
   }
-
+  
   /**
    * Return if all of the Clusters in the parts in the filePath have converged or not
    * 
-   * @param filePath the file path to the single file containing the clusters
-   * @param conf the JobConf
-   * @param fs the FileSystem
+   * @param filePath
+   *          the file path to the single file containing the clusters
+   * @param conf
+   *          the JobConf
+   * @param fs
+   *          the FileSystem
    * @return true if all Clusters are converged
-   * @throws IOException if there was an IO error
+   * @throws IOException
+   *           if there was an IO error
    */
   private static boolean isConverged(String filePath, JobConf conf, FileSystem fs) throws IOException {
     FileStatus[] parts = fs.listStatus(new Path(filePath));
@@ -305,11 +355,11 @@
         Writable key;
         try {
           key = (Writable) reader.getKeyClass().newInstance();
-        } catch (InstantiationException e) {// shouldn't happen
-          log.error("Exception", e);
+        } catch (InstantiationException e) { // shouldn't happen
+          KMeansDriver.log.error("Exception", e);
           throw new IllegalStateException(e);
         } catch (IllegalAccessException e) {
-          log.error("Exception", e);
+          KMeansDriver.log.error("Exception", e);
           throw new IllegalStateException(e);
         }
         Cluster value = new Cluster();

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansInfo.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansInfo.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansInfo.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansInfo.java Sat Feb 13 21:07:53 2010
@@ -17,41 +17,40 @@
 
 package org.apache.mahout.clustering.kmeans;
 
-import org.apache.hadoop.io.Writable;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.VectorWritable;
-
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
 
-public class KMeansInfo implements Writable {
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
 
+public class KMeansInfo implements Writable {
+  
   private int points;
   private Vector pointTotal;
-
-  public KMeansInfo() {
-  }
-
+  
+  public KMeansInfo() { }
+  
   public KMeansInfo(int points, Vector pointTotal) {
     this.points = points;
     this.pointTotal = pointTotal;
   }
-
+  
   public int getPoints() {
     return points;
   }
-
+  
   public Vector getPointTotal() {
     return pointTotal;
   }
-
+  
   @Override
   public void write(DataOutput out) throws IOException {
     out.writeInt(points);
     VectorWritable.writeVector(out, pointTotal);
   }
-
+  
   @Override
   public void readFields(DataInput in) throws IOException {
     this.points = in.readInt();

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansMapper.java Sat Feb 13 21:07:53 2010
@@ -16,6 +16,10 @@
  */
 package org.apache.mahout.clustering.kmeans;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.mapred.JobConf;
@@ -26,23 +30,20 @@
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.math.VectorWritable;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
 public class KMeansMapper extends MapReduceBase implements
-    Mapper<WritableComparable<?>, VectorWritable, Text, KMeansInfo> {
-
+    Mapper<WritableComparable<?>,VectorWritable,Text,KMeansInfo> {
+  
   private KMeansClusterer clusterer;
   private final List<Cluster> clusters = new ArrayList<Cluster>();
-
+  
   @Override
-  public void map(WritableComparable<?> key, VectorWritable point,
-      OutputCollector<Text, KMeansInfo> output, Reporter reporter)
-      throws IOException {
-   this.clusterer.emitPointToNearestCluster(point.get(), this.clusters, output);
+  public void map(WritableComparable<?> key,
+                  VectorWritable point,
+                  OutputCollector<Text,KMeansInfo> output,
+                  Reporter reporter) throws IOException {
+    this.clusterer.emitPointToNearestCluster(point.get(), this.clusters, output);
   }
-
+  
   /**
    * Configure the mapper by providing its clusters. Used by unit tests.
    * 
@@ -53,27 +54,26 @@
     this.clusters.clear();
     this.clusters.addAll(clusters);
   }
-
+  
   @Override
   public void configure(JobConf job) {
     super.configure(job);
     try {
       ClassLoader ccl = Thread.currentThread().getContextClassLoader();
-      Class<?> cl = ccl.loadClass(job
-          .get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
+      Class<?> cl = ccl.loadClass(job.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
       DistanceMeasure measure = (DistanceMeasure) cl.newInstance();
       measure.configure(job);
-
+      
       this.clusterer = new KMeansClusterer(measure);
-
+      
       String clusterPath = job.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
-      if (clusterPath != null && clusterPath.length() > 0) {
+      if ((clusterPath != null) && (clusterPath.length() > 0)) {
         KMeansUtil.configureWithClusterInfo(clusterPath, clusters);
         if (clusters.isEmpty()) {
           throw new IllegalStateException("Cluster is empty!");
         }
       }
-
+      
     } catch (ClassNotFoundException e) {
       throw new IllegalStateException(e);
     } catch (IllegalAccessException e) {

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansReducer.java Sat Feb 13 21:07:53 2010
@@ -16,14 +16,6 @@
  */
 package org.apache.mahout.clustering.kmeans;
 
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapred.JobConf;
-import org.apache.hadoop.mapred.MapReduceBase;
-import org.apache.hadoop.mapred.OutputCollector;
-import org.apache.hadoop.mapred.Reducer;
-import org.apache.hadoop.mapred.Reporter;
-import org.apache.mahout.common.distance.DistanceMeasure;
-
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -31,18 +23,27 @@
 import java.util.List;
 import java.util.Map;
 
-public class KMeansReducer extends MapReduceBase implements
-    Reducer<Text, KMeansInfo, Text, Cluster> {
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reducer;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.common.distance.DistanceMeasure;
 
-  private Map<String, Cluster> clusterMap;
+public class KMeansReducer extends MapReduceBase implements Reducer<Text,KMeansInfo,Text,Cluster> {
+  
+  private Map<String,Cluster> clusterMap;
   private double convergenceDelta;
   private DistanceMeasure measure;
-
+  
   @Override
-  public void reduce(Text key, Iterator<KMeansInfo> values,
-                     OutputCollector<Text, Cluster> output, Reporter reporter) throws IOException {
+  public void reduce(Text key,
+                     Iterator<KMeansInfo> values,
+                     OutputCollector<Text,Cluster> output,
+                     Reporter reporter) throws IOException {
     Cluster cluster = clusterMap.get(key.toString());
-
+    
     while (values.hasNext()) {
       KMeansInfo delta = values.next();
       cluster.addPoints(delta.getPoints(), delta.getPointTotal());
@@ -54,23 +55,21 @@
     }
     output.collect(new Text(cluster.getIdentifier()), cluster);
   }
-
+  
   @Override
   public void configure(JobConf job) {
-
+    
     super.configure(job);
     try {
       ClassLoader ccl = Thread.currentThread().getContextClassLoader();
-      Class<?> cl = ccl.loadClass(job
-          .get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
+      Class<?> cl = ccl.loadClass(job.get(KMeansConfigKeys.DISTANCE_MEASURE_KEY));
       this.measure = (DistanceMeasure) cl.newInstance();
       this.measure.configure(job);
-
-      this.convergenceDelta = Double.parseDouble(job
-          .get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
-
-      this.clusterMap = new HashMap<String, Cluster>();
-
+      
+      this.convergenceDelta = Double.parseDouble(job.get(KMeansConfigKeys.CLUSTER_CONVERGENCE_KEY));
+      
+      this.clusterMap = new HashMap<String,Cluster>();
+      
       String path = job.get(KMeansConfigKeys.CLUSTER_PATH_KEY);
       if (path.length() > 0) {
         List<Cluster> clusters = new ArrayList<Cluster>();
@@ -88,18 +87,18 @@
       throw new IllegalStateException(e);
     }
   }
-
+  
   private void setClusterMap(List<Cluster> clusters) {
-    clusterMap = new HashMap<String, Cluster>();
+    clusterMap = new HashMap<String,Cluster>();
     for (Cluster cluster : clusters) {
       clusterMap.put(cluster.getIdentifier(), cluster);
     }
     clusters.clear();
   }
-
+  
   public void config(List<Cluster> clusters) {
     setClusterMap(clusters);
-
+    
   }
-
+  
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,10 @@
 
 package org.apache.mahout.clustering.kmeans;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.FileUtil;
@@ -30,26 +34,20 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
 final class KMeansUtil {
-
+  
   private static final Logger log = LoggerFactory.getLogger(KMeansUtil.class);
-
-  private KMeansUtil() {
-  }
-
+  
+  private KMeansUtil() { }
+  
   /** Configure the mapper with the cluster info */
-  public static void configureWithClusterInfo(String clusterPathStr,
-                                              List<Cluster> clusters) {
-
+  public static void configureWithClusterInfo(String clusterPathStr, List<Cluster> clusters) {
+    
     // Get the path location where the cluster Info is stored
     JobConf job = new JobConf(KMeansUtil.class);
     Path clusterPath = new Path(clusterPathStr + "/*");
     List<Path> result = new ArrayList<Path>();
-
+    
     // filter out the files
     PathFilter clusterFileFilter = new PathFilter() {
       @Override
@@ -57,17 +55,17 @@
         return path.getName().startsWith("part");
       }
     };
-
+    
     try {
       // get all filtered file names in result list
       FileSystem fs = clusterPath.getFileSystem(job);
-      FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(
-          clusterPath, clusterFileFilter)), clusterFileFilter);
-
+      FileStatus[] matches = fs.listStatus(
+        FileUtil.stat2Paths(fs.globStatus(clusterPath, clusterFileFilter)), clusterFileFilter);
+      
       for (FileStatus match : matches) {
         result.add(fs.makeQualified(match.getPath()));
       }
-
+      
       // iterate thru the result path list
       for (Path path : result) {
         SequenceFile.Reader reader = null;
@@ -77,11 +75,11 @@
           Writable key;
           try {
             key = (Writable) reader.getKeyClass().newInstance();
-          } catch (InstantiationException e) {//Should not be possible
-            log.error("Exception", e);
+          } catch (InstantiationException e) { // Should not be possible
+            KMeansUtil.log.error("Exception", e);
             throw new IllegalStateException(e);
           } catch (IllegalAccessException e) {
-            log.error("Exception", e);
+            KMeansUtil.log.error("Exception", e);
             throw new IllegalStateException(e);
           }
           if (valueClass.equals(Cluster.class)) {
@@ -104,11 +102,11 @@
           IOUtils.quietClose(reader);
         }
       }
-
+      
     } catch (IOException e) {
-      log.info("Exception occurred in loading clusters:", e);
+      KMeansUtil.log.info("Exception occurred in loading clusters:", e);
       throw new IllegalStateException(e);
     }
   }
-
+  
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java Sat Feb 13 21:07:53 2010
@@ -17,6 +17,11 @@
 
 package org.apache.mahout.clustering.kmeans;
 
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
@@ -29,28 +34,23 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-
-
 /**
- * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, randomly select k vectors and write them
- * to the output file as a {@link org.apache.mahout.clustering.kmeans.Cluster} representing the initial centroid to use.
+ * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, randomly select k vectors and
+ * write them to the output file as a {@link org.apache.mahout.clustering.kmeans.Cluster} representing the
+ * initial centroid to use.
  * <p/>
  */
 public final class RandomSeedGenerator {
-
+  
   private static final Logger log = LoggerFactory.getLogger(RandomSeedGenerator.class);
-
+  
   public static final String K = "k";
-
-  private RandomSeedGenerator() {
-  }
-
-  public static Path buildRandom(String input, String output,
-                                 int k) throws IOException, IllegalAccessException, InstantiationException {
+  
+  private RandomSeedGenerator() { }
+  
+  public static Path buildRandom(String input, String output, int k) throws IOException,
+                                                                    IllegalAccessException,
+                                                                    InstantiationException {
     // delete the output directory
     JobConf conf = new JobConf(RandomSeedGenerator.class);
     Path outPath = new Path(output);
@@ -61,7 +61,7 @@
     fs.mkdirs(outPath);
     Path outFile = new Path(outPath, "part-randomSeed");
     if (fs.exists(outFile)) {
-      log.warn("Deleting {}", outFile);
+      RandomSeedGenerator.log.warn("Deleting {}", outFile);
       fs.delete(outFile, false);
     }
     boolean newFile = fs.createNewFile(outFile);
@@ -83,7 +83,9 @@
       int nextClusterId = 0;
       
       for (FileStatus fileStatus : inputFiles) {
-        if(fileStatus.isDir() == true) continue; // select only the top level files
+        if (fileStatus.isDir() == true) {
+          continue; // select only the top level files
+        }
         SequenceFile.Reader reader = new SequenceFile.Reader(fs, fileStatus.getPath(), conf);
         Writable key = (Writable) reader.getKeyClass().newInstance();
         VectorWritable value = (VectorWritable) reader.getValueClass().newInstance();
@@ -109,10 +111,10 @@
       for (int i = 0; i < k; i++) {
         writer.append(chosenTexts.get(i), chosenClusters.get(i));
       }
-      log.info("Wrote {} vectors to {}", k, outFile);
+      RandomSeedGenerator.log.info("Wrote {} vectors to {}", k, outFile);
       writer.close();
     }
-
+    
     return outFile;
   }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java Sat Feb 13 21:07:53 2010
@@ -26,50 +26,49 @@
 import org.apache.hadoop.io.WritableComparator;
 
 /**
-* Saves two ints, x and y.
-*/
+ * Saves two ints, x and y.
+ */
 public class IntPairWritable implements WritableComparable<IntPairWritable> {
-
+  
   private int x;
   private int y;
-
+  
   /** For serialization purposes only */
-  public IntPairWritable() {
-  }
-
+  public IntPairWritable() { }
+  
   public IntPairWritable(int x, int y) {
     this.x = x;
     this.y = y;
   }
-
+  
   public void setX(int x) {
     this.x = x;
   }
-
+  
   public int getX() {
     return x;
   }
-
+  
   public void setY(int y) {
     this.y = y;
   }
-
+  
   public int getY() {
     return y;
   }
-
+  
   @Override
   public void write(DataOutput dataOutput) throws IOException {
     dataOutput.writeInt(x);
     dataOutput.writeInt(y);
   }
-
+  
   @Override
   public void readFields(DataInput dataInput) throws IOException {
     x = dataInput.readInt();
     y = dataInput.readInt();
   }
-
+  
   @Override
   public int compareTo(IntPairWritable that) {
     if (this.x < that.getX()) {
@@ -80,51 +79,52 @@
       return this.y < that.getY() ? -1 : this.y > that.getY() ? 1 : 0;
     }
   }
-
+  
+  @Override
   public boolean equals(Object o) {
-    if (this == o) { 
+    if (this == o) {
       return true;
     } else if (!(o instanceof IntPairWritable)) {
       return false;
     }
-
+    
     IntPairWritable that = (IntPairWritable) o;
-
-    return that.getX() == this.x && this.y == that.getY();
+    
+    return (that.getX() == this.x) && (this.y == that.getY());
   }
-
+  
   @Override
   public int hashCode() {
     return 43 * x + y;
   }
-
+  
   @Override
   public String toString() {
     return "(" + x + ", " + y + ')';
   }
-
+  
   static {
     WritableComparator.define(IntPairWritable.class, new Comparator());
   }
-
+  
   public static class Comparator extends WritableComparator implements Serializable {
     public Comparator() {
       super(IntPairWritable.class);
     }
-
+    
     @Override
     public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
       if (l1 != 8) {
         throw new IllegalArgumentException();
       }
-      int int11 = readInt(b1, s1);
-      int int21 = readInt(b2, s2);
+      int int11 = WritableComparator.readInt(b1, s1);
+      int int21 = WritableComparator.readInt(b2, s2);
       if (int11 != int21) {
         return int11 - int21;
       }
-
-      int int12 = readInt(b1, s1 + 4);
-      int int22 = readInt(b2, s2 + 4);
+      
+      int int12 = WritableComparator.readInt(b1, s1 + 4);
+      int int22 = WritableComparator.readInt(b2, s2 + 4);
       return int12 - int22;
     }
   }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Sat Feb 13 21:07:53 2010
@@ -39,117 +39,114 @@
 import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
-import org.apache.mahout.math.DenseMatrix;
 import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 /**
- * Estimates an LDA model from a corpus of documents,
- * which are SparseVectors of word counts. At each
- * phase, it outputs a matrix of log probabilities of
- * each topic.
+ * Estimates an LDA model from a corpus of documents, which are SparseVectors of word counts. At each phase,
+ * it outputs a matrix of log probabilities of each topic.
  */
 public final class LDADriver {
-
+  
   static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
-
+  
   static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
   static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
-
+  
   static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";
-
+  
   static final int LOG_LIKELIHOOD_KEY = -2;
   static final int TOPIC_SUM_KEY = -1;
-
+  
   static final double OVERALL_CONVERGENCE = 1.0E-5;
-
+  
   private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
-
-  private LDADriver() {
-  }
-
-  public static void main(String[] args) throws ClassNotFoundException,
-          IOException, InterruptedException {
-
+  
+  private LDADriver() { }
+  
+  public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {
+    
     DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
     ArgumentBuilder abuilder = new ArgumentBuilder();
     GroupBuilder gbuilder = new GroupBuilder();
-
+    
     Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
-            abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
-            "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
-
+      abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+    
     Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
-            abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
-            "The Output Working Directory").withShortName("o").create();
-
+      abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The Output Working Directory").withShortName("o").create();
+    
     Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withDescription(
-            "If set, overwrite the output directory").withShortName("w").create();
-
+      "If set, overwrite the output directory").withShortName("w").create();
+    
     Option topicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(
-            abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription(
-            "The number of topics").withShortName("k").create();
-
+      abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The number of topics").withShortName("k").create();
+    
     Option wordsOpt = obuilder.withLongName("numWords").withRequired(true).withArgument(
-            abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
-            "The total number of words in the corpus").withShortName("v").create();
-
-    Option topicSmOpt = obuilder.withLongName("topicSmoothing").withRequired(false).withArgument(abuilder
-            .withName("topicSmoothing").withDefault(-1.0).withMinimum(0).withMaximum(1).create()).withDescription(
-            "Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create();
-
+      abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
+      "The total number of words in the corpus").withShortName("v").create();
+    
+    Option topicSmOpt = obuilder.withLongName("topicSmoothing").withRequired(false).withArgument(
+      abuilder.withName("topicSmoothing").withDefault(-1.0).withMinimum(0).withMaximum(1).create())
+        .withDescription("Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create();
+    
     Option maxIterOpt = obuilder.withLongName("maxIter").withRequired(false).withArgument(
-            abuilder.withName("maxIter").withDefault(-1).withMinimum(0).withMaximum(1).create()).withDescription(
-            "Max iterations to run (or until convergence). -1 (default) waits until convergence.").create();
-
+      abuilder.withName("maxIter").withDefault(-1).withMinimum(0).withMaximum(1).create()).withDescription(
+      "Max iterations to run (or until convergence). -1 (default) waits until convergence.").create();
+    
     Option numReducOpt = obuilder.withLongName("numReducers").withRequired(false).withArgument(
-            abuilder.withName("numReducers").withDefault(10).withMinimum(0).withMaximum(1).create()).withDescription(
-            "Max iterations to run (or until convergence). Default 10").create();
-
-    Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
-
+      abuilder.withName("numReducers").withDefault(10).withMinimum(0).withMaximum(1).create())
+        .withDescription("Max iterations to run (or until convergence). Default 10").create();
+    
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+        .create();
+    
     Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(
-            topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt).withOption(
-            numReducOpt).withOption(overwriteOutput).withOption(helpOpt).create();
+      topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt).withOption(numReducOpt)
+        .withOption(overwriteOutput).withOption(helpOpt).create();
     try {
       Parser parser = new Parser();
       parser.setGroup(group);
       CommandLine cmdLine = parser.parse(args);
-
+      
       if (cmdLine.hasOption(helpOpt)) {
         CommandLineUtil.printHelp(group);
         return;
       }
       String input = cmdLine.getValue(inputOpt).toString();
       String output = cmdLine.getValue(outputOpt).toString();
-
+      
       int maxIterations = -1;
       if (cmdLine.hasOption(maxIterOpt)) {
         maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
       }
-
+      
       int numReduceTasks = 2;
       if (cmdLine.hasOption(numReducOpt)) {
         numReduceTasks = Integer.parseInt(cmdLine.getValue(numReducOpt).toString());
       }
-
+      
       int numTopics = 20;
       if (cmdLine.hasOption(topicsOpt)) {
         numTopics = Integer.parseInt(cmdLine.getValue(topicsOpt).toString());
       }
-
+      
       int numWords = 20;
       if (cmdLine.hasOption(wordsOpt)) {
         numWords = Integer.parseInt(cmdLine.getValue(wordsOpt).toString());
       }
-
+      
       if (cmdLine.hasOption(overwriteOutput)) {
         HadoopUtil.overwriteOutput(output);
       }
-
+      
       double topicSmoothing = -1.0;
       if (cmdLine.hasOption(topicSmOpt)) {
         topicSmoothing = Double.parseDouble(cmdLine.getValue(maxIterOpt).toString());
@@ -157,72 +154,81 @@
       if (topicSmoothing < 1) {
         topicSmoothing = 50.0 / numTopics;
       }
-
-      runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations,
-              numReduceTasks);
-
+      
+      LDADriver.runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations, numReduceTasks);
+      
     } catch (OptionException e) {
-      log.error("Exception", e);
+      LDADriver.log.error("Exception", e);
       CommandLineUtil.printHelp(group);
     }
   }
-
+  
   /**
    * Run the job using supplied arguments
-   *
-   * @param input          the directory pathname for input points
-   * @param output         the directory pathname for output points
-   * @param numTopics      the number of topics
-   * @param numWords       the number of words
-   * @param topicSmoothing pseudocounts for each topic, typically small &lt; .5
-   * @param maxIterations  the maximum number of iterations
-   * @param numReducers    the number of Reducers desired
+   * 
+   * @param input
+   *          the directory pathname for input points
+   * @param output
+   *          the directory pathname for output points
+   * @param numTopics
+   *          the number of topics
+   * @param numWords
+   *          the number of words
+   * @param topicSmoothing
+   *          pseudocounts for each topic, typically small &lt; .5
+   * @param maxIterations
+   *          the maximum number of iterations
+   * @param numReducers
+   *          the number of Reducers desired
    * @throws IOException
    */
-  public static void runJob(String input, String output, int numTopics,
-                            int numWords, double topicSmoothing, int maxIterations, int numReducers)
-          throws IOException, InterruptedException, ClassNotFoundException {
-
+  public static void runJob(String input,
+                            String output,
+                            int numTopics,
+                            int numWords,
+                            double topicSmoothing,
+                            int maxIterations,
+                            int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+    
     String stateIn = output + "/state-0";
-    writeInitialState(stateIn, numTopics, numWords);
+    LDADriver.writeInitialState(stateIn, numTopics, numWords);
     double oldLL = Double.NEGATIVE_INFINITY;
     boolean converged = false;
-
-    for (int iteration = 0; (maxIterations < 1 || iteration < maxIterations) && !converged; iteration++) {
-      log.info("Iteration {}", iteration);
+    
+    for (int iteration = 0; ((maxIterations < 1) || (iteration < maxIterations)) && !converged; iteration++) {
+      LDADriver.log.info("Iteration {}", iteration);
       // point the output to a new directory per iteration
       String stateOut = output + "/state-" + (iteration + 1);
-      double ll = runIteration(input, stateIn, stateOut, numTopics,
-              numWords, topicSmoothing, numReducers);
+      double ll = LDADriver.runIteration(input, stateIn, stateOut, numTopics, numWords, topicSmoothing,
+        numReducers);
       double relChange = (oldLL - ll) / oldLL;
-
+      
       // now point the input to the old output directory
-      log.info("Iteration {} finished. Log Likelihood: {}", iteration, ll);
-      log.info("(Old LL: {})", oldLL);
-      log.info("(Rel Change: {})", relChange);
-
-      converged = iteration > 2 && relChange < OVERALL_CONVERGENCE;
+      LDADriver.log.info("Iteration {} finished. Log Likelihood: {}", iteration, ll);
+      LDADriver.log.info("(Old LL: {})", oldLL);
+      LDADriver.log.info("(Rel Change: {})", relChange);
+      
+      converged = (iteration > 2) && (relChange < LDADriver.OVERALL_CONVERGENCE);
       stateIn = stateOut;
       oldLL = ll;
     }
   }
-
-  private static void writeInitialState(String statePath,
-                                        int numTopics, int numWords) throws IOException {
+  
+  private static void writeInitialState(String statePath, int numTopics, int numWords) throws IOException {
     Path dir = new Path(statePath);
     Configuration job = new Configuration();
     FileSystem fs = dir.getFileSystem(job);
-
+    
     IntPairWritable kw = new IntPairWritable();
     DoubleWritable v = new DoubleWritable();
-
+    
     Random random = RandomUtils.getRandom();
-
+    
     for (int k = 0; k < numTopics; ++k) {
       Path path = new Path(dir, "part-" + k);
-      SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path,
-              IntPairWritable.class, DoubleWritable.class);
-
+      SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class,
+          DoubleWritable.class);
+      
       kw.setX(k);
       double total = 0.0; // total number of pseudo counts we made
       for (int w = 0; w < numWords; ++w) {
@@ -233,64 +239,75 @@
         v.set(Math.log(pseudocount));
         writer.append(kw, v);
       }
-
-      kw.setY(TOPIC_SUM_KEY);
+      
+      kw.setY(LDADriver.TOPIC_SUM_KEY);
       v.set(Math.log(total));
       writer.append(kw, v);
-
+      
       writer.close();
     }
   }
-
+  
   private static double findLL(String statePath, Configuration job) throws IOException {
     Path dir = new Path(statePath);
     FileSystem fs = dir.getFileSystem(job);
-
+    
     double ll = 0.0;
-
+    
     IntPairWritable key = new IntPairWritable();
     DoubleWritable value = new DoubleWritable();
     for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
       Path path = status.getPath();
       SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
       while (reader.next(key, value)) {
-        if (key.getX() == LOG_LIKELIHOOD_KEY) {
+        if (key.getX() == LDADriver.LOG_LIKELIHOOD_KEY) {
           ll = value.get();
           break;
         }
       }
       reader.close();
     }
-
+    
     return ll;
   }
-
+  
   /**
    * Run the job using supplied arguments
-   *
-   * @param input       the directory pathname for input points
-   * @param stateIn     the directory pathname for input state
-   * @param stateOut    the directory pathname for output state
-   * @param numTopics   the number of clusters
-   * @param numReducers the number of Reducers desired
+   * 
+   * @param input
+   *          the directory pathname for input points
+   * @param stateIn
+   *          the directory pathname for input state
+   * @param stateOut
+   *          the directory pathname for output state
+   * @param numTopics
+   *          the number of clusters
+   * @param numReducers
+   *          the number of Reducers desired
    */
-  public static double runIteration(String input, String stateIn,
-                                    String stateOut, int numTopics, int numWords, double topicSmoothing,
-                                    int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+  public static double runIteration(String input,
+                                    String stateIn,
+                                    String stateOut,
+                                    int numTopics,
+                                    int numWords,
+                                    double topicSmoothing,
+                                    int numReducers) throws IOException,
+                                                    InterruptedException,
+                                                    ClassNotFoundException {
     Configuration conf = new Configuration();
-    conf.set(STATE_IN_KEY, stateIn);
-    conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
-    conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
-    conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
-
+    conf.set(LDADriver.STATE_IN_KEY, stateIn);
+    conf.set(LDADriver.NUM_TOPICS_KEY, Integer.toString(numTopics));
+    conf.set(LDADriver.NUM_WORDS_KEY, Integer.toString(numWords));
+    conf.set(LDADriver.TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
+    
     Job job = new Job(conf);
-
+    
     job.setOutputKeyClass(IntPairWritable.class);
     job.setOutputValueClass(DoubleWritable.class);
     FileInputFormat.addInputPaths(job, input);
     Path outPath = new Path(stateOut);
     FileOutputFormat.setOutputPath(job, outPath);
-
+    
     job.setMapperClass(LDAMapper.class);
     job.setReducerClass(LDAReducer.class);
     job.setCombinerClass(LDAReducer.class);
@@ -298,24 +315,24 @@
     job.setOutputFormatClass(SequenceFileOutputFormat.class);
     job.setInputFormatClass(SequenceFileInputFormat.class);
     job.setJarByClass(LDADriver.class);
-
+    
     job.waitForCompletion(true);
-    return findLL(stateOut, conf);
+    return LDADriver.findLL(stateOut, conf);
   }
-
+  
   static LDAState createState(Configuration job) throws IOException {
-    String statePath = job.get(STATE_IN_KEY);
-    int numTopics = Integer.parseInt(job.get(NUM_TOPICS_KEY));
-    int numWords = Integer.parseInt(job.get(NUM_WORDS_KEY));
-    double topicSmoothing = Double.parseDouble(job.get(TOPIC_SMOOTHING_KEY));
-
+    String statePath = job.get(LDADriver.STATE_IN_KEY);
+    int numTopics = Integer.parseInt(job.get(LDADriver.NUM_TOPICS_KEY));
+    int numWords = Integer.parseInt(job.get(LDADriver.NUM_WORDS_KEY));
+    double topicSmoothing = Double.parseDouble(job.get(LDADriver.TOPIC_SMOOTHING_KEY));
+    
     Path dir = new Path(statePath);
     FileSystem fs = dir.getFileSystem(job);
-
+    
     DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
     double[] logTotals = new double[numTopics];
     double ll = 0.0;
-
+    
     IntPairWritable key = new IntPairWritable();
     DoubleWritable value = new DoubleWritable();
     for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
@@ -324,15 +341,15 @@
       while (reader.next(key, value)) {
         int topic = key.getX();
         int word = key.getY();
-        if (word == TOPIC_SUM_KEY) {
+        if (word == LDADriver.TOPIC_SUM_KEY) {
           logTotals[topic] = value.get();
           if (Double.isInfinite(value.get())) {
             throw new IllegalArgumentException();
           }
-        } else if (topic == LOG_LIKELIHOOD_KEY) {
+        } else if (topic == LDADriver.LOG_LIKELIHOOD_KEY) {
           ll = value.get();
         } else {
-          if (!(topic >= 0 && word >= 0)) {
+          if (!((topic >= 0) && (word >= 0))) {
             throw new IllegalArgumentException(topic + " " + word);
           }
           if (pWgT.getQuick(topic, word) != 0.0) {
@@ -346,8 +363,7 @@
       }
       reader.close();
     }
-
-    return new LDAState(numTopics, numWords, topicSmoothing,
-            pWgT, logTotals, ll);
+    
+    return new LDAState(numTopics, numWords, topicSmoothing, pWgT, logTotals, ll);
   }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java Sat Feb 13 21:07:53 2010
@@ -22,157 +22,151 @@
 import java.util.Map;
 
 import org.apache.commons.math.special.Gamma;
-import org.apache.mahout.math.function.BinaryFunction;
 import org.apache.mahout.math.DenseMatrix;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.BinaryFunction;
 
 /**
- * Class for performing infererence on a document, which involves
- * computing (an approximation to) p(word|topic) for each word and
- * topic, and a prior distribution p(topic) for each topic.
+ * Class for performing infererence on a document, which involves computing (an approximation to)
+ * p(word|topic) for each word and topic, and a prior distribution p(topic) for each topic.
  */
 public class LDAInference {
-
+  
   private static final double E_STEP_CONVERGENCE = 1.0E-6;
   private static final int MAX_ITER = 20;
-
+  
   public LDAInference(LDAState state) {
     this.state = state;
   }
-
+  
   /**
-  * An estimate of the probabilitys for each document.
-  * Gamma(k) is the probability of seeing topic k in
-  * the document, phi(k,w) is the probability of
-  * topic k generating w in this document.
-  */
+   * An estimate of the probabilitys for each document. Gamma(k) is the probability of seeing topic k in the
+   * document, phi(k,w) is the probability of topic k generating w in this document.
+   */
   public static class InferredDocument {
-
+    
     private final Vector wordCounts;
     private final Vector gamma; // p(topic)
     private final Matrix mphi; // log p(columnMap(w)|t)
-    private final Map<Integer, Integer> columnMap; // maps words into the matrix's column map
+    private final Map<Integer,Integer> columnMap; // maps words into the matrix's column map
     public final double logLikelihood;
-
+    
     public double phi(int k, int w) {
       return mphi.getQuick(k, columnMap.get(w));
     }
-
-    InferredDocument(Vector wordCounts, Vector gamma,
-                     Map<Integer, Integer> columnMap, Matrix phi,
-                     double ll) {
+    
+    InferredDocument(Vector wordCounts, Vector gamma, Map<Integer,Integer> columnMap, Matrix phi, double ll) {
       this.wordCounts = wordCounts;
       this.gamma = gamma;
       this.mphi = phi;
       this.columnMap = columnMap;
       this.logLikelihood = ll;
     }
-
+    
     public Vector getWordCounts() {
       return wordCounts;
     }
-
+    
     public Vector getGamma() {
       return gamma;
     }
   }
-
+  
   /**
-  * Performs inference on the given document, returning
-  * an InferredDocument.
-  */
+   * Performs inference on the given document, returning an InferredDocument.
+   */
   public InferredDocument infer(Vector wordCounts) {
     double docTotal = wordCounts.zSum();
     int docLength = wordCounts.size();
-
+    
     // initialize variational approximation to p(z|doc)
     Vector gamma = new DenseVector(state.numTopics);
     gamma.assign(state.topicSmoothing + docTotal / state.numTopics);
     Vector nextGamma = new DenseVector(state.numTopics);
-
+    
     DenseMatrix phi = new DenseMatrix(state.numTopics, docLength);
-
+    
     // digamma is expensive, precompute
-    Vector digammaGamma = digamma(gamma);
+    Vector digammaGamma = LDAInference.digamma(gamma);
     // and log normalize:
-    double digammaSumGamma = digamma(gamma.zSum());
+    double digammaSumGamma = LDAInference.digamma(gamma.zSum());
     digammaGamma = digammaGamma.plus(-digammaSumGamma);
-
-    Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();
-
+    
+    Map<Integer,Integer> columnMap = new HashMap<Integer,Integer>();
+    
     int iteration = 0;
-
+    
     boolean converged = false;
     double oldLL = 1;
-    while (!converged && iteration < MAX_ITER) {
+    while (!converged && (iteration < LDAInference.MAX_ITER)) {
       nextGamma.assign(state.topicSmoothing); // nG := alpha, for all topics
-
+      
       int mapping = 0;
-      for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
-          iter.hasNext();) {
-      Vector.Element e = iter.next();
+      for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
+        Vector.Element e = iter.next();
         int word = e.index();
         Vector phiW = eStepForWord(word, digammaGamma);
         phi.assignColumn(mapping, phiW);
         if (iteration == 0) { // first iteration
           columnMap.put(word, mapping);
         }
-
+        
         for (int k = 0; k < nextGamma.size(); ++k) {
           double g = nextGamma.getQuick(k);
           nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.get(k)));
         }
-
+        
         mapping++;
       }
-
+      
       Vector tempG = gamma;
       gamma = nextGamma;
       nextGamma = tempG;
-
+      
       // digamma is expensive, precompute
-      digammaGamma = digamma(gamma);
+      digammaGamma = LDAInference.digamma(gamma);
       // and log normalize:
-      digammaSumGamma = digamma(gamma.zSum());
+      digammaSumGamma = LDAInference.digamma(gamma.zSum());
       digammaGamma = digammaGamma.plus(-digammaSumGamma);
-
+      
       double ll = computeLikelihood(wordCounts, columnMap, phi, gamma, digammaGamma);
       assert !Double.isNaN(ll);
-      converged = oldLL < 0 && ((oldLL - ll) / oldLL < E_STEP_CONVERGENCE);
-
+      converged = (oldLL < 0) && ((oldLL - ll) / oldLL < LDAInference.E_STEP_CONVERGENCE);
+      
       oldLL = ll;
       iteration++;
     }
-
+    
     return new InferredDocument(wordCounts, gamma, columnMap, phi, oldLL);
   }
-
+  
   private final LDAState state;
-
-  private double computeLikelihood(Vector wordCounts, Map<Integer, Integer> columnMap,
-      Matrix phi, Vector gamma, Vector digammaGamma) {
+  
+  private double computeLikelihood(Vector wordCounts,
+                                   Map<Integer,Integer> columnMap,
+                                   Matrix phi,
+                                   Vector gamma,
+                                   Vector digammaGamma) {
     double ll = 0.0;
-
+    
     // log normalizer for q(gamma);
     ll += Gamma.logGamma(state.topicSmoothing * state.numTopics);
     ll -= state.numTopics * Gamma.logGamma(state.topicSmoothing);
     assert !Double.isNaN(ll) : state.topicSmoothing + " " + state.numTopics;
-
+    
     // now for the the rest of q(gamma);
     for (int k = 0; k < state.numTopics; ++k) {
       ll += (state.topicSmoothing - gamma.get(k)) * digammaGamma.get(k);
       ll += Gamma.logGamma(gamma.get(k));
-
+      
     }
     ll -= Gamma.logGamma(gamma.zSum());
     assert !Double.isNaN(ll);
-
-
+    
     // for each word
-    for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
-        iter.hasNext();) {
+    for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
       Vector.Element e = iter.next();
       int w = e.index();
       double n = e.get();
@@ -181,19 +175,18 @@
       for (int k = 0; k < state.numTopics; k++) {
         double llPart = 0.0;
         llPart += Math.exp(phi.get(k, mapping))
-          * (digammaGamma.get(k) - phi.get(k, mapping)
-             + state.logProbWordGivenTopic(w, k));
-
+                  * (digammaGamma.get(k) - phi.get(k, mapping) + state.logProbWordGivenTopic(w, k));
+        
         ll += llPart * n;
-
-        assert state.logProbWordGivenTopic(w, k)  < 0;
+        
+        assert state.logProbWordGivenTopic(w, k) < 0;
         assert !Double.isNaN(llPart);
       }
     }
     assert ll <= 0;
     return ll;
   }
-
+  
   /**
    * Compute log q(k|w,doc) for each topic k, for a given word.
    */
@@ -203,7 +196,7 @@
     for (int k = 0; k < state.numTopics; ++k) { // update q(k|w)'s param phi
       phi.set(k, state.logProbWordGivenTopic(word, k) + digammaGamma.get(k));
       phiTotal = LDAUtil.logSum(phiTotal, phi.get(k));
-
+      
       assert !Double.isNaN(phiTotal);
       assert !Double.isNaN(state.logProbWordGivenTopic(word, k));
       assert !Double.isInfinite(state.logProbWordGivenTopic(word, k));
@@ -211,57 +204,53 @@
     }
     return phi.plus(-phiTotal); // log normalize
   }
-
-
+  
   private static Vector digamma(Vector v) {
     Vector digammaGamma = new DenseVector(v.size());
     digammaGamma.assign(v, new BinaryFunction() {
       @Override
       public double apply(double unused, double g) {
-        return digamma(g);
+        return LDAInference.digamma(g);
       }
     });
     return digammaGamma;
   }
-
+  
   /**
    * Approximation to the digamma function, from Radford Neal.
-   *
-   * Original License:
-   * Copyright (c) 1995-2003 by Radford M. Neal
-   *
-   * Permission is granted for anyone to copy, use, modify, or distribute this
-   * program and accompanying programs and documents for any purpose, provided
-   * this copyright notice is retained and prominently displayed, along with
-   * a note saying that the original programs are available from Radford Neal's
-   * web page, and note is made of any changes made to the programs.  The
-   * programs and documents are distributed without any warranty, express or
-   * implied.  As the programs were written for research purposes only, they have
-   * not been tested to the degree that would be advisable in any important
-   * application.  All use of these programs is entirely at the user's own risk.
-   *
-   *
+   * 
+   * Original License: Copyright (c) 1995-2003 by Radford M. Neal
+   * 
+   * Permission is granted for anyone to copy, use, modify, or distribute this program and accompanying
+   * programs and documents for any purpose, provided this copyright notice is retained and prominently
+   * displayed, along with a note saying that the original programs are available from Radford Neal's web
+   * page, and note is made of any changes made to the programs. The programs and documents are distributed
+   * without any warranty, express or implied. As the programs were written for research purposes only, they
+   * have not been tested to the degree that would be advisable in any important application. All use of these
+   * programs is entirely at the user's own risk.
+   * 
+   * 
    * Ported to Java for Mahout.
-   *
+   * 
    */
   private static double digamma(double x) {
     double r = 0.0;
-
+    
     while (x <= 5) {
       r -= 1 / x;
       x += 1;
     }
-
+    
     double f = 1.0 / (x * x);
-    double t = f * (-1 / 12.0
-        + f * (1 / 120.0
-        + f * (-1 / 252.0
-        + f * (1 / 240.0 
-        + f * (-1 / 132.0 
-        + f * (691 / 32760.0 
-        + f * (-1 / 12.0 
-        + f * 3617.0 / 8160.0)))))));
+    double t = f
+               * (-1 / 12.0 + f
+                              * (1 / 120.0 + f
+                                             * (-1 / 252.0 + f
+                                                             * (1 / 240.0 + f
+                                                                            * (-1 / 132.0 + f
+                                                                                            * (691 / 32760.0 + f
+                                                                                                               * (-1 / 12.0 + f * 3617.0 / 8160.0)))))));
     return r + Math.log(x) - 0.5 / x + t;
   }
-
+  
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java Sat Feb 13 21:07:53 2010
@@ -29,39 +29,35 @@
 import org.apache.mahout.math.VectorWritable;
 
 /**
-* Runs inference on the input documents (which are
-* sparse vectors of word counts) and outputs
-* the sufficient statistics for the word-topic
-* assignments.
-*/
-public class LDAMapper extends 
-    Mapper<WritableComparable<?>, VectorWritable, IntPairWritable, DoubleWritable> {
-
+ * Runs inference on the input documents (which are sparse vectors of word counts) and outputs the sufficient
+ * statistics for the word-topic assignments.
+ */
+public class LDAMapper extends Mapper<WritableComparable<?>,VectorWritable,IntPairWritable,DoubleWritable> {
+  
   private LDAState state;
   private LDAInference infer;
-
+  
   @Override
-  public void map(WritableComparable<?> key, VectorWritable wordCountsWritable, Context context)
-      throws IOException, InterruptedException {
+  public void map(WritableComparable<?> key, VectorWritable wordCountsWritable, Context context) throws IOException,
+                                                                                                InterruptedException {
     Vector wordCounts = wordCountsWritable.get();
     LDAInference.InferredDocument doc = infer.infer(wordCounts);
-
+    
     double[] logTotals = new double[state.numTopics];
     Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
-
+    
     // Output sufficient statistics for each word. == pseudo-log counts.
     IntPairWritable kw = new IntPairWritable();
     DoubleWritable v = new DoubleWritable();
-    for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
-        iter.hasNext();) {
+    for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
       Vector.Element e = iter.next();
       int w = e.index();
       kw.setY(w);
       for (int k = 0; k < state.numTopics; ++k) {
         v.set(doc.phi(k, w) + Math.log(e.get()));
-
+        
         kw.setX(k);
-
+        
         // ouput (topic, word)'s logProb contribution
         context.write(kw, v);
         logTotals[k] = LDAUtil.logSum(logTotals[k], v.get());
@@ -77,19 +73,19 @@
       assert !Double.isNaN(v.get());
       context.write(kw, v);
     }
-
+    
     // Output log-likelihoods.
     kw.setX(LDADriver.LOG_LIKELIHOOD_KEY);
     kw.setY(LDADriver.LOG_LIKELIHOOD_KEY);
     v.set(doc.logLikelihood);
     context.write(kw, v);
   }
-
+  
   public void configure(LDAState myState) {
     this.state = myState;
     this.infer = new LDAInference(state);
   }
-
+  
   public void configure(Configuration job) {
     try {
       LDAState myState = LDADriver.createState(job);
@@ -98,11 +94,10 @@
       throw new IllegalStateException("Error creating LDA State!", e);
     }
   }
-
+  
   @Override
   protected void setup(Context context) {
     configure(context.getConfiguration());
   }
-
-
+  
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java Sat Feb 13 21:07:53 2010
@@ -19,20 +19,16 @@
 import org.apache.hadoop.io.DoubleWritable;
 import org.apache.hadoop.mapreduce.Reducer;
 
-
 /**
-* A very simple reducer which simply logSums the
-* input doubles and outputs a new double for sufficient
-* statistics, and sums log likelihoods.
-*/
-public class LDAReducer extends
-    Reducer<IntPairWritable, DoubleWritable, IntPairWritable, DoubleWritable> {
-
+ * A very simple reducer which simply logSums the input doubles and outputs a new double for sufficient
+ * statistics, and sums log likelihoods.
+ */
+public class LDAReducer extends Reducer<IntPairWritable,DoubleWritable,IntPairWritable,DoubleWritable> {
+  
   @Override
-  public void reduce(IntPairWritable topicWord, Iterable<DoubleWritable> values,
-      Context context) 
-      throws java.io.IOException, InterruptedException {
-
+  public void reduce(IntPairWritable topicWord, Iterable<DoubleWritable> values, Context context) throws java.io.IOException,
+                                                                                                 InterruptedException {
+    
     // sum likelihoods
     if (topicWord.getY() == LDADriver.LOG_LIKELIHOOD_KEY) {
       double accum = 0.0;
@@ -58,7 +54,7 @@
       }
       context.write(topicWord, new DoubleWritable(accum));
     }
-
+    
   }
-
+  
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java Sat Feb 13 21:07:53 2010
@@ -19,15 +19,19 @@
 import org.apache.mahout.math.Matrix;
 
 public class LDAState {
-  public final int numTopics; 
-  public final int numWords; 
+  public final int numTopics;
+  public final int numWords;
   public final double topicSmoothing;
   private final Matrix topicWordProbabilities; // log p(w|t) for topic=1..nTopics
   private final double[] logTotals; // log \sum p(w|t) for topic=1..nTopics
   public final double logLikelihood; // log \sum p(w|t) for topic=1..nTopics
-
-  public LDAState(int numTopics, int numWords, double topicSmoothing,
-      Matrix topicWordProbabilities, double[] logTotals, double ll) {
+  
+  public LDAState(int numTopics,
+                  int numWords,
+                  double topicSmoothing,
+                  Matrix topicWordProbabilities,
+                  double[] logTotals,
+                  double ll) {
     this.numWords = numWords;
     this.numTopics = numTopics;
     this.topicSmoothing = topicSmoothing;
@@ -35,10 +39,9 @@
     this.logTotals = logTotals;
     this.logLikelihood = ll;
   }
-
+  
   public double logProbWordGivenTopic(int word, int topic) {
     double logProb = topicWordProbabilities.getQuick(topic, word);
-    return logProb == Double.NEGATIVE_INFINITY ? -100.0 
-      : logProb - logTotals[topic];
+    return logProb == Double.NEGATIVE_INFINITY ? -100.0 : logProb - logTotals[topic];
   }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java Sat Feb 13 21:07:53 2010
@@ -20,17 +20,14 @@
  * Various utility classes for doing LDA inference..
  */
 final class LDAUtil {
-  private LDAUtil() {
-  } // no creation
-
+  private LDAUtil() { } // no creation
+  
   /**
    * @return log(exp(a) + exp(b))
    */
   static double logSum(double a, double b) {
-    return (a == Double.NEGATIVE_INFINITY) ? b
-      : (b == Double.NEGATIVE_INFINITY) ? a
-      : (a < b) ? b + Math.log(1 + Math.exp(a - b))
-      : a + Math.log(1 + Math.exp(b - a));    
+    return a == Double.NEGATIVE_INFINITY ? b : b == Double.NEGATIVE_INFINITY ? a
+        : a < b ? b + Math.log(1 + Math.exp(a - b)) : a + Math.log(1 + Math.exp(b - a));
   }
-
+  
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java?rev=909914&r1=909913&r2=909914&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java Sat Feb 13 21:07:53 2010
@@ -17,68 +17,63 @@
 
 package org.apache.mahout.clustering.meanshift;
 
-import com.google.gson.Gson;
-import com.google.gson.GsonBuilder;
-import com.google.gson.reflect.TypeToken;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.List;
+
 import org.apache.mahout.clustering.ClusterBase;
 import org.apache.mahout.math.CardinalityException;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.JsonVectorAdapter;
-import static org.apache.mahout.math.function.Functions.*;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
 
-import java.io.DataInput;
-import java.io.DataOutput;
-import java.io.IOException;
-import java.lang.reflect.Type;
-import java.util.ArrayList;
-import java.util.List;
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.reflect.TypeToken;
 
 /**
- * This class models a canopy as a center point, the number of points that are contained within it according to the
- * application of some distance metric, and a point total which is the sum of all the points and is used to compute the
- * centroid when needed.
+ * This class models a canopy as a center point, the number of points that are contained within it according
+ * to the application of some distance metric, and a point total which is the sum of all the points and is
+ * used to compute the centroid when needed.
  */
 public class MeanShiftCanopy extends ClusterBase {
-
-   // TODO: this is problematic, but how else to encode membership?
+  
+  // TODO: this is problematic, but how else to encode membership?
   private List<Vector> boundPoints = new ArrayList<Vector>();
-
+  
   private boolean converged = false;
-
+  
   public MeanShiftCanopy() {
     super();
   }
-
+  
   /** Create a new Canopy with the given canopyId */
   /*
-  public MeanShiftCanopy(String id) {
-    this.setId(Integer.parseInt(id.substring(1)));
-    this.setCenter(null);
-    this.setPointTotal(null);
-    this.setNumPoints(0);
-  }
-  */
+   * public MeanShiftCanopy(String id) { this.setId(Integer.parseInt(id.substring(1))); this.setCenter(null);
+   * this.setPointTotal(null); this.setNumPoints(0); }
+   */
 
   /**
    * Create a new Canopy containing the given point
-   *
-   * @param point a Vector
+   * 
+   * @param point
+   *          a Vector
    */
   /*
-  public MeanShiftCanopy(Vector point) {
-    this.setCenter(point);
-    this.setPointTotal(point.clone());
-    this.setNumPoints(1);
-    this.boundPoints.add(point);
-  }
-  */
+   * public MeanShiftCanopy(Vector point) { this.setCenter(point); this.setPointTotal(point.clone());
+   * this.setNumPoints(1); this.boundPoints.add(point); }
+   */
 
   /**
    * Create a new Canopy containing the given point
-   *
-   * @param point a Vector
+   * 
+   * @param point
+   *          a Vector
    */
   public MeanShiftCanopy(Vector point, int id) {
     this.setId(id);
@@ -90,14 +85,17 @@
   
   /**
    * Create a new Canopy containing the given point, id and bound points
-   *
-   * @param point       a Vector
-   * @param id          an int identifying the canopy local to this process only
-   * @param boundPoints a List<Vector> containing points bound to the canopy
-   * @param converged   true if the canopy has converged
+   * 
+   * @param point
+   *          a Vector
+   * @param id
+   *          an int identifying the canopy local to this process only
+   * @param boundPoints
+   *          a List<Vector> containing points bound to the canopy
+   * @param converged
+   *          true if the canopy has converged
    */
-  MeanShiftCanopy(Vector point, int id, List<Vector> boundPoints,
-                  boolean converged) {
+  MeanShiftCanopy(Vector point, int id, List<Vector> boundPoints, boolean converged) {
     this.setId(id);
     this.setCenter(point);
     this.setPointTotal(point.clone());
@@ -105,36 +103,39 @@
     this.boundPoints = boundPoints;
     this.converged = converged;
   }
-
+  
   /**
    * Add a point to the canopy some number of times
-   *
-   * @param point   a Vector to add
-   * @param nPoints the number of times to add the point
-   * @throws CardinalityException if the cardinalities disagree
+   * 
+   * @param point
+   *          a Vector to add
+   * @param nPoints
+   *          the number of times to add the point
+   * @throws CardinalityException
+   *           if the cardinalities disagree
    */
   void addPoints(Vector point, int nPoints) {
     setNumPoints(getNumPoints() + nPoints);
-    Vector subTotal = (nPoints == 1) ? point.clone() : point.times(nPoints);
-    setPointTotal((getPointTotal() == null) ? subTotal : getPointTotal().plus(subTotal));
+    Vector subTotal = nPoints == 1 ? point.clone() : point.times(nPoints);
+    setPointTotal(getPointTotal() == null ? subTotal : getPointTotal().plus(subTotal));
   }
-
+  
   /**
    * Compute the bound centroid by averaging the bound points
-   *
+   * 
    * @return a Vector which is the new bound centroid
    */
   public Vector computeBoundCentroid() {
     Vector result = new DenseVector(getCenter().size());
     for (Vector v : boundPoints) {
-      result.assign(v, plus);
+      result.assign(v, Functions.plus);
     }
     return result.divide(boundPoints.size());
   }
-
+  
   /**
    * Compute the centroid by normalizing the pointTotal
-   *
+   * 
    * @return a Vector which is the new centroid
    */
   @Override
@@ -145,55 +146,57 @@
       return getPointTotal().divide(getNumPoints());
     }
   }
-
+  
   public List<Vector> getBoundPoints() {
     return boundPoints;
   }
-
+  
   public int getCanopyId() {
     return getId();
   }
-
+  
   @Override
   public String getIdentifier() {
     return (converged ? "V" : "C") + getId();
   }
-
+  
   void init(MeanShiftCanopy canopy) {
     setId(canopy.getId());
     setCenter(canopy.getCenter());
     addPoints(getCenter(), 1);
     boundPoints.addAll(canopy.getBoundPoints());
   }
-
+  
   public boolean isConverged() {
     return converged;
   }
-
+  
   /**
    * The receiver overlaps the given canopy. Touch it and add my bound points to it.
-   *
-   * @param canopy an existing MeanShiftCanopy
+   * 
+   * @param canopy
+   *          an existing MeanShiftCanopy
    */
   void merge(MeanShiftCanopy canopy) {
     boundPoints.addAll(canopy.boundPoints);
   }
-
+  
   @Override
   public String toString() {
-    return formatCanopy(this);
+    return MeanShiftCanopy.formatCanopy(this);
   }
-
+  
   /**
    * The receiver touches the given canopy. Add respective centers.
-   *
-   * @param canopy an existing MeanShiftCanopy
+   * 
+   * @param canopy
+   *          an existing MeanShiftCanopy
    */
   void touch(MeanShiftCanopy canopy) {
     canopy.addPoints(getCenter(), boundPoints.size());
     addPoints(canopy.getCenter(), canopy.boundPoints.size());
   }
-
+  
   @Override
   public void readFields(DataInput in) throws IOException {
     super.readFields(in);
@@ -207,7 +210,7 @@
       this.boundPoints.add(temp.get());
     }
   }
-
+  
   @Override
   public void write(DataOutput out) throws IOException {
     super.write(out);
@@ -217,7 +220,7 @@
       VectorWritable.writeVector(out, v);
     }
   }
-
+  
   public MeanShiftCanopy shallowCopy() {
     MeanShiftCanopy result = new MeanShiftCanopy();
     result.setId(this.getId());
@@ -227,43 +230,42 @@
     result.boundPoints = this.boundPoints;
     return result;
   }
-
+  
   @Override
   public String asFormatString() {
-    return formatCanopy(this);
+    return MeanShiftCanopy.formatCanopy(this);
   }
   
   public void setBoundPoints(List<Vector> boundPoints) {
     this.boundPoints = boundPoints;
   }
-
+  
   public void setConverged(boolean converged) {
     this.converged = converged;
   }
-
+  
   /** Format the canopy for output */
   public static String formatCanopy(MeanShiftCanopy canopy) {
-    Type vectorType = new TypeToken<Vector>() {
-    }.getType();
+    Type vectorType = new TypeToken<Vector>() { }.getType();
     GsonBuilder gBuilder = new GsonBuilder();
     gBuilder.registerTypeAdapter(vectorType, new JsonVectorAdapter());
     Gson gson = gBuilder.create();
     return gson.toJson(canopy, MeanShiftCanopy.class);
   }
-
+  
   /**
    * Decodes and returns a Canopy from the formattedString
-   *
-   * @param formattedString a String produced by formatCanopy
+   * 
+   * @param formattedString
+   *          a String produced by formatCanopy
    * @return a new Canopy
    */
   public static MeanShiftCanopy decodeCanopy(String formattedString) {
-    Type vectorType = new TypeToken<Vector>() {
-    }.getType();
+    Type vectorType = new TypeToken<Vector>() { }.getType();
     GsonBuilder gBuilder = new GsonBuilder();
     gBuilder.registerTypeAdapter(vectorType, new JsonVectorAdapter());
     Gson gson = gBuilder.create();
     return gson.fromJson(formattedString, MeanShiftCanopy.class);
   }
-
+  
 }