You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/08/17 15:33:19 UTC

svn commit: r804979 - in /lucene/mahout/trunk: core/ core/src/main/java/org/apache/mahout/clustering/lda/ core/src/test/java/org/apache/mahout/clustering/lda/ examples/ examples/bin/

Author: gsingers
Date: Mon Aug 17 13:33:18 2009
New Revision: 804979

URL: http://svn.apache.org/viewvc?rev=804979&view=rev
Log:
MAHOUT-123: Latent Dirichlet Allocation

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java   (with props)
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java   (with props)
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java   (with props)
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java   (with props)
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java   (with props)
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java   (with props)
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java   (with props)
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java   (with props)
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java   (with props)
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java   (with props)
    lucene/mahout/trunk/examples/bin/
    lucene/mahout/trunk/examples/bin/build-reuters.sh   (with props)
    lucene/mahout/trunk/examples/bin/lda.algorithm
Modified:
    lucene/mahout/trunk/core/pom.xml
    lucene/mahout/trunk/examples/pom.xml

Modified: lucene/mahout/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/pom.xml?rev=804979&r1=804978&r2=804979&view=diff
==============================================================================
--- lucene/mahout/trunk/core/pom.xml (original)
+++ lucene/mahout/trunk/core/pom.xml Mon Aug 17 13:33:18 2009
@@ -454,6 +454,9 @@
       <version>1.1.1</version>
     </dependency>
 
+    
+
+
     <dependency>
       <groupId>commons-httpclient</groupId>
       <artifactId>commons-httpclient</artifactId>
@@ -548,12 +551,25 @@
       <version>2.0-mahout</version>
     </dependency>
     <dependency>
+      <groupId>commons-math</groupId>
+      <artifactId>commons-math</artifactId>
+      <version>1.2</version>
+    </dependency>
+
+    <dependency>
       <groupId>junit</groupId>
       <artifactId>junit</artifactId>
       <version>3.8.2</version>
       <scope>test</scope>
     </dependency>
     
+    <dependency>
+      <groupId>org.easymock</groupId>
+      <artifactId>easymockclassextension</artifactId>
+      <version>2.2</version>
+      <scope>test</scope>
+    </dependency>
+    
     <!--  Gson: Java to Json conversion -->
     <dependency>
       <groupId>com.google.code.gson</groupId>

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,121 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.io.WritableComparator;
+
+/**
+* Saves two ints, x and y.
+*/
+public class IntPairWritable implements WritableComparable<IntPairWritable> {
+
+  private int x;
+  private int y;
+
+  /** For serialization purposes only */
+  public IntPairWritable() {
+  }
+
+  public IntPairWritable(int x, int y) {
+    this.x = x;
+    this.y = y;
+  }
+
+  public void setX(int x) {
+    this.x = x;
+  }
+
+  public int getX() {
+    return x;
+  }
+
+  public void setY(int y) {
+    this.y = y;
+  }
+
+  public int getY() {
+    return y;
+  }
+
+  @Override
+  public void write(DataOutput dataOutput) throws IOException {
+    dataOutput.writeInt(x);
+    dataOutput.writeInt(y);
+  }
+
+  @Override
+  public void readFields(DataInput dataInput) throws IOException {
+    x = dataInput.readInt();
+    y = dataInput.readInt();
+  }
+
+  public int compareTo(IntPairWritable that) {
+    int xdiff = this.x - that.x;
+    return (xdiff != 0) ? xdiff : this.y - that.y;
+  }
+
+  public boolean equals(Object o) {
+    if (this == o) { 
+      return true;
+    } else if (!(o instanceof IntPairWritable)) {
+      return false;
+    }
+
+    IntPairWritable that = (IntPairWritable) o;
+
+    return that.x == this.x && this.y == that.y;
+  }
+
+  @Override
+  public int hashCode() {
+    return 43 * x + y;
+  }
+
+  @Override
+  public String toString() {
+    return "(" + x + ", " + y + ")";
+  }
+
+  static {
+    WritableComparator.define(IntPairWritable.class, new Comparator());
+  }
+
+  public static class Comparator extends WritableComparator {
+    public Comparator() {
+      super(IntPairWritable.class);
+    }
+
+    public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
+      assert l1 == 8;
+      int int11 = readInt(b1, s1);
+      int int21 = readInt(b2, s2);
+      if (int11 != int21) {
+        return int11 - int21;
+      }
+
+      int int12 = readInt(b1, s1 + 4);
+      int int22 = readInt(b2, s2 + 4);
+      return int12 - int22;
+    }
+  }
+}

Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,349 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.io.IOException;
+import java.util.Random;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.matrix.DenseMatrix;
+import org.apache.mahout.utils.CommandLineUtil;
+import org.apache.mahout.utils.HadoopUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+* Estimates an LDA model from a corpus of documents,
+* which are SparseVectors of word counts. At each
+* phase, it outputs a matrix of log probabilities of
+* each topic. 
+*/
+public final class LDADriver {
+
+  static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
+
+  static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
+  static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
+
+  static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";
+
+  static final int LOG_LIKELIHOOD_KEY = -2;
+  static final int TOPIC_SUM_KEY = -1;
+
+  static final double OVERALL_CONVERGENCE = 1E-5;
+
+  private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
+
+  private LDADriver() {
+  }
+
+  public static void main(String[] args) throws InstantiationException,
+      IllegalAccessException, ClassNotFoundException,
+      IOException, InterruptedException {
+
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+
+    Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
+        abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+
+    Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
+        abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The Output Working Directory").withShortName("o").create();
+
+    Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withDescription(
+        "If set, overwrite the output directory").withShortName("w").create();
+
+    Option topicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(
+        abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The number of topics").withShortName("k").create();
+
+    Option wordsOpt = obuilder.withLongName("numWords").withRequired(true).withArgument(
+        abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The number of words in the corpus").withShortName("v").create();
+
+    Option topicSmOpt = obuilder.withLongName("topicSmoothing").withRequired(false).withArgument(abuilder
+        .withName("topicSmoothing").withDefault(-1.0).withMinimum(0).withMaximum(1).create()).withDescription(
+        "Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create();
+
+    Option maxIterOpt = obuilder.withLongName("maxIter").withRequired(false).withArgument(
+        abuilder.withName("maxIter").withDefault(-1).withMinimum(0).withMaximum(1).create()).withDescription(
+        "Max iterations to run (or until convergence). -1 (default) waits until convergence.").create();
+
+    Option numReducOpt = obuilder.withLongName("numReducers").withRequired(false).withArgument(
+        abuilder.withName("numReducers").withDefault(10).withMinimum(0).withMaximum(1).create()).withDescription(
+        "Max iterations to run (or until convergence). Default 10").create();
+
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
+
+    Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(
+        topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt).withOption(
+            numReducOpt).withOption(overwriteOutput).withOption(helpOpt).create();
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        return;
+      }
+      String input = cmdLine.getValue(inputOpt).toString();
+      String output = cmdLine.getValue(outputOpt).toString();
+
+      int maxIterations = -1;
+      if (cmdLine.hasOption(maxIterOpt)) {
+        maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
+      }
+
+      int numReduceTasks = 2;
+      if (cmdLine.hasOption(numReducOpt)) {
+        numReduceTasks = Integer.parseInt(cmdLine.getValue(numReducOpt).toString());
+      }
+
+      int numTopics = 20;
+      if (cmdLine.hasOption(topicsOpt)) {
+        numTopics = Integer.parseInt(cmdLine.getValue(topicsOpt).toString());
+      }
+
+      int numWords = 20;
+      if (cmdLine.hasOption(wordsOpt)) {
+        numWords = Integer.parseInt(cmdLine.getValue(wordsOpt).toString());
+      }
+
+      if (cmdLine.hasOption(overwriteOutput)) {
+        HadoopUtil.overwriteOutput(output);
+      }
+
+      double topicSmoothing = -1.0;
+      if (cmdLine.hasOption(topicSmOpt)) {
+        topicSmoothing = Double.parseDouble(cmdLine.getValue(maxIterOpt).toString());
+      }
+      if(topicSmoothing < 1) {
+        topicSmoothing = 50. / numTopics;
+      }
+
+      runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations,
+          numReduceTasks);
+
+    } catch (OptionException e) {
+      log.error("Exception", e);
+      CommandLineUtil.printHelp(group);
+    }
+  }
+
+  /**
+   * Run the job using supplied arguments
+   *
+   * @param input           the directory pathname for input points
+   * @param output          the directory pathname for output points
+   * @param numTopics       the number of topics
+   * @param numWords        the number of words
+   * @param topicSmoothing  pseudocounts for each topic, typically small &lt; .5
+   * @param maxIterations   the maximum number of iterations
+   * @param numReducers     the number of Reducers desired
+   * @throws IOException 
+   */
+  public static void runJob(String input, String output, int numTopics, 
+      int numWords, double topicSmoothing, int maxIterations, int numReducers)
+      throws IOException, InterruptedException, ClassNotFoundException {
+
+    String stateIn = output + "/state-0";
+    writeInitialState(stateIn, numTopics, numWords);
+    double oldLL = Double.NEGATIVE_INFINITY;
+    boolean converged = false;
+
+    for (int iteration = 0; (maxIterations < 1 || iteration < maxIterations) && !converged; iteration++) {
+      log.info("Iteration {}", iteration);
+      // point the output to a new directory per iteration
+      String stateOut = output + "/state-" + (iteration + 1);
+      double ll = runIteration(input, stateIn, stateOut, numTopics,
+          numWords, topicSmoothing, numReducers);
+      double relChange = (oldLL - ll) / oldLL;
+
+      // now point the input to the old output directory
+      log.info("Iteration {} finished. Log Likelihood: {}", iteration, ll);
+      log.info("(Old LL: {})", oldLL);
+      log.info("(Rel Change: {})", relChange);
+
+      converged = iteration > 2 && relChange < OVERALL_CONVERGENCE;
+      stateIn = stateOut;
+      oldLL = ll;
+    }
+  }
+
+  private static void writeInitialState(String statePath,  
+      int numTopics, int numWords) throws IOException {
+    Path dir = new Path(statePath);
+    Configuration job = new Configuration();
+    FileSystem fs = dir.getFileSystem(job);
+
+    IntPairWritable kw = new IntPairWritable();
+    DoubleWritable v = new DoubleWritable();
+
+    Random random = new Random();
+
+    for (int k = 0; k < numTopics; ++k) {
+      Path path = new Path(dir, "part-" + k);
+      SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path,
+          IntPairWritable.class, DoubleWritable.class);
+
+      double total = 0.0; // total number of pseudo counts we made
+
+      kw.setX(k);
+      for (int w = 0; w < numWords; ++w) {
+        kw.setY(w);
+        // A small amount of random noise, minimized by having a floor.
+        double pseudocount = random.nextDouble() + 1E-8;
+        total += pseudocount;
+        v.set(Math.log(pseudocount));
+        writer.append(kw, v);
+      }
+
+      kw.setY(TOPIC_SUM_KEY);
+      v.set(Math.log(total));
+      writer.append(kw, v);
+
+      writer.close();
+    }
+  }
+
+  private static double findLL(String statePath, Configuration job) throws IOException {
+    Path dir = new Path(statePath);
+    FileSystem fs = dir.getFileSystem(job);
+
+    double ll = 0.0;
+
+    IntPairWritable key = new IntPairWritable();
+    DoubleWritable value = new DoubleWritable();
+    for (FileStatus status : fs.globStatus(new Path(dir, "*"))) { 
+      Path path = status.getPath();
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
+      while (reader.next(key, value)) {
+        if (key.getX() == LOG_LIKELIHOOD_KEY) {
+          ll = value.get();
+          break;
+        }
+      }
+      reader.close();
+    }
+
+    return ll;
+  }
+
+  /**
+   * Run the job using supplied arguments
+   *
+   * @param input         the directory pathname for input points
+   * @param stateIn       the directory pathname for input state
+   * @param stateOut      the directory pathname for output state
+   * @param modelFactory  the class name of the model factory class
+   * @param numTopics   the number of clusters
+   * @param alpha_0       alpha_0
+   * @param numReducers   the number of Reducers desired
+   */
+  public static double runIteration(String input, String stateIn,
+      String stateOut, int numTopics, int numWords, double topicSmoothing,
+      int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+    Configuration conf = new Configuration();
+    conf.set(STATE_IN_KEY, stateIn);
+    conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
+    conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
+    conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
+
+    Job job = new Job(conf);
+
+    job.setOutputKeyClass(IntPairWritable.class);
+    job.setOutputValueClass(DoubleWritable.class);
+
+    FileInputFormat.addInputPaths(job, input);
+    Path outPath = new Path(stateOut);
+    FileOutputFormat.setOutputPath(job, outPath);
+
+    job.setMapperClass(LDAMapper.class);
+    job.setReducerClass(LDAReducer.class);
+    job.setCombinerClass(LDAReducer.class);
+    job.setNumReduceTasks(numReducers);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    job.setInputFormatClass(SequenceFileInputFormat.class);
+
+    job.waitForCompletion(true);
+    return findLL(stateOut, conf);
+  }
+
+  static LDAState createState(Configuration job) throws IOException {
+    String statePath = job.get(LDADriver.STATE_IN_KEY);
+    int numTopics = Integer.parseInt(job.get(LDADriver.NUM_TOPICS_KEY));
+    int numWords = Integer.parseInt(job.get(LDADriver.NUM_WORDS_KEY));
+    double topicSmoothing = Double.parseDouble(job.get(LDADriver.TOPIC_SMOOTHING_KEY));
+
+    Path dir = new Path(statePath);
+    FileSystem fs = dir.getFileSystem(job);
+
+    DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
+    double[] logTotals = new double[numTopics];
+    double ll = 0.0;
+
+    IntPairWritable key = new IntPairWritable();
+    DoubleWritable value = new DoubleWritable();
+    for (FileStatus status : fs.globStatus(new Path(dir, "*"))) { 
+      Path path = status.getPath();
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
+      while (reader.next(key, value)) {
+        int topic = key.getX();
+        int word = key.getY();
+        if (word == TOPIC_SUM_KEY) {
+          logTotals[topic] = value.get();
+          assert !Double.isInfinite(value.get());
+        } else if (topic == LOG_LIKELIHOOD_KEY) {
+          ll = value.get();
+        } else {
+          //System.out.println(topic + " " + word);
+          assert topic >= 0 && word >= 0 : topic + " " + word;
+          assert pWgT.getQuick(topic, word) == 0.0;
+          pWgT.setQuick(topic, word, value.get());
+          assert !Double.isInfinite(pWgT.getQuick(topic, word));
+        }
+      }
+      reader.close();
+    }
+
+    return new LDAState(numTopics, numWords, topicSmoothing,
+        pWgT, logTotals, ll);
+  }
+}

Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,257 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+import org.apache.commons.math.special.Gamma;
+import org.apache.mahout.matrix.BinaryFunction;
+import org.apache.mahout.matrix.DenseMatrix;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Matrix;
+import org.apache.mahout.matrix.Vector;
+
+
+/**
+* Class for performing infererence on a document, which involves
+* computing (an approximation to) p(word|topic) for each word and
+* topic, and a prior distribution p(topic) for each topic.
+*/
+public class LDAInference {
+  public LDAInference(LDAState state) {
+    this.state = state;
+  }
+
+  /**
+  * An estimate of the probabilitys for each document.
+  * Gamma(k) is the probability of seeing topic k in
+  * the document, phi(k,w) is the probability of
+  * topic k generating w in this document.
+  */
+  public class InferredDocument {
+
+    public final Vector wordCounts;
+    public final Vector gamma; // p(topic)
+    private final Matrix mphi; // log p(columnMap(w)|t)
+    private final Map<Integer, Integer> columnMap; // maps words into the matrix's column map
+    public final double logLikelihood;
+
+    public double phi(int k, int w) {
+      return mphi.getQuick(k, columnMap.get(w));
+    }
+
+    InferredDocument(Vector wordCounts, Vector gamma,
+                     Map<Integer, Integer> columnMap, Matrix phi,
+                     double ll) {
+      this.wordCounts = wordCounts;
+      this.gamma = gamma;
+      this.mphi = phi;
+      this.columnMap = columnMap;
+      this.logLikelihood = ll;
+    }
+  }
+
+  /**
+  * Performs inference on the given document, returning
+  * an InferredDocument.
+  */
+  public InferredDocument infer(Vector wordCounts) {
+    double docTotal = wordCounts.zSum();
+    int docLength = wordCounts.size();
+
+    // initialize variational approximation to p(z|doc)
+    Vector gamma = new DenseVector(state.numTopics);
+    gamma.assign(state.topicSmoothing + docTotal / state.numTopics);
+    Vector nextGamma = new DenseVector(state.numTopics);
+
+    DenseMatrix phi = new DenseMatrix(state.numTopics, docLength);
+
+    boolean converged = false;
+    double oldLL = 1;
+    // digamma is expensive, precompute
+    Vector digammaGamma = digamma(gamma);
+    // and log normalize:
+    double digammaSumGamma = digamma(gamma.zSum());
+    digammaGamma = digammaGamma.plus(-digammaSumGamma);
+
+    Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();
+
+    int iteration = 0;
+    final int MAX_ITER = 20;
+
+    while (!converged && iteration < MAX_ITER) {
+      nextGamma.assign(state.topicSmoothing); // nG := alpha, for all topics
+
+      int mapping = 0;
+      for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
+          iter.hasNext();) {
+      Vector.Element e = iter.next();
+        int word = e.index();
+        Vector phiW = eStepForWord(word, digammaGamma);
+        phi.assignColumn(mapping, phiW);
+        if (iteration == 0) { // first iteration
+          columnMap.put(word, mapping);
+        }
+
+        for (int k = 0; k < nextGamma.size(); ++k) {
+          double g = nextGamma.getQuick(k);
+          nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.get(k)));
+        }
+
+        mapping++;
+      }
+
+      Vector tempG = gamma;
+      gamma = nextGamma;
+      nextGamma = tempG;
+
+      // digamma is expensive, precompute
+      digammaGamma = digamma(gamma);
+      // and log normalize:
+      digammaSumGamma = digamma(gamma.zSum());
+      digammaGamma = digammaGamma.plus(-digammaSumGamma);
+
+      double ll = computeLikelihood(wordCounts, columnMap, phi, gamma, digammaGamma);
+      converged = oldLL < 0 && ((oldLL - ll) / oldLL < E_STEP_CONVERGENCE);
+      assert !Double.isNaN(ll);
+
+      oldLL = ll;
+      iteration++;
+    }
+
+    return new InferredDocument(wordCounts, gamma, columnMap, phi, oldLL);
+  }
+
+  private LDAState state;
+
+  private double computeLikelihood(Vector wordCounts, Map<Integer, Integer> columnMap,
+      Matrix phi, Vector gamma, Vector digammaGamma) {
+    double ll = 0.0;
+
+    // log normalizer for q(gamma);
+    ll += Gamma.logGamma(state.topicSmoothing * state.numTopics);
+    ll -= state.numTopics * Gamma.logGamma(state.topicSmoothing);
+    assert !Double.isNaN(ll) : state.topicSmoothing + " " + state.numTopics;
+
+    // now for the the rest of q(gamma);
+    for (int k = 0; k < state.numTopics; ++k) {
+      ll += (state.topicSmoothing - gamma.get(k)) * digammaGamma.get(k);
+      ll += Gamma.logGamma(gamma.get(k));
+
+    }
+    ll -= Gamma.logGamma(gamma.zSum());
+    assert !Double.isNaN(ll);
+
+
+    // for each word
+    for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
+        iter.hasNext();) {
+      Vector.Element e = iter.next();
+      int w = e.index();
+      double n = e.get();
+      int mapping = columnMap.get(w);
+      // now for each topic:
+      for (int k = 0; k < state.numTopics; k++) {
+        double llPart = 0.0;
+        llPart += Math.exp(phi.get(k, mapping))
+          * (digammaGamma.get(k) - phi.get(k, mapping)
+             + state.logProbWordGivenTopic(w, k));
+
+        ll += llPart * n;
+
+        assert state.logProbWordGivenTopic(w, k)  < 0;
+        assert !Double.isNaN(llPart);
+      }
+    }
+    assert ll <= 0;
+    return ll;
+  }
+
+  /**
+   * Compute log q(k|w,doc) for each topic k, for a given word.
+   */
+  private Vector eStepForWord(int word, Vector digammaGamma) {
+    Vector phi = new DenseVector(state.numTopics); // log q(k|w), for each w
+    double phiTotal = Double.NEGATIVE_INFINITY; // log Normalizer
+    for (int k = 0; k < state.numTopics; ++k) { // update q(k|w)'s param phi
+      phi.set(k, state.logProbWordGivenTopic(word, k) + digammaGamma.get(k));
+      phiTotal = LDAUtil.logSum(phiTotal, phi.get(k));
+
+      assert !Double.isNaN(phiTotal);
+      assert !Double.isNaN(state.logProbWordGivenTopic(word, k));
+      assert !Double.isInfinite(state.logProbWordGivenTopic(word, k));
+      assert !Double.isNaN(digammaGamma.get(k));
+    }
+    return phi.plus(-phiTotal); // log normalize
+  }
+
+
+  private static Vector digamma(Vector v) {
+    Vector digammaGamma = new DenseVector(v.size());
+    digammaGamma.assign(v, new BinaryFunction() {
+      public double apply(double unused, double g) {
+        return digamma(g);
+      }
+    });
+    return digammaGamma;
+  }
+
+  /**
+   * Approximation to the digamma function, from Radford Neal.
+   *
+   * Original License:
+   * Copyright (c) 1995-2003 by Radford M. Neal
+   *
+   * Permission is granted for anyone to copy, use, modify, or distribute this
+   * program and accompanying programs and documents for any purpose, provided
+   * this copyright notice is retained and prominently displayed, along with
+   * a note saying that the original programs are available from Radford Neal's
+   * web page, and note is made of any changes made to the programs.  The
+   * programs and documents are distributed without any warranty, express or
+   * implied.  As the programs were written for research purposes only, they have
+   * not been tested to the degree that would be advisable in any important
+   * application.  All use of these programs is entirely at the user's own risk.
+   *
+   *
+   * Ported to Java for Mahout.
+   *
+   */
+  private static double digamma(double x) {
+    double r = 0.0;
+
+    while (x <= 5) {
+      r -= 1 / x;
+      x += 1;
+    }
+
+    double f = 1. / (x * x);
+    double t = f * (-1 / 12.0
+        + f * (1 / 120.0
+        + f * (-1 / 252.0
+        + f * (1 / 240.0 
+        + f * (-1 / 132.0 
+        + f * (691 / 32760.0 
+        + f * (-1 / 12.0 
+        + f * 3617.0 / 8160.0)))))));
+    return r + Math.log(x) - 0.5 / x + t;
+  }
+
+    private static final double E_STEP_CONVERGENCE = 1E-6;
+}

Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.MapContext;
+import org.apache.mahout.matrix.AbstractVector;
+import org.apache.mahout.matrix.Vector;
+
+/**
+* Runs inference on the input documents (which are
+* sparse vectors of word counts) and outputs
+* the sufficient statistics for the word-topic
+* assignments.
+*/
+public class LDAMapper extends 
+    Mapper<WritableComparable<?>, Vector, IntPairWritable, DoubleWritable> {
+
+  private LDAState state;
+  private LDAInference infer;
+
+  @Override
+  public void map(WritableComparable<?> key, Vector wordCounts, Context context)
+      throws IOException, InterruptedException {
+    LDAInference.InferredDocument doc = infer.infer(wordCounts);
+
+    double[] logTotals = new double[state.numTopics];
+    Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
+
+    // Output sufficient statistics for each word. == pseudo-log counts.
+    IntPairWritable kw = new IntPairWritable();
+    DoubleWritable v = new DoubleWritable();
+    for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
+        iter.hasNext();) {
+      Vector.Element e = iter.next();
+      int w = e.index();
+      kw.setY(w);
+      for (int k = 0; k < state.numTopics; ++k) {
+        v.set(doc.phi(k, w) + Math.log(e.get()));
+
+        kw.setX(k);
+
+        // ouput (topic, word)'s logProb contribution
+        context.write(kw, v);
+        logTotals[k] = LDAUtil.logSum(logTotals[k], v.get());
+      }
+    }
+    
+    // Output the totals for the statistics. This is to make
+    // normalizing a lot easier.
+    kw.setY(LDADriver.TOPIC_SUM_KEY);
+    for (int k = 0; k < state.numTopics; ++k) {
+      kw.setX(k);
+      v.set(logTotals[k]);
+      assert !Double.isNaN(v.get());
+      context.write(kw, v);
+    }
+
+    // Output log-likelihoods.
+    kw.setX(LDADriver.LOG_LIKELIHOOD_KEY);
+    kw.setY(LDADriver.LOG_LIKELIHOOD_KEY);
+    v.set(doc.logLikelihood);
+    context.write(kw, v);
+  }
+
+  public void configure(LDAState myState) {
+    this.state = myState;
+    this.infer = new LDAInference(state);
+  }
+
+  public void configure(Configuration job) {
+    try {
+      LDAState myState = LDADriver.createState(job);
+      configure(myState);
+    } catch (IOException e) {
+      throw new RuntimeException("Error creating LDA State!", e);
+    }
+  }
+
+  @Override
+  protected void setup(Context context) {
+    configure(context.getConfiguration());
+  }
+
+
+}

Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,219 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.PriorityQueue;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.utils.CommandLineUtil;
+
+/**
+ * Class to print out the top K words for each topic.
+ */
+public class LDAPrintTopics {
+  private LDAPrintTopics() {
+  }
+
+  private static class StringDoublePair implements Comparable<StringDoublePair> {
+    StringDoublePair(double score, String word) {
+      this.score = score;
+      this.word = word;
+    }
+    
+    public int compareTo(StringDoublePair other) {
+      return Double.compare(score,other.score);
+    }
+
+    double score;
+    String word;
+  }
+
+  public static List<List<String>> topWordsForTopics(String dir, Configuration job,
+      List<String> wordList, int numWordsToPrint) throws IOException {
+    FileSystem fs = new Path(dir).getFileSystem(job);
+
+    List<PriorityQueue<StringDoublePair>> queues = new ArrayList<PriorityQueue<StringDoublePair>>();
+
+    IntPairWritable key = new IntPairWritable();
+    DoubleWritable value = new DoubleWritable();
+    for (FileStatus status : fs.globStatus(new Path(dir, "*"))) { 
+      Path path = status.getPath();
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
+      while (reader.next(key, value)) {
+        int topic = key.getX();
+        int word = key.getY();
+
+        ensureQueueSize(queues,topic);
+        if (word >= 0 && topic >= 0) {
+          double score = value.get();
+          String realWord = wordList.get(word);
+          maybeEnqueue(queues.get(topic), realWord, score, numWordsToPrint);
+        }
+      }
+      reader.close();
+    }
+
+    List<List<String>> result = new ArrayList<List<String>>();
+    for (int i = 0; i < queues.size(); ++i) {
+      result.add(i,new LinkedList<String>());
+      for (StringDoublePair sdp: queues.get(i)) {
+        result.get(i).add(0,sdp.word); // prepend
+      }
+    }
+
+    return result;
+  }
+
+  // Expands the queue list to have a Queue for topic K
+  private static void ensureQueueSize(List<PriorityQueue<StringDoublePair>> queues, int k) {
+    for (int i = queues.size(); i <= k; ++i) {
+      queues.add(new PriorityQueue<StringDoublePair>());
+    }
+  }
+
+  // Adds the word if the queue is below capacity, or the score is high enough
+  private static void maybeEnqueue(PriorityQueue<StringDoublePair> q, String word, 
+      double score, int numWordsToPrint) {
+    if (q.size() >= numWordsToPrint && score > q.peek().score) {
+      q.poll();
+    }
+    if (q.size() < numWordsToPrint) {
+      q.add(new StringDoublePair(score,word));
+    } 
+  }
+
+  // Reads dictionary in created by the vector Driver in util
+  private static List<String> readDictionary(File path) throws IOException {
+    BufferedReader rdr = new BufferedReader(new FileReader(path));
+
+    List<String> result = new ArrayList<String>();
+
+    // skip 2 lines
+    rdr.readLine();
+    rdr.readLine();
+    String line = null;
+    while ( (line = rdr.readLine()) != null) {
+      String[] parts = line.split("\t");
+      String word = parts[0];
+      int index = Integer.parseInt(parts[2]);
+      assert index == result.size();
+      result.add(word);
+    }
+    rdr.close();
+
+    return result;
+  }
+
+  public static void main(String[] args) {
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+
+    Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
+        abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+        "Path to an LDA output (a state)").withShortName("i").create();
+
+    Option dictOpt = obuilder.withLongName("dict").withRequired(true).withArgument(
+        abuilder.withName("dict").withMinimum(1).withMaximum(1).create()).withDescription(
+        "Dictionary to read in, created by utils.vector.Driver").withShortName("d").create();
+
+    Option outOpt = obuilder.withLongName("output").withRequired(true).withArgument(
+        abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+        "Output directory to write top words").withShortName("o").create();
+
+    Option wordOpt = obuilder.withLongName("words").withRequired(true).withArgument(
+        abuilder.withName("words").withMinimum(0).withMaximum(1).withDefault("20").create()).withDescription(
+        "Number of words to print").withShortName("w").create();
+
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
+
+    Group group = gbuilder.withName("Options").withOption(dictOpt).withOption(outOpt).withOption(
+        wordOpt).withOption(inputOpt).create();
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        return;
+      }
+
+      String input   = cmdLine.getValue(inputOpt).toString();
+      File output = new File(cmdLine.getValue(outOpt).toString());
+      File dict   = new File(cmdLine.getValue(dictOpt).toString());
+      int numWords = 20;
+      if (cmdLine.hasOption(wordOpt)) {
+        numWords = Integer.parseInt(cmdLine.getValue(wordOpt).toString());
+      }
+
+      List<String> wordList = readDictionary(dict);
+
+      Configuration config = new Configuration();
+      List<List<String>> topWords = topWordsForTopics(input, config, wordList, numWords);
+
+      if(!output.exists()) {
+        if (!output.mkdirs()) {
+          throw new IOException("Could not create directory: " + output);
+        }
+      }
+
+      for (int i = 0; i < topWords.size(); ++i) {
+        List<String> topK = topWords.get(i);
+        File out = new File(output,"topic-"+i);
+        PrintWriter writer = new PrintWriter(new FileWriter(out));
+        writer.println("Topic " + i);
+        writer.println("===========");
+        for (String word: topK) {
+          writer.println(word);
+        }
+        writer.close();
+      }
+
+    } catch (OptionException e) {
+      System.err.println("Exception: " + e);
+      CommandLineUtil.printHelp(group);
+    } catch (IOException e) {
+      System.err.println("Exception:" + e);
+      e.printStackTrace();
+    }
+  }
+
+}

Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+
+
+/**
+* A very simple reducer which simply logSums the
+* input doubles and outputs a new double for sufficient
+* statistics, and sums log likelihoods.
+*/
+public class LDAReducer extends
+    Reducer<IntPairWritable, DoubleWritable, IntPairWritable, DoubleWritable> {
+
+  @Override
+  public void reduce(IntPairWritable topicWord, Iterable<DoubleWritable> values,
+      Context context) 
+      throws java.io.IOException, InterruptedException {
+
+    // sum likelihoods
+    if (topicWord.getY() == LDADriver.LOG_LIKELIHOOD_KEY) {
+      double accum = 0.0;
+      for (DoubleWritable vw : values) {
+        double v = vw.get();
+        assert !Double.isNaN(v) : topicWord.getX() + " " + topicWord.getY();
+        accum += v;
+      }
+      context.write(topicWord, new DoubleWritable(accum));
+    } else { // log sum sufficient statistics.
+      double accum = Double.NEGATIVE_INFINITY;
+      for (DoubleWritable vw : values) {
+        double v = vw.get();
+        assert !Double.isNaN(v) : topicWord.getX() + " " + topicWord.getY();
+        accum = LDAUtil.logSum(accum, v);
+        assert !Double.isNaN(accum) : topicWord.getX() + " " + topicWord.getY();
+      }
+      context.write(topicWord, new DoubleWritable(accum));
+    }
+
+  }
+
+}

Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda;
+
+import org.apache.mahout.matrix.Matrix;
+
+public class LDAState {
+  public final int numTopics; 
+  public final int numWords; 
+  public final double topicSmoothing;
+  private final Matrix topicWordProbabilities; // log p(w|t) for topic=1..nTopics
+  private final double[] logTotals; // log \sum p(w|t) for topic=1..nTopics
+  public final double logLikelihood; // log \sum p(w|t) for topic=1..nTopics
+
+  public LDAState(int numTopics, int numWords, double topicSmoothing,
+      Matrix topicWordProbabilities, double[] logTotals, double ll) {
+    this.numWords = numWords;
+    this.numTopics = numTopics;
+    this.topicSmoothing = topicSmoothing;
+    this.topicWordProbabilities = topicWordProbabilities;
+    this.logTotals = logTotals;
+    this.logLikelihood = ll;
+  }
+
+  public double logProbWordGivenTopic(int word, int topic) {
+    final double logProb = topicWordProbabilities.getQuick(topic, word);
+    return logProb == Double.NEGATIVE_INFINITY ? -100.0 
+      : logProb - logTotals[topic];
+  }
+}

Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda;
+
+/**
+ * Various utility classes for doing LDA inference..
+ */
+final class LDAUtil {
+  private LDAUtil() {
+  } // no creation
+
+  /**
+   * @return log(exp(a) + exp(b))
+   */
+  static double logSum(double a, double b) {
+    return (a == Double.NEGATIVE_INFINITY) ? b
+      : (b == Double.NEGATIVE_INFINITY) ? a
+      : (a < b) ? b + Math.log(1 + Math.exp(a - b))
+      : a + Math.log(1 + Math.exp(b - a));    
+  }
+
+}

Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,122 @@
+package org.apache.mahout.clustering.lda;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.commons.math.distribution.PoissonDistribution;
+import org.apache.commons.math.distribution.PoissonDistributionImpl;
+
+import org.apache.mahout.matrix.DenseMatrix;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Matrix;
+import org.apache.mahout.matrix.Vector;
+
+public class TestLDAInference extends TestCase {
+
+  private Random random;
+
+  private static int NUM_TOPICS = 20;
+
+  @Override
+  protected void setUp() throws Exception {
+    super.setUp();
+    random = new Random();
+  }
+
+  /**
+   * Generate random document vector
+   * @param numWords int number of words in the vocabulary
+   * @param numWords E[count] for each word
+   */
+  private Vector generateRandomDoc(int numWords, double sparsity) {
+    Vector v = new DenseVector(numWords);
+    try {
+      PoissonDistribution dist = new PoissonDistributionImpl(sparsity);
+      for (int i = 0; i < numWords; i++) {
+        // random integer
+        v.setQuick(i, dist.inverseCumulativeProbability(random.nextDouble()) + 1);
+      }
+    } catch (Exception e) {
+      e.printStackTrace();
+      fail("Caught " + e.toString());
+    }
+    return v;
+  }
+
+  private LDAState generateRandomState(int numWords, int numTopics) {
+    double topicSmoothing = 50.0 / numTopics; // whatever
+    Matrix m = new DenseMatrix(numTopics, numWords);
+    double[] logTotals = new double[numTopics];
+    double ll = Double.NEGATIVE_INFINITY;
+
+    for (int k = 0; k < numTopics; ++k) {
+      double total = 0.0; // total number of pseudo counts we made
+      for (int w = 0; w < numWords; ++w) {
+        // A small amount of random noise, minimized by having a floor.
+        double pseudocount = random.nextDouble() + 1E-10;
+        total += pseudocount;
+        m.setQuick(k, w, Math.log(pseudocount));
+      }
+
+      logTotals[k] = Math.log(total);
+    }
+
+    return new LDAState(numTopics, numWords, topicSmoothing, m, logTotals, ll);
+  }
+
+
+  private void runTest(int numWords, double sparsity, int numTests) {
+    LDAState state = generateRandomState(numWords, NUM_TOPICS);
+    LDAInference lda = new LDAInference(state);
+    for (int t = 0; t < numTests; ++t) {
+      Vector v = generateRandomDoc(numWords, sparsity);
+      LDAInference.InferredDocument doc = lda.infer(v);
+
+      assertEquals("wordCounts", doc.wordCounts, v);
+      assertNotNull("gamma", doc.gamma);
+      for (Iterator<Vector.Element> iter = v.iterateNonZero();
+          iter.hasNext(); ) {
+        int w = iter.next().index();
+        for (int k = 0; k < NUM_TOPICS; ++k) {
+          double logProb = doc.phi(k, w);
+          assertTrue(k + " " + w + " logProb " + logProb, logProb <= 0.0); 
+        }
+      }
+      assertTrue("log likelihood", doc.logLikelihood <= 1E-10);
+    }
+  }
+
+
+  public void testLDAEasy() {
+    runTest(10, 1, 5); // 1 word per doc in expectation
+  }
+
+  public void testLDASparse() {
+    runTest(100, 0.4, 5); // 40 words per doc in expectation
+  }
+
+  public void testLDADense() {
+    runTest(100, 3, 5); // 300 words per doc in expectation
+  }
+}

Propchange: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+
+import org.apache.commons.math.distribution.PoissonDistribution;
+import org.apache.commons.math.distribution.PoissonDistributionImpl;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.matrix.DenseMatrix;
+import org.apache.mahout.matrix.Matrix;
+import org.apache.mahout.matrix.SparseVector;
+import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.utils.DummyOutputCollector;
+
+import static org.easymock.classextension.EasyMock.*;
+
+public class TestMapReduce extends TestCase {
+
+
+  private Random random;
+
+  /**
+   * Generate random document vector
+   * @param numWords int number of words in the vocabulary
+   * @param numWords E[count] for each word
+   */
+  private SparseVector generateRandomDoc(int numWords, double sparsity) {
+    SparseVector v = new SparseVector(numWords,(int)(numWords * sparsity));
+    try {
+      PoissonDistribution dist = new PoissonDistributionImpl(sparsity);
+      for (int i = 0; i < numWords; i++) {
+        // random integer
+        v.set(i,dist.inverseCumulativeProbability(random.nextDouble()) + 1);
+      }
+    } catch(Exception e) {
+      e.printStackTrace();
+      fail("Caught " + e.toString());
+    }
+    return v;
+  }
+
+  private LDAState generateRandomState(int numWords, int numTopics) {
+    double topicSmoothing = 50.0 / numTopics; // whatever
+    Matrix m = new DenseMatrix(numTopics,numWords);
+    double[] logTotals = new double[numTopics];
+    double ll = Double.NEGATIVE_INFINITY;
+    for(int k = 0; k < numTopics; ++k) {
+      double total = 0.0; // total number of pseudo counts we made
+      for(int w = 0; w < numWords; ++w) {
+        // A small amount of random noise, minimized by having a floor.
+        double pseudocount = random.nextDouble() + 1E-10;
+        total += pseudocount;
+        m.setQuick(k,w,Math.log(pseudocount));
+      }
+
+      logTotals[k] = Math.log(total);
+    }
+
+    return new LDAState(numTopics,numWords,topicSmoothing,m,logTotals,ll);
+  }
+
+  @Override
+  protected void setUp() throws Exception {
+    super.setUp();
+    File f = new File("input");
+    random = new Random();
+    f.mkdir();
+  }
+
+  private static int NUM_TESTS = 10;
+  private static int NUM_TOPICS = 10;
+
+  /**
+   * Test the basic Mapper
+   * 
+   * @throws Exception
+   */
+  public void testMapper() throws Exception {
+    LDAState state = generateRandomState(100,NUM_TOPICS);
+    LDAMapper mapper = new LDAMapper();
+    mapper.configure(state);
+
+    for(int i = 0; i < NUM_TESTS; ++i) {
+      SparseVector v = generateRandomDoc(100,0.3);
+      int myNumWords = numNonZero(v);
+      LDAMapper.Context mock = createMock(LDAMapper.Context.class);
+
+      mock.write(isA(IntPairWritable.class),isA(DoubleWritable.class));
+      expectLastCall().times(myNumWords * NUM_TOPICS + NUM_TOPICS + 1);
+      replay(mock);
+
+      mapper.map(new Text("tstMapper"), v, mock);
+      verify(mock);
+    }
+  }
+
+  private int numNonZero(Vector v) {
+    int count = 0;
+    for(Iterator<Vector.Element> iter = v.iterateNonZero();
+        iter.hasNext();iter.next() ) {
+      count++;
+    }
+    return count;
+  }
+
+}

Propchange: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: lucene/mahout/trunk/examples/bin/build-reuters.sh
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/bin/build-reuters.sh?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/bin/build-reuters.sh (added)
+++ lucene/mahout/trunk/examples/bin/build-reuters.sh Mon Aug 17 13:33:18 2009
@@ -0,0 +1,60 @@
+#/**
+# * Licensed to the Apache Software Foundation (ASF) under one or more
+# * contributor license agreements.  See the NOTICE file distributed with
+# * this work for additional information regarding copyright ownership.
+# * The ASF licenses this file to You under the Apache License, Version 2.0
+# * (the "License"); you may not use this file except in compliance with
+# * the License.  You may obtain a copy of the License at
+# *
+# *     http://www.apache.org/licenses/LICENSE-2.0
+# *
+# * Unless required by applicable law or agreed to in writing, software
+# * distributed under the License is distributed on an "AS IS" BASIS,
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# * See the License for the specific language governing permissions and
+# * limitations under the License.
+# */
+
+#
+# Runs the LDA examples using Reuters.
+#
+# To run:  change into the mahout/examples directory (the parent of the one containing this file) and type:
+#  bin/build-reuters.sh
+#
+#
+mkdir -p work
+if [ ! -e work/reuters-out ]; then
+  if [ ! -e work/reuters-sgm ]; then
+    if [ ! -f work/reuters21578.tar.gz ]; then
+      echo "Downloading Reuters-21578"
+      curl http://kdd.ics.uci.edu/databases/reuters21578/reuters21578.tar.gz  -o work/reuters21578.tar.gz
+    fi
+    mkdir -p work/reuters-sgm
+    echo "Extracting..."
+    cd work/reuters-sgm && tar xzf ../reuters21578.tar.gz && cd .. && cd ..
+  fi
+  echo "Converting to plain text."
+  mvn -e -q exec:java  -Dexec.mainClass="org.apache.lucene.benchmark.utils.ExtractReuters" -Dexec.args="work/reuters-sgm work/reuters-out" || exit
+fi
+# Create index
+if [ ! -e work/index ]; then
+  echo "Creating index";
+  mvn -e exec:java -Dexec.classpathScope="test" -Dexec.mainClass="org.apache.lucene.benchmark.byTask.Benchmark" -Dexec.args="bin/lda.algorithm" || ( rm -rf work/index && exit )
+fi
+if [ ! -e work/vectors ]; then
+  echo "Creating vectors from index"
+  cd ../core
+  mvn -q install -DskipTests=true
+  cd ../utils/
+  mvn -q compile
+  mvn -e exec:java -Dexec.mainClass="org.apache.mahout.utils.vectors.lucene.Driver" \
+    -Dexec.args="--dir ../examples/work/index/ --field body --dictOut ../examples/work/dict.txt \
+    --output ../examples/work/vectors --minDF 100 --maxDFPercent 97" || exit
+  cd ../core/
+fi
+echo "Running LDA"
+rm -rf ../examples/work/lda
+MAVEN_OPTS="-Xmx2G -ea" mvn -e exec:java -Dexec.mainClass=org.apache.mahout.clustering.lda.LDADriver -Dexec.args="-i ../examples/work/vectors -o ../examples/work/lda/\
+  -k 20 -v 10000 --maxIter 40"
+echo "Writing top words for each topic to to examples/work/topics/"
+mvn -q exec:java -Dexec.mainClass="org.apache.mahout.clustering.lda.LDAPrintTopics" -Dexec.args="-i `ls -1dtr ../examples/work/lda/state-* | tail -1` -d ../examples/work/dict.txt -o ../examples/work/topics/ -w 100"

Propchange: lucene/mahout/trunk/examples/bin/build-reuters.sh
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: lucene/mahout/trunk/examples/bin/build-reuters.sh
------------------------------------------------------------------------------
    svn:executable = *

Added: lucene/mahout/trunk/examples/bin/lda.algorithm
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/bin/lda.algorithm?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/bin/lda.algorithm (added)
+++ lucene/mahout/trunk/examples/bin/lda.algorithm Mon Aug 17 13:33:18 2009
@@ -0,0 +1,30 @@
+merge.policy=org.apache.lucene.index.LogDocMergePolicy
+merge.factor=mrg:10:20
+max.buffered=buf:100:1000
+compound=true
+
+analyzer=org.apache.lucene.analysis.standard.StandardAnalyzer
+directory=FSDirectory
+
+doc.stored=true
+doc.term.vector=true
+doc.tokenized=true
+log.step=600
+
+content.source=org.apache.lucene.benchmark.byTask.feeds.ReutersContentSource
+content.source.forever=false
+doc.maker.forever=false
+query.maker=org.apache.lucene.benchmark.byTask.feeds.SimpleQueryMaker
+
+# task at this depth or less would print when they start
+task.max.depth.log=2
+
+log.queries=false
+# --------- alg
+{ "BuildReuters"
+  CreateIndex 
+  { "AddDocs" AddDoc > : *
+#  Optimize
+  CloseIndex
+}
+

Modified: lucene/mahout/trunk/examples/pom.xml
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/pom.xml?rev=804979&r1=804978&r2=804979&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/pom.xml (original)
+++ lucene/mahout/trunk/examples/pom.xml Mon Aug 17 13:33:18 2009
@@ -180,6 +180,11 @@
       <artifactId>mahout-core</artifactId>
       <version>${project.version}</version>
     </dependency>
+    <dependency>
+      <groupId>org.apache.mahout</groupId>
+      <artifactId>mahout-utils</artifactId>
+      <version>${project.version}</version>
+    </dependency>
 
 
     <!-- A Lucene wikipedia tokenizer/analyzer -->
@@ -197,6 +202,17 @@
       <type>test-jar</type>
       <scope>test</scope>
     </dependency>
+    <dependency>
+      <groupId>org.apache.lucene</groupId>
+      <artifactId>lucene-benchmark</artifactId>
+      <version>2.9-SNAPSHOT</version>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.commons</groupId>
+      <artifactId>commons-compress</artifactId>
+      <version>1.0</version>
+
+    </dependency>
 
     <dependency>
       <groupId>org.apache.openejb</groupId>