You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2010/01/10 18:48:13 UTC

svn commit: r897668 - /lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java

Author: gsingers
Date: Sun Jan 10 17:48:13 2010
New Revision: 897668

URL: http://svn.apache.org/viewvc?rev=897668&view=rev
Log:
small clarification

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=897668&r1=897667&r2=897668&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Sun Jan 10 17:48:13 2010
@@ -47,11 +47,11 @@
 import org.slf4j.LoggerFactory;
 
 /**
-* Estimates an LDA model from a corpus of documents,
-* which are SparseVectors of word counts. At each
-* phase, it outputs a matrix of log probabilities of
-* each topic. 
-*/
+ * Estimates an LDA model from a corpus of documents,
+ * which are SparseVectors of word counts. At each
+ * phase, it outputs a matrix of log probabilities of
+ * each topic.
+ */
 public final class LDADriver {
 
   static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
@@ -72,47 +72,47 @@
   }
 
   public static void main(String[] args) throws ClassNotFoundException,
-      IOException, InterruptedException {
+          IOException, InterruptedException {
 
     DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
     ArgumentBuilder abuilder = new ArgumentBuilder();
     GroupBuilder gbuilder = new GroupBuilder();
 
     Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
-        abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+            abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+            "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
 
     Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
-        abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The Output Working Directory").withShortName("o").create();
+            abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+            "The Output Working Directory").withShortName("o").create();
 
     Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withDescription(
-        "If set, overwrite the output directory").withShortName("w").create();
+            "If set, overwrite the output directory").withShortName("w").create();
 
     Option topicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(
-        abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The number of topics").withShortName("k").create();
+            abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription(
+            "The number of topics").withShortName("k").create();
 
     Option wordsOpt = obuilder.withLongName("numWords").withRequired(true).withArgument(
-        abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
-        "The number of words in the corpus").withShortName("v").create();
+            abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
+            "The total number of words in the corpus").withShortName("v").create();
 
     Option topicSmOpt = obuilder.withLongName("topicSmoothing").withRequired(false).withArgument(abuilder
-        .withName("topicSmoothing").withDefault(-1.0).withMinimum(0).withMaximum(1).create()).withDescription(
-        "Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create();
+            .withName("topicSmoothing").withDefault(-1.0).withMinimum(0).withMaximum(1).create()).withDescription(
+            "Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create();
 
     Option maxIterOpt = obuilder.withLongName("maxIter").withRequired(false).withArgument(
-        abuilder.withName("maxIter").withDefault(-1).withMinimum(0).withMaximum(1).create()).withDescription(
-        "Max iterations to run (or until convergence). -1 (default) waits until convergence.").create();
+            abuilder.withName("maxIter").withDefault(-1).withMinimum(0).withMaximum(1).create()).withDescription(
+            "Max iterations to run (or until convergence). -1 (default) waits until convergence.").create();
 
     Option numReducOpt = obuilder.withLongName("numReducers").withRequired(false).withArgument(
-        abuilder.withName("numReducers").withDefault(10).withMinimum(0).withMaximum(1).create()).withDescription(
-        "Max iterations to run (or until convergence). Default 10").create();
+            abuilder.withName("numReducers").withDefault(10).withMinimum(0).withMaximum(1).create()).withDescription(
+            "Max iterations to run (or until convergence). Default 10").create();
 
     Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
 
     Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(
-        topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt).withOption(
+            topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt).withOption(
             numReducOpt).withOption(overwriteOutput).withOption(helpOpt).create();
     try {
       Parser parser = new Parser();
@@ -154,12 +154,12 @@
       if (cmdLine.hasOption(topicSmOpt)) {
         topicSmoothing = Double.parseDouble(cmdLine.getValue(maxIterOpt).toString());
       }
-      if(topicSmoothing < 1) {
+      if (topicSmoothing < 1) {
         topicSmoothing = 50.0 / numTopics;
       }
 
       runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations,
-          numReduceTasks);
+              numReduceTasks);
 
     } catch (OptionException e) {
       log.error("Exception", e);
@@ -170,18 +170,18 @@
   /**
    * Run the job using supplied arguments
    *
-   * @param input           the directory pathname for input points
-   * @param output          the directory pathname for output points
-   * @param numTopics       the number of topics
-   * @param numWords        the number of words
-   * @param topicSmoothing  pseudocounts for each topic, typically small &lt; .5
-   * @param maxIterations   the maximum number of iterations
-   * @param numReducers     the number of Reducers desired
-   * @throws IOException 
+   * @param input          the directory pathname for input points
+   * @param output         the directory pathname for output points
+   * @param numTopics      the number of topics
+   * @param numWords       the number of words
+   * @param topicSmoothing pseudocounts for each topic, typically small &lt; .5
+   * @param maxIterations  the maximum number of iterations
+   * @param numReducers    the number of Reducers desired
+   * @throws IOException
    */
-  public static void runJob(String input, String output, int numTopics, 
-      int numWords, double topicSmoothing, int maxIterations, int numReducers)
-      throws IOException, InterruptedException, ClassNotFoundException {
+  public static void runJob(String input, String output, int numTopics,
+                            int numWords, double topicSmoothing, int maxIterations, int numReducers)
+          throws IOException, InterruptedException, ClassNotFoundException {
 
     String stateIn = output + "/state-0";
     writeInitialState(stateIn, numTopics, numWords);
@@ -193,7 +193,7 @@
       // point the output to a new directory per iteration
       String stateOut = output + "/state-" + (iteration + 1);
       double ll = runIteration(input, stateIn, stateOut, numTopics,
-          numWords, topicSmoothing, numReducers);
+              numWords, topicSmoothing, numReducers);
       double relChange = (oldLL - ll) / oldLL;
 
       // now point the input to the old output directory
@@ -207,8 +207,8 @@
     }
   }
 
-  private static void writeInitialState(String statePath,  
-      int numTopics, int numWords) throws IOException {
+  private static void writeInitialState(String statePath,
+                                        int numTopics, int numWords) throws IOException {
     Path dir = new Path(statePath);
     Configuration job = new Configuration();
     FileSystem fs = dir.getFileSystem(job);
@@ -221,7 +221,7 @@
     for (int k = 0; k < numTopics; ++k) {
       Path path = new Path(dir, "part-" + k);
       SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path,
-          IntPairWritable.class, DoubleWritable.class);
+              IntPairWritable.class, DoubleWritable.class);
 
       kw.setX(k);
       double total = 0.0; // total number of pseudo counts we made
@@ -250,7 +250,7 @@
 
     IntPairWritable key = new IntPairWritable();
     DoubleWritable value = new DoubleWritable();
-    for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) { 
+    for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
       Path path = status.getPath();
       SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
       while (reader.next(key, value)) {
@@ -268,21 +268,21 @@
   /**
    * Run the job using supplied arguments
    *
-   * @param input         the directory pathname for input points
-   * @param stateIn       the directory pathname for input state
-   * @param stateOut      the directory pathname for output state
+   * @param input       the directory pathname for input points
+   * @param stateIn     the directory pathname for input state
+   * @param stateOut    the directory pathname for output state
    * @param numTopics   the number of clusters
-   * @param numReducers   the number of Reducers desired
+   * @param numReducers the number of Reducers desired
    */
   public static double runIteration(String input, String stateIn,
-      String stateOut, int numTopics, int numWords, double topicSmoothing,
-      int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+                                    String stateOut, int numTopics, int numWords, double topicSmoothing,
+                                    int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
     Configuration conf = new Configuration();
     conf.set(STATE_IN_KEY, stateIn);
     conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
     conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
     conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
-    
+
     Job job = new Job(conf);
 
     job.setOutputKeyClass(IntPairWritable.class);
@@ -298,7 +298,7 @@
     job.setOutputFormatClass(SequenceFileOutputFormat.class);
     job.setInputFormatClass(SequenceFileInputFormat.class);
     job.setJarByClass(LDADriver.class);
-    
+
     job.waitForCompletion(true);
     return findLL(stateOut, conf);
   }
@@ -318,7 +318,7 @@
 
     IntPairWritable key = new IntPairWritable();
     DoubleWritable value = new DoubleWritable();
-    for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) { 
+    for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
       Path path = status.getPath();
       SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
       while (reader.next(key, value)) {
@@ -349,6 +349,6 @@
     }
 
     return new LDAState(numTopics, numWords, topicSmoothing,
-        pWgT, logTotals, ll);
+            pWgT, logTotals, ll);
   }
 }