You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by jm...@apache.org on 2011/05/07 02:18:52 UTC

svn commit: r1100421 - in /mahout/trunk: core/src/main/java/org/apache/mahout/clustering/lda/ core/src/test/java/org/apache/mahout/clustering/lda/ core/src/test/java/org/apache/mahout/math/hadoop/decomposer/ utils/src/main/java/org/apache/mahout/cluste...

Author: jmannix
Date: Sat May  7 00:18:51 2011
New Revision: 1100421

URL: http://svn.apache.org/viewvc?rev=1100421&view=rev
Log:
Fixes MAHOUT-683 and MAHOUT-682

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADocumentTopicMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAWordTopicMapper.java
Removed:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
    mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
    mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java
    mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADocumentTopicMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADocumentTopicMapper.java?rev=1100421&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADocumentTopicMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADocumentTopicMapper.java Sat May  7 00:18:51 2011
@@ -0,0 +1,52 @@
+package org.apache.mahout.clustering.lda;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+public class LDADocumentTopicMapper extends Mapper<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable> {
+
+  private LDAState state;
+  private LDAInference infer;
+
+  @Override
+  protected void map(WritableComparable<?> key,
+                     VectorWritable wordCountsWritable,
+                     Context context) throws IOException, InterruptedException {
+
+    Vector wordCounts = wordCountsWritable.get();
+    LDAInference.InferredDocument doc;
+    try {
+      doc = infer.infer(wordCounts);
+      context.write(key, new VectorWritable(doc.getGamma().normalize(1)));
+    } catch (ArrayIndexOutOfBoundsException e1) {
+      throw new IllegalStateException(
+         "This is probably because the --numWords argument is set too small.  \n"
+         + "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
+         + "\tlarger if some storage inefficiency can be tolerated.", e1);
+    }
+  }
+
+  public void configure(LDAState myState) {
+    this.state = myState;
+    this.infer = new LDAInference(state);
+  }
+
+  public void configure(Configuration job) {
+    try {
+      LDAState myState = LDADriver.createState(job);
+      configure(myState);
+    } catch (IOException e) {
+      throw new IllegalStateException("Error creating LDA State!", e);
+    }
+  }
+
+  @Override
+  protected void setup(Context context) {
+    configure(context.getConfiguration());
+  }
+}

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Sat May  7 00:18:51 2011
@@ -17,21 +17,21 @@
 
 package org.apache.mahout.clustering.lda;
 
-import java.io.IOException;
-import java.util.Random;
-
+import com.google.common.base.Preconditions;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.DoubleWritable;
 import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapreduce.Job;
 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
 import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
 import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
 import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ReflectionUtils;
 import org.apache.hadoop.util.ToolRunner;
 import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.HadoopUtil;
@@ -43,10 +43,17 @@ import org.apache.mahout.common.iterator
 import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
 import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterator;
 import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
 
 /**
  * Estimates an LDA model from a corpus of documents, which are SparseVectors of word counts. At each phase,
@@ -57,6 +64,8 @@ public final class LDADriver extends Abs
   private static final String TOPIC_SMOOTHING_OPTION = "topicSmoothing";
   private static final String NUM_WORDS_OPTION = "numWords";
   private static final String NUM_TOPICS_OPTION = "numTopics";
+  // TODO: sequential iteration is not yet correct.
+  // private static final String SEQUENTIAL_OPTION = "sequential";
   static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
   static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
   static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
@@ -67,6 +76,14 @@ public final class LDADriver extends Abs
 
   private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
 
+  private LDAState state = null;
+
+  private LDAState newState = null;
+
+  private LDAInference inference = null;
+
+  private Iterable<Pair<Writable, VectorWritable>> trainingCorpus = null;
+
   private LDADriver() {
   }
 
@@ -74,7 +91,11 @@ public final class LDADriver extends Abs
     ToolRunner.run(new Configuration(), new LDADriver(), args);
   }
 
-  static LDAState createState(Configuration job) {
+  public static LDAState createState(Configuration job) throws IOException {
+    return createState(job, false);
+  }
+
+  public static LDAState createState(Configuration job, boolean empty) throws IOException {
     String statePath = job.get(STATE_IN_KEY);
     int numTopics = Integer.parseInt(job.get(NUM_TOPICS_KEY));
     int numWords = Integer.parseInt(job.get(NUM_WORDS_KEY));
@@ -82,10 +103,14 @@ public final class LDADriver extends Abs
 
     Path dir = new Path(statePath);
 
+    // TODO scalability bottleneck: numWords * numTopics * 8bytes for the driver *and* M/R classes
     DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
     double[] logTotals = new double[numTopics];
+    Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
     double ll = 0.0;
-
+    if(empty) {
+      return new LDAState(numTopics, numWords, topicSmoothing, pWgT, logTotals, ll);
+    }
     for (Pair<IntPairWritable,DoubleWritable> record
          : new SequenceFileDirIterable<IntPairWritable, DoubleWritable>(new Path(dir, "part-*"),
                                                                         PathType.GLOB,
@@ -126,6 +151,7 @@ public final class LDADriver extends Abs
               "v",
               "The total number of words in the corpus (can be approximate, needs to exceed the actual value)");
     addOption(TOPIC_SMOOTHING_OPTION, "a", "Topic smoothing parameter. Default is 50/numTopics.", "-1.0");
+    // addOption(SEQUENTIAL_OPTION, "seq", "Run sequentially (not Hadoop-based).  Default is false.", "false");
     addOption(DefaultOptionCreator.maxIterationsOption().withRequired(false).create());
 
     if (parseArguments(args) == null) {
@@ -144,30 +170,61 @@ public final class LDADriver extends Abs
     if (topicSmoothing < 1) {
       topicSmoothing = 50.0 / numTopics;
     }
+    boolean runSequential = false; // Boolean.parseBoolean(getOption(SEQUENTIAL_OPTION));
 
-    run(getConf(), input, output, numTopics, numWords, topicSmoothing, maxIterations);
+    run(getConf(), input, output, numTopics, numWords, topicSmoothing, maxIterations, runSequential);
 
     return 0;
   }
 
-  private static void run(Configuration conf,
+  private Path getLastKnownStatePath(Configuration conf, Path stateDir) throws IOException {
+    FileSystem fs = FileSystem.get(conf);
+    Path lastPath = null;
+    int maxIteration = Integer.MIN_VALUE;
+    for(FileStatus fstatus : fs.globStatus(new Path(stateDir, "state-*"))) {
+      try {
+        int iteration = Integer.parseInt(fstatus.getPath().getName().split("-")[1]);
+        if(iteration > maxIteration) {
+          maxIteration = iteration;
+          lastPath = fstatus.getPath();
+        }
+      } catch(NumberFormatException nfe) {
+        throw new IOException(nfe);
+      }
+    }
+    return lastPath;
+  }
+
+  private void run(Configuration conf,
                           Path input,
                           Path output,
                           int numTopics,
                           int numWords,
                           double topicSmoothing,
-                          int maxIterations)
+                          int maxIterations,
+                          boolean runSequential)
     throws IOException, InterruptedException, ClassNotFoundException {
-    Path stateIn = new Path(output, "state-0");
-    writeInitialState(stateIn, numTopics, numWords);
+    Path lastKnownState = getLastKnownStatePath(conf, output);
+    Path stateIn;
+    if(lastKnownState == null) {
+      stateIn = new Path(output, "state-0");
+      writeInitialState(stateIn, numTopics, numWords);
+    } else {
+      stateIn = lastKnownState;
+    }
+    conf.set(STATE_IN_KEY, stateIn.toString());
+    conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
+    conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
+    conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
     double oldLL = Double.NEGATIVE_INFINITY;
     boolean converged = false;
-
-    for (int iteration = 1; (maxIterations < 1 || iteration <= maxIterations) && !converged; iteration++) {
+    int iteration = Integer.parseInt(stateIn.getName().split("-")[1]) + 1;
+    for (; ((maxIterations < 1) || (iteration <= maxIterations)) && !converged; iteration++) {
       log.info("LDA Iteration {}", iteration);
+      conf.set(STATE_IN_KEY, stateIn.toString());
       // point the output to a new directory per iteration
       Path stateOut = new Path(output, "state-" + iteration);
-      double ll = runIteration(conf, input, stateIn, stateOut, numTopics, numWords, topicSmoothing);
+      double ll = runSequential ? runIterationSequential(conf, input, stateOut) : runIteration(conf, input, stateIn, stateOut);
       double relChange = (oldLL - ll) / oldLL;
 
       // now point the input to the old output directory
@@ -175,13 +232,18 @@ public final class LDADriver extends Abs
       log.info("(Old LL: {})", oldLL);
       log.info("(Rel Change: {})", relChange);
 
-      converged = iteration > 3 && relChange < OVERALL_CONVERGENCE;
+      converged = (iteration > 3) && (relChange < OVERALL_CONVERGENCE);
       stateIn = stateOut;
       oldLL = ll;
     }
+    if(runSequential) {
+      computeDocumentTopicProbabilitiesSequential(conf, input, new Path(output, "docTopics"));
+    } else {
+      computeDocumentTopicProbabilities(conf, input, stateIn, new Path(output, "docTopics"), numTopics, numWords, topicSmoothing);
+    }
   }
 
-  private static void writeInitialState(Path statePath, int numTopics, int numWords) throws IOException {
+  private void writeInitialState(Path statePath, int numTopics, int numWords) throws IOException {
     Configuration job = new Configuration();
     FileSystem fs = statePath.getFileSystem(job);
 
@@ -210,7 +272,33 @@ public final class LDADriver extends Abs
     }
   }
 
-  private static double findLL(Path statePath, Configuration job) throws IOException {
+  private void writeState(Configuration job, LDAState state, Path statePath) throws IOException {
+    FileSystem fs = statePath.getFileSystem(job);
+    DoubleWritable v = new DoubleWritable();
+
+    for (int k = 0; k < state.getNumTopics(); ++k) {
+      Path path = new Path(statePath, "part-" + k);
+      SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class);
+
+      for (int w = 0; w < state.getNumWords(); ++w) {
+        Writable kw = new IntPairWritable(k, w);
+        v.set(state.logProbWordGivenTopic(w,k) + state.getLogTotal(k));
+        writer.append(kw, v);
+      }
+      Writable kTsk = new IntPairWritable(k, TOPIC_SUM_KEY);
+      v.set(state.getLogTotal(k));
+      writer.append(kTsk, v);
+      writer.close();
+    }
+    Path path = new Path(statePath, "part-" + LOG_LIKELIHOOD_KEY);
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class);
+    Writable kTsk = new IntPairWritable(LOG_LIKELIHOOD_KEY,LOG_LIKELIHOOD_KEY);
+    v.set(state.getLogLikelihood());
+    writer.append(kTsk, v);
+    writer.close();
+  }
+
+  private double findLL(Path statePath, Configuration job) throws IOException {
     FileSystem fs = statePath.getFileSystem(job);
     double ll = 0.0;
     for (FileStatus status : fs.globStatus(new Path(statePath, "part-*"))) {
@@ -229,6 +317,66 @@ public final class LDADriver extends Abs
     return ll;
   }
 
+  private double runIterationSequential(Configuration conf, Path input, Path stateOut)
+    throws IOException, InterruptedException {
+    if(state == null) {
+      state = createState(conf);
+    }
+    if(trainingCorpus == null) {
+      Class<? extends Writable> keyClass = peekAtSequenceFileForKeyType(conf, input);
+      List<Pair<Writable,VectorWritable>> corpus = new LinkedList<Pair<Writable, VectorWritable>>();
+      for(FileStatus fileStatus : FileSystem.get(conf).globStatus(new Path(input, "part-*"))) {
+        Path inputPart = fileStatus.getPath();
+        SequenceFile.Reader reader = new SequenceFile.Reader(FileSystem.get(conf), inputPart, conf);
+        Writable key = ReflectionUtils.newInstance(keyClass, conf);
+        VectorWritable value = new VectorWritable();
+        while(reader.next(key, value)) {
+          Writable nextKey = ReflectionUtils.newInstance(keyClass, conf);
+          VectorWritable nextValue = new VectorWritable();
+          corpus.add(new Pair<Writable,VectorWritable>(key, value));
+          key = nextKey;
+          value = nextValue;
+        }
+      }
+      trainingCorpus = corpus;
+    }
+    if(inference == null) {
+      inference = new LDAInference(state);
+    }
+    double ll = 0;
+    newState = createState(conf, true);
+    for(Pair<Writable, VectorWritable> slice : trainingCorpus) {
+      LDAInference.InferredDocument doc;
+      Vector wordCounts = slice.getSecond().get();
+      try {
+        doc = inference.infer(wordCounts);
+      } catch (ArrayIndexOutOfBoundsException e1) {
+        throw new IllegalStateException(
+         "This is probably because the --numWords argument is set too small.  \n"
+         + "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
+         + "\tlarger if some storage inefficiency can be tolerated.", e1);
+      }
+
+      for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
+        Vector.Element e = iter.next();
+        int w = e.index();
+
+        for (int k = 0; k < state.getNumTopics(); ++k) {
+          double vwUpdate = doc.phi(k, w) + Math.log(e.get());
+          newState.updateLogProbGivenTopic(w, k, vwUpdate); // update state.topicWordProbabilities[v,w]!
+          newState.updateLogTotals(k, vwUpdate);
+        }
+        ll += doc.getLogLikelihood();
+      }
+    }
+    newState.setLogLikelihood(ll);
+    writeState(conf, newState, stateOut);
+    state = newState;
+    newState = null;
+
+    return ll;
+  }
+
   /**
    * Run the job using supplied arguments
    * @param input
@@ -237,21 +385,13 @@ public final class LDADriver extends Abs
    *          the directory pathname for input state
    * @param stateOut
    *          the directory pathname for output state
-   * @param numTopics
-   *          the number of clusters
    */
-  private static double runIteration(Configuration conf,
+  private double runIteration(Configuration conf,
                                      Path input,
                                      Path stateIn,
-                                     Path stateOut,
-                                     int numTopics,
-                                     int numWords,
-                                     double topicSmoothing)
+                                     Path stateOut)
     throws IOException, InterruptedException, ClassNotFoundException {
     conf.set(STATE_IN_KEY, stateIn.toString());
-    conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
-    conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
-    conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
 
     Job job = new Job(conf, "LDA Driver running runIteration over stateIn: " + stateIn);
     job.setOutputKeyClass(IntPairWritable.class);
@@ -259,7 +399,7 @@ public final class LDADriver extends Abs
     FileInputFormat.addInputPaths(job, input.toString());
     FileOutputFormat.setOutputPath(job, stateOut);
 
-    job.setMapperClass(LDAMapper.class);
+    job.setMapperClass(LDAWordTopicMapper.class);
     job.setReducerClass(LDAReducer.class);
     job.setCombinerClass(LDAReducer.class);
     job.setOutputFormatClass(SequenceFileOutputFormat.class);
@@ -271,4 +411,69 @@ public final class LDADriver extends Abs
     }
     return findLL(stateOut, conf);
   }
+
+  private void computeDocumentTopicProbabilities(Configuration conf,
+                                     Path input,
+                                     Path stateIn,
+                                     Path outputPath,
+                                     int numTopics,
+                                     int numWords,
+                                     double topicSmoothing)
+    throws IOException, InterruptedException, ClassNotFoundException {
+    conf.set(STATE_IN_KEY, stateIn.toString());
+    conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
+    conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
+    conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
+
+    Job job = new Job(conf, "LDA Driver computing p(topic|doc) for all docs/topics with stateIn: " + stateIn);
+    job.setOutputKeyClass(peekAtSequenceFileForKeyType(conf, input));
+    job.setOutputValueClass(VectorWritable.class);
+    FileInputFormat.addInputPaths(job, input.toString());
+    FileOutputFormat.setOutputPath(job, outputPath);
+
+    job.setMapperClass(LDADocumentTopicMapper.class);
+    job.setNumReduceTasks(0);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    job.setInputFormatClass(SequenceFileInputFormat.class);
+    job.setJarByClass(LDADriver.class);
+
+    if (job.waitForCompletion(true) == false) {
+      throw new InterruptedException("LDA failed to compute and output document topic probabilities with: "+ stateIn);
+    }
+  }
+
+  private void computeDocumentTopicProbabilitiesSequential(Configuration conf, Path input, Path outputPath)
+    throws IOException, ClassNotFoundException {
+    FileSystem fs = input.getFileSystem(conf);
+    Class<? extends Writable> keyClass = peekAtSequenceFileForKeyType(conf, input);
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, outputPath, keyClass, VectorWritable.class);
+
+    Writable key = ReflectionUtils.newInstance(keyClass, conf);
+    VectorWritable vw = new VectorWritable();
+
+    for(Pair<Writable, VectorWritable> slice : trainingCorpus) {
+      LDAInference.InferredDocument doc;
+      Vector wordCounts = slice.getSecond().get();
+      try {
+        doc = inference.infer(wordCounts);
+      } catch (ArrayIndexOutOfBoundsException e1) {
+        throw new IllegalStateException(
+         "This is probably because the --numWords argument is set too small.  \n"
+         + "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
+         + "\tlarger if some storage inefficiency can be tolerated.", e1);
+      }
+      writer.append(key, vw);
+    }
+
+    writer.close();
+  }
+
+  private static Class<? extends Writable> peekAtSequenceFileForKeyType(Configuration conf, Path input) {
+    try {
+      SequenceFile.Reader reader = new SequenceFile.Reader(FileSystem.get(conf), input, conf);
+      return (Class<? extends Writable>) reader.getKeyClass();
+    } catch(IOException ioe) {
+      return Text.class;
+    }
+  }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java Sat May  7 00:18:51 2011
@@ -43,8 +43,8 @@ public class LDAInference {
   }
   
   /**
-   * An estimate of the probabilitys for each document. Gamma(k) is the probability of seeing topic k in the
-   * document, phi(k,w) is the probability of topic k generating w in this document.
+   * An estimate of the probabilities for each document. Gamma(k) is the probability of seeing topic k in the
+   * document, phi(k,w) is the (log) probability of topic k generating w in this document.
    */
   public static class InferredDocument {
     
@@ -137,7 +137,11 @@ public class LDAInference {
     
     return new InferredDocument(wordCounts, gamma, map, phi, oldLL);
   }
-  
+
+  /**
+   * @param gamma
+   * @return a vector whose entries are digamma(oldEntry) - digamma(gamma.zSum())
+   */
   private Vector digammaGamma(Vector gamma) {
     // digamma is expensive, precompute
     Vector digammaGamma = digamma(gamma);
@@ -156,7 +160,21 @@ public class LDAInference {
       phi.assign(0);
     }
   }
-  
+
+  /**
+   * diGamma(x) = gamma'(x)/gamma(x)
+   * logGamma(x) = log(gamma(x))
+   *
+   * ll = log(gamma(smooth*numTop) / smooth^numTop) +
+   *   sum_{i < numTop} (smooth - g[i])*(digamma(g[i]) - digamma(|g|)) + log(gamma(g[i])
+   * Computes the log likelihood of the wordCounts vector, given \phi, \gamma, and \digamma(gamma)
+   * @param wordCounts
+   * @param map
+   * @param phi
+   * @param gamma
+   * @param digammaGamma
+   * @return
+   */
   private double computeLikelihood(Vector wordCounts, int[] map, Matrix phi, Vector gamma, Vector digammaGamma) {
     double ll = 0.0;
     

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java Sat May  7 00:18:51 2011
@@ -25,7 +25,7 @@ public class LDAState {
   private final double topicSmoothing;
   private final Matrix topicWordProbabilities; // log p(w|t) for topic=1..nTopics
   private final double[] logTotals; // log \sum p(w|t) for topic=1..nTopics
-  private final double logLikelihood; // log \sum p(w|t) for topic=1..nTopics
+  private double logLikelihood; // log \sum p(w|t) for topic=1..nTopics
   
   public LDAState(int numTopics,
                   int numWords,
@@ -42,10 +42,14 @@ public class LDAState {
   }
   
   public double logProbWordGivenTopic(int word, int topic) {
-    double logProb = topicWordProbabilities.getQuick(topic, word);
+    double logProb = topicWordProbabilities.get(topic, word);
     return logProb == Double.NEGATIVE_INFINITY ? -100.0 : logProb - logTotals[topic];
   }
 
+  public double getLogTotal(int topic) {
+    return logTotals[topic];
+  }
+
   public int getNumTopics() {
     return numTopics;
   }
@@ -62,4 +66,16 @@ public class LDAState {
     return logLikelihood;
   }
 
+  public void updateLogProbGivenTopic(int word, int topic, double logProbGivenTopic) {
+    topicWordProbabilities.set(topic, word, LDAUtil.logSum(logProbGivenTopic, topicWordProbabilities.getQuick(topic, word)));
+  }
+
+  public void updateLogTotals(int topic, double logTotal) {
+    logTotals[topic] = LDAUtil.logSum(logTotals[topic], logTotal);
+  }
+
+  public void setLogLikelihood(double logLikelihood) {
+    this.logLikelihood = logLikelihood;
+  }
+
 }

Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAWordTopicMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAWordTopicMapper.java?rev=1100421&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAWordTopicMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAWordTopicMapper.java Sat May  7 00:18:51 2011
@@ -0,0 +1,110 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.IntPairWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * Runs inference on the input documents (which are sparse vectors of word counts) and outputs the sufficient
+ * statistics for the word-topic assignments.
+ */
+public class LDAWordTopicMapper extends Mapper<WritableComparable<?>,VectorWritable,IntPairWritable,DoubleWritable> {
+  
+  private LDAState state;
+  private LDAInference infer;
+  
+  @Override
+  protected void map(WritableComparable<?> key,
+                     VectorWritable wordCountsWritable,
+                     Context context) throws IOException, InterruptedException {
+    Vector wordCounts = wordCountsWritable.get();
+    LDAInference.InferredDocument doc;
+    try {
+      doc = infer.infer(wordCounts);
+    } catch (ArrayIndexOutOfBoundsException e1) {
+      throw new IllegalStateException(
+         "This is probably because the --numWords argument is set too small.  \n"
+         + "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
+         + "\tlarger if some storage inefficiency can be tolerated.", e1);
+    }
+    
+    double[] logTotals = new double[state.getNumTopics()];
+    Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
+    
+    // Output sufficient statistics for each word. == pseudo-log counts.
+    DoubleWritable v = new DoubleWritable();
+    for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
+      Vector.Element e = iter.next();
+      int w = e.index();
+      
+      for (int k = 0; k < state.getNumTopics(); ++k) {
+        v.set(doc.phi(k, w) + Math.log(e.get()));
+        
+        IntPairWritable kw = new IntPairWritable(k, w);
+        
+        // output (topic, word)'s logProb contribution
+        context.write(kw, v);
+        logTotals[k] = LDAUtil.logSum(logTotals[k], v.get());
+      }
+    }
+
+    // Output the totals for the statistics. This is to make
+    // normalizing a lot easier.
+    for (int k = 0; k < state.getNumTopics(); ++k) {
+      IntPairWritable kw = new IntPairWritable(k, LDADriver.TOPIC_SUM_KEY);
+      v.set(logTotals[k]);
+      assert !Double.isNaN(v.get());
+      context.write(kw, v);
+    }
+    IntPairWritable llk = new IntPairWritable(LDADriver.LOG_LIKELIHOOD_KEY, LDADriver.LOG_LIKELIHOOD_KEY);
+    // Output log-likelihoods.
+    v.set(doc.getLogLikelihood());
+    context.write(llk, v);
+  }
+  
+  public void configure(LDAState myState) {
+    this.state = myState;
+    this.infer = new LDAInference(state);
+  }
+  
+  public void configure(Configuration job) {
+    try {
+      LDAState myState = LDADriver.createState(job);
+      configure(myState);
+    } catch (IOException e) {
+      throw new IllegalStateException("Error creating LDA State!", e);
+    }
+  }
+  
+  @Override
+  protected void setup(Context context) {
+    configure(context.getConfiguration());
+  }
+  
+}

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java Sat May  7 00:18:51 2011
@@ -93,12 +93,12 @@ public final class TestMapReduce extends
   @Test
   public void testMapper() throws Exception {
     LDAState state = generateRandomState(100,NUM_TOPICS);
-    LDAMapper mapper = new LDAMapper();
+    LDAWordTopicMapper mapper = new LDAWordTopicMapper();
     mapper.configure(state);
     for(int i = 0; i < NUM_TESTS; ++i) {
       RandomAccessSparseVector v = generateRandomDoc(100,0.3);
       int myNumWords = numNonZero(v);
-      LDAMapper.Context mock = EasyMock.createMock(LDAMapper.Context.class);
+      LDAWordTopicMapper.Context mock = EasyMock.createMock(LDAWordTopicMapper.Context.class);
 
       mock.write(EasyMock.isA(IntPairWritable.class), EasyMock.isA(DoubleWritable.class));
       EasyMock.expectLastCall().times(myNumWords * NUM_TOPICS + NUM_TOPICS + 1);

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java Sat May  7 00:18:51 2011
@@ -33,8 +33,8 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
+import java.util.Arrays;
 
 public final class TestDistributedLanczosSolverCLI extends MahoutTestCase {
   private static final Logger log = LoggerFactory.getLogger(TestDistributedLanczosSolverCLI.class);

Modified: mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
URL: http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java (original)
+++ mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java Sat May  7 00:18:51 2011
@@ -17,19 +17,6 @@
 
 package org.apache.mahout.clustering.lda;
 
-import java.io.File;
-import java.io.IOException;
-import java.io.Writer;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.PriorityQueue;
-import java.util.Queue;
-
-import com.google.common.base.Charsets;
-import com.google.common.io.Files;
 import org.apache.commons.cli2.CommandLine;
 import org.apache.commons.cli2.Group;
 import org.apache.commons.cli2.Option;
@@ -48,6 +35,19 @@ import org.apache.mahout.common.iterator
 import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
 import org.apache.mahout.utils.vectors.VectorHelper;
 
+import java.io.File;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Queue;
+
 /**
  * Class to print out the top K words for each topic.
  */
@@ -56,7 +56,7 @@ public final class LDAPrintTopics {
   private LDAPrintTopics() { }  
   
   private static class StringDoublePair implements Comparable<StringDoublePair> {
-    private final double score;
+    private double score;
     private final String word;
     
     StringDoublePair(double score, String word) {
@@ -153,18 +153,16 @@ public final class LDAPrintTopics {
         throw new IllegalArgumentException("Invalid dictionary format");
       }
       
-      List<List<String>> topWords = topWordsForTopics(input, config, wordList, numWords);
-      
+      List<PriorityQueue<StringDoublePair>> topWords = topWordsForTopics(input, config, wordList, numWords);
+
+      File output = null;
       if (cmdLine.hasOption(outOpt)) {
-        File output = new File(cmdLine.getValue(outOpt).toString());
+        output = new File(cmdLine.getValue(outOpt).toString());
         if (!output.exists() && !output.mkdirs()) {
           throw new IOException("Could not create directory: " + output);
         }
-        writeTopWords(topWords, output);
-      } else {
-        printTopWords(topWords);
       }
-      
+      printTopWords(topWords, output);
     } catch (OptionException e) {
       CommandLineUtil.printHelp(group);
       throw e;
@@ -181,63 +179,60 @@ public final class LDAPrintTopics {
     }
   }
   
-  private static void printTopWords(List<List<String>> topWords) {
+  private static void printTopWords(List<PriorityQueue<StringDoublePair>> topWords, File outputDir)
+    throws IOException {
     for (int i = 0; i < topWords.size(); ++i) {
-      List<String> topK = topWords.get(i);
-      System.out.println("Topic " + i);
-      System.out.println("===========");
-      for (String word : topK) {
-        System.out.println(word);
+      PriorityQueue<StringDoublePair> topK = topWords.get(i);
+      PrintWriter out;
+      if(outputDir != null) {
+        out = new PrintWriter(new File(outputDir, "topic_" + i));
+      } else {
+        out = new PrintWriter(System.out);
+        out.println("Topic " + i);
+        out.println("===========");
+      }
+      List<StringDoublePair> topKasList = new ArrayList<StringDoublePair>(topK.size());
+      for(StringDoublePair wordWithScore : topK) {
+        topKasList.add(wordWithScore);
       }
+      Collections.sort(topKasList, Collections.reverseOrder());
+      for(StringDoublePair wordWithScore : topKasList) {
+        out.println(wordWithScore.word + " [p(" + wordWithScore.word + "|topic_" + i +") = "
+         + wordWithScore.score);
+      }
+      out.close();
     }
   }
   
-  private static List<List<String>> topWordsForTopics(String dir,
+  private static List<PriorityQueue<StringDoublePair>> topWordsForTopics(String dir,
                                                       Configuration job,
                                                       List<String> wordList,
                                                       int numWordsToPrint) {
     List<PriorityQueue<StringDoublePair>> queues = new ArrayList<PriorityQueue<StringDoublePair>>();
-
+    Map<Integer,Double> expSums = new HashMap<Integer, Double>();
     for (Pair<IntPairWritable,DoubleWritable> record :
          new SequenceFileDirIterable<IntPairWritable, DoubleWritable>(
              new Path(dir, "part-*"), PathType.GLOB, null, null, true, job)) {
       IntPairWritable key = record.getFirst();
       int topic = key.getFirst();
       int word = key.getSecond();
-
       ensureQueueSize(queues, topic);
       if (word >= 0 && topic >= 0) {
         double score = record.getSecond().get();
+        if(expSums.get(topic) == null) {
+          expSums.put(topic, 0d);
+        }
+        expSums.put(topic, expSums.get(topic) + Math.exp(score));
         String realWord = wordList.get(word);
         maybeEnqueue(queues.get(topic), realWord, score, numWordsToPrint);
       }
     }
-    
-    List<List<String>> result = new ArrayList<List<String>>();
-    for (int i = 0; i < queues.size(); ++i) {
-      result.add(i, new LinkedList<String>());
-      for (StringDoublePair sdp : queues.get(i)) {
-        result.get(i).add(0, sdp.word); // prepend
+    for(int i=0; i<queues.size(); i++) {
+      PriorityQueue<StringDoublePair> queue = queues.get(i);
+      for(StringDoublePair pair : queue) {
+        pair.score = Math.exp(pair.score) / expSums.get(i);
       }
     }
-    
-    return result;
+    return queues;
   }
-  
-  private static void writeTopWords(List<List<String>> topWords, File output) throws IOException {
-    for (int i = 0; i < topWords.size(); ++i) {
-      List<String> topK = topWords.get(i);
-      Writer writer = Files.newWriter(new File(output, "topic-" + i), Charsets.UTF_8);
-      try {
-        writer.write("Topic " + i + '\n');
-        writer.write("===========\n");
-        for (String word : topK) {
-          writer.write(word + '\n');
-        }
-      } finally {
-        writer.close();
-      }
-    }
-  }
-  
 }