You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2009/08/17 15:33:19 UTC
svn commit: r804979 - in /lucene/mahout/trunk: core/
core/src/main/java/org/apache/mahout/clustering/lda/
core/src/test/java/org/apache/mahout/clustering/lda/ examples/ examples/bin/
Author: gsingers
Date: Mon Aug 17 13:33:18 2009
New Revision: 804979
URL: http://svn.apache.org/viewvc?rev=804979&view=rev
Log:
MAHOUT-123: Latent Dirichlet Allocation
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java (with props)
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java (with props)
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java (with props)
lucene/mahout/trunk/examples/bin/
lucene/mahout/trunk/examples/bin/build-reuters.sh (with props)
lucene/mahout/trunk/examples/bin/lda.algorithm
Modified:
lucene/mahout/trunk/core/pom.xml
lucene/mahout/trunk/examples/pom.xml
Modified: lucene/mahout/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/pom.xml?rev=804979&r1=804978&r2=804979&view=diff
==============================================================================
--- lucene/mahout/trunk/core/pom.xml (original)
+++ lucene/mahout/trunk/core/pom.xml Mon Aug 17 13:33:18 2009
@@ -454,6 +454,9 @@
<version>1.1.1</version>
</dependency>
+
+
+
<dependency>
<groupId>commons-httpclient</groupId>
<artifactId>commons-httpclient</artifactId>
@@ -548,12 +551,25 @@
<version>2.0-mahout</version>
</dependency>
<dependency>
+ <groupId>commons-math</groupId>
+ <artifactId>commons-math</artifactId>
+ <version>1.2</version>
+ </dependency>
+
+ <dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>3.8.2</version>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymockclassextension</artifactId>
+ <version>2.2</version>
+ <scope>test</scope>
+ </dependency>
+
<!-- Gson: Java to Json conversion -->
<dependency>
<groupId>com.google.code.gson</groupId>
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,121 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.io.WritableComparator;
+
+/**
+* Saves two ints, x and y.
+*/
+public class IntPairWritable implements WritableComparable<IntPairWritable> {
+
+ private int x;
+ private int y;
+
+ /** For serialization purposes only */
+ public IntPairWritable() {
+ }
+
+ public IntPairWritable(int x, int y) {
+ this.x = x;
+ this.y = y;
+ }
+
+ public void setX(int x) {
+ this.x = x;
+ }
+
+ public int getX() {
+ return x;
+ }
+
+ public void setY(int y) {
+ this.y = y;
+ }
+
+ public int getY() {
+ return y;
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ dataOutput.writeInt(x);
+ dataOutput.writeInt(y);
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ x = dataInput.readInt();
+ y = dataInput.readInt();
+ }
+
+ public int compareTo(IntPairWritable that) {
+ int xdiff = this.x - that.x;
+ return (xdiff != 0) ? xdiff : this.y - that.y;
+ }
+
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ } else if (!(o instanceof IntPairWritable)) {
+ return false;
+ }
+
+ IntPairWritable that = (IntPairWritable) o;
+
+ return that.x == this.x && this.y == that.y;
+ }
+
+ @Override
+ public int hashCode() {
+ return 43 * x + y;
+ }
+
+ @Override
+ public String toString() {
+ return "(" + x + ", " + y + ")";
+ }
+
+ static {
+ WritableComparator.define(IntPairWritable.class, new Comparator());
+ }
+
+ public static class Comparator extends WritableComparator {
+ public Comparator() {
+ super(IntPairWritable.class);
+ }
+
+ public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
+ assert l1 == 8;
+ int int11 = readInt(b1, s1);
+ int int21 = readInt(b2, s2);
+ if (int11 != int21) {
+ return int11 - int21;
+ }
+
+ int int12 = readInt(b1, s1 + 4);
+ int int22 = readInt(b2, s2 + 4);
+ return int12 - int22;
+ }
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,349 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.io.IOException;
+import java.util.Random;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.matrix.DenseMatrix;
+import org.apache.mahout.utils.CommandLineUtil;
+import org.apache.mahout.utils.HadoopUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+* Estimates an LDA model from a corpus of documents,
+* which are SparseVectors of word counts. At each
+* phase, it outputs a matrix of log probabilities of
+* each topic.
+*/
+public final class LDADriver {
+
+ static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
+
+ static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
+ static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
+
+ static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";
+
+ static final int LOG_LIKELIHOOD_KEY = -2;
+ static final int TOPIC_SUM_KEY = -1;
+
+ static final double OVERALL_CONVERGENCE = 1E-5;
+
+ private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
+
+ private LDADriver() {
+ }
+
+ public static void main(String[] args) throws InstantiationException,
+ IllegalAccessException, ClassNotFoundException,
+ IOException, InterruptedException {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+
+ Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Output Working Directory").withShortName("o").create();
+
+ Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withDescription(
+ "If set, overwrite the output directory").withShortName("w").create();
+
+ Option topicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(
+ abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The number of topics").withShortName("k").create();
+
+ Option wordsOpt = obuilder.withLongName("numWords").withRequired(true).withArgument(
+ abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The number of words in the corpus").withShortName("v").create();
+
+ Option topicSmOpt = obuilder.withLongName("topicSmoothing").withRequired(false).withArgument(abuilder
+ .withName("topicSmoothing").withDefault(-1.0).withMinimum(0).withMaximum(1).create()).withDescription(
+ "Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create();
+
+ Option maxIterOpt = obuilder.withLongName("maxIter").withRequired(false).withArgument(
+ abuilder.withName("maxIter").withDefault(-1).withMinimum(0).withMaximum(1).create()).withDescription(
+ "Max iterations to run (or until convergence). -1 (default) waits until convergence.").create();
+
+ Option numReducOpt = obuilder.withLongName("numReducers").withRequired(false).withArgument(
+ abuilder.withName("numReducers").withDefault(10).withMinimum(0).withMaximum(1).create()).withDescription(
+ "Max iterations to run (or until convergence). Default 10").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
+
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(
+ topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt).withOption(
+ numReducOpt).withOption(overwriteOutput).withOption(helpOpt).create();
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+ String input = cmdLine.getValue(inputOpt).toString();
+ String output = cmdLine.getValue(outputOpt).toString();
+
+ int maxIterations = -1;
+ if (cmdLine.hasOption(maxIterOpt)) {
+ maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString());
+ }
+
+ int numReduceTasks = 2;
+ if (cmdLine.hasOption(numReducOpt)) {
+ numReduceTasks = Integer.parseInt(cmdLine.getValue(numReducOpt).toString());
+ }
+
+ int numTopics = 20;
+ if (cmdLine.hasOption(topicsOpt)) {
+ numTopics = Integer.parseInt(cmdLine.getValue(topicsOpt).toString());
+ }
+
+ int numWords = 20;
+ if (cmdLine.hasOption(wordsOpt)) {
+ numWords = Integer.parseInt(cmdLine.getValue(wordsOpt).toString());
+ }
+
+ if (cmdLine.hasOption(overwriteOutput)) {
+ HadoopUtil.overwriteOutput(output);
+ }
+
+ double topicSmoothing = -1.0;
+ if (cmdLine.hasOption(topicSmOpt)) {
+ topicSmoothing = Double.parseDouble(cmdLine.getValue(maxIterOpt).toString());
+ }
+ if(topicSmoothing < 1) {
+ topicSmoothing = 50. / numTopics;
+ }
+
+ runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations,
+ numReduceTasks);
+
+ } catch (OptionException e) {
+ log.error("Exception", e);
+ CommandLineUtil.printHelp(group);
+ }
+ }
+
+ /**
+ * Run the job using supplied arguments
+ *
+ * @param input the directory pathname for input points
+ * @param output the directory pathname for output points
+ * @param numTopics the number of topics
+ * @param numWords the number of words
+ * @param topicSmoothing pseudocounts for each topic, typically small < .5
+ * @param maxIterations the maximum number of iterations
+ * @param numReducers the number of Reducers desired
+ * @throws IOException
+ */
+ public static void runJob(String input, String output, int numTopics,
+ int numWords, double topicSmoothing, int maxIterations, int numReducers)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ String stateIn = output + "/state-0";
+ writeInitialState(stateIn, numTopics, numWords);
+ double oldLL = Double.NEGATIVE_INFINITY;
+ boolean converged = false;
+
+ for (int iteration = 0; (maxIterations < 1 || iteration < maxIterations) && !converged; iteration++) {
+ log.info("Iteration {}", iteration);
+ // point the output to a new directory per iteration
+ String stateOut = output + "/state-" + (iteration + 1);
+ double ll = runIteration(input, stateIn, stateOut, numTopics,
+ numWords, topicSmoothing, numReducers);
+ double relChange = (oldLL - ll) / oldLL;
+
+ // now point the input to the old output directory
+ log.info("Iteration {} finished. Log Likelihood: {}", iteration, ll);
+ log.info("(Old LL: {})", oldLL);
+ log.info("(Rel Change: {})", relChange);
+
+ converged = iteration > 2 && relChange < OVERALL_CONVERGENCE;
+ stateIn = stateOut;
+ oldLL = ll;
+ }
+ }
+
+ private static void writeInitialState(String statePath,
+ int numTopics, int numWords) throws IOException {
+ Path dir = new Path(statePath);
+ Configuration job = new Configuration();
+ FileSystem fs = dir.getFileSystem(job);
+
+ IntPairWritable kw = new IntPairWritable();
+ DoubleWritable v = new DoubleWritable();
+
+ Random random = new Random();
+
+ for (int k = 0; k < numTopics; ++k) {
+ Path path = new Path(dir, "part-" + k);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path,
+ IntPairWritable.class, DoubleWritable.class);
+
+ double total = 0.0; // total number of pseudo counts we made
+
+ kw.setX(k);
+ for (int w = 0; w < numWords; ++w) {
+ kw.setY(w);
+ // A small amount of random noise, minimized by having a floor.
+ double pseudocount = random.nextDouble() + 1E-8;
+ total += pseudocount;
+ v.set(Math.log(pseudocount));
+ writer.append(kw, v);
+ }
+
+ kw.setY(TOPIC_SUM_KEY);
+ v.set(Math.log(total));
+ writer.append(kw, v);
+
+ writer.close();
+ }
+ }
+
+ private static double findLL(String statePath, Configuration job) throws IOException {
+ Path dir = new Path(statePath);
+ FileSystem fs = dir.getFileSystem(job);
+
+ double ll = 0.0;
+
+ IntPairWritable key = new IntPairWritable();
+ DoubleWritable value = new DoubleWritable();
+ for (FileStatus status : fs.globStatus(new Path(dir, "*"))) {
+ Path path = status.getPath();
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
+ while (reader.next(key, value)) {
+ if (key.getX() == LOG_LIKELIHOOD_KEY) {
+ ll = value.get();
+ break;
+ }
+ }
+ reader.close();
+ }
+
+ return ll;
+ }
+
+ /**
+ * Run the job using supplied arguments
+ *
+ * @param input the directory pathname for input points
+ * @param stateIn the directory pathname for input state
+ * @param stateOut the directory pathname for output state
+ * @param modelFactory the class name of the model factory class
+ * @param numTopics the number of clusters
+ * @param alpha_0 alpha_0
+ * @param numReducers the number of Reducers desired
+ */
+ public static double runIteration(String input, String stateIn,
+ String stateOut, int numTopics, int numWords, double topicSmoothing,
+ int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+ Configuration conf = new Configuration();
+ conf.set(STATE_IN_KEY, stateIn);
+ conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
+ conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
+ conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
+
+ Job job = new Job(conf);
+
+ job.setOutputKeyClass(IntPairWritable.class);
+ job.setOutputValueClass(DoubleWritable.class);
+
+ FileInputFormat.addInputPaths(job, input);
+ Path outPath = new Path(stateOut);
+ FileOutputFormat.setOutputPath(job, outPath);
+
+ job.setMapperClass(LDAMapper.class);
+ job.setReducerClass(LDAReducer.class);
+ job.setCombinerClass(LDAReducer.class);
+ job.setNumReduceTasks(numReducers);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+
+ job.waitForCompletion(true);
+ return findLL(stateOut, conf);
+ }
+
+ static LDAState createState(Configuration job) throws IOException {
+ String statePath = job.get(LDADriver.STATE_IN_KEY);
+ int numTopics = Integer.parseInt(job.get(LDADriver.NUM_TOPICS_KEY));
+ int numWords = Integer.parseInt(job.get(LDADriver.NUM_WORDS_KEY));
+ double topicSmoothing = Double.parseDouble(job.get(LDADriver.TOPIC_SMOOTHING_KEY));
+
+ Path dir = new Path(statePath);
+ FileSystem fs = dir.getFileSystem(job);
+
+ DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
+ double[] logTotals = new double[numTopics];
+ double ll = 0.0;
+
+ IntPairWritable key = new IntPairWritable();
+ DoubleWritable value = new DoubleWritable();
+ for (FileStatus status : fs.globStatus(new Path(dir, "*"))) {
+ Path path = status.getPath();
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
+ while (reader.next(key, value)) {
+ int topic = key.getX();
+ int word = key.getY();
+ if (word == TOPIC_SUM_KEY) {
+ logTotals[topic] = value.get();
+ assert !Double.isInfinite(value.get());
+ } else if (topic == LOG_LIKELIHOOD_KEY) {
+ ll = value.get();
+ } else {
+ //System.out.println(topic + " " + word);
+ assert topic >= 0 && word >= 0 : topic + " " + word;
+ assert pWgT.getQuick(topic, word) == 0.0;
+ pWgT.setQuick(topic, word, value.get());
+ assert !Double.isInfinite(pWgT.getQuick(topic, word));
+ }
+ }
+ reader.close();
+ }
+
+ return new LDAState(numTopics, numWords, topicSmoothing,
+ pWgT, logTotals, ll);
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,257 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+import org.apache.commons.math.special.Gamma;
+import org.apache.mahout.matrix.BinaryFunction;
+import org.apache.mahout.matrix.DenseMatrix;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Matrix;
+import org.apache.mahout.matrix.Vector;
+
+
+/**
+* Class for performing infererence on a document, which involves
+* computing (an approximation to) p(word|topic) for each word and
+* topic, and a prior distribution p(topic) for each topic.
+*/
+public class LDAInference {
+ public LDAInference(LDAState state) {
+ this.state = state;
+ }
+
+ /**
+ * An estimate of the probabilitys for each document.
+ * Gamma(k) is the probability of seeing topic k in
+ * the document, phi(k,w) is the probability of
+ * topic k generating w in this document.
+ */
+ public class InferredDocument {
+
+ public final Vector wordCounts;
+ public final Vector gamma; // p(topic)
+ private final Matrix mphi; // log p(columnMap(w)|t)
+ private final Map<Integer, Integer> columnMap; // maps words into the matrix's column map
+ public final double logLikelihood;
+
+ public double phi(int k, int w) {
+ return mphi.getQuick(k, columnMap.get(w));
+ }
+
+ InferredDocument(Vector wordCounts, Vector gamma,
+ Map<Integer, Integer> columnMap, Matrix phi,
+ double ll) {
+ this.wordCounts = wordCounts;
+ this.gamma = gamma;
+ this.mphi = phi;
+ this.columnMap = columnMap;
+ this.logLikelihood = ll;
+ }
+ }
+
+ /**
+ * Performs inference on the given document, returning
+ * an InferredDocument.
+ */
+ public InferredDocument infer(Vector wordCounts) {
+ double docTotal = wordCounts.zSum();
+ int docLength = wordCounts.size();
+
+ // initialize variational approximation to p(z|doc)
+ Vector gamma = new DenseVector(state.numTopics);
+ gamma.assign(state.topicSmoothing + docTotal / state.numTopics);
+ Vector nextGamma = new DenseVector(state.numTopics);
+
+ DenseMatrix phi = new DenseMatrix(state.numTopics, docLength);
+
+ boolean converged = false;
+ double oldLL = 1;
+ // digamma is expensive, precompute
+ Vector digammaGamma = digamma(gamma);
+ // and log normalize:
+ double digammaSumGamma = digamma(gamma.zSum());
+ digammaGamma = digammaGamma.plus(-digammaSumGamma);
+
+ Map<Integer, Integer> columnMap = new HashMap<Integer, Integer>();
+
+ int iteration = 0;
+ final int MAX_ITER = 20;
+
+ while (!converged && iteration < MAX_ITER) {
+ nextGamma.assign(state.topicSmoothing); // nG := alpha, for all topics
+
+ int mapping = 0;
+ for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
+ iter.hasNext();) {
+ Vector.Element e = iter.next();
+ int word = e.index();
+ Vector phiW = eStepForWord(word, digammaGamma);
+ phi.assignColumn(mapping, phiW);
+ if (iteration == 0) { // first iteration
+ columnMap.put(word, mapping);
+ }
+
+ for (int k = 0; k < nextGamma.size(); ++k) {
+ double g = nextGamma.getQuick(k);
+ nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.get(k)));
+ }
+
+ mapping++;
+ }
+
+ Vector tempG = gamma;
+ gamma = nextGamma;
+ nextGamma = tempG;
+
+ // digamma is expensive, precompute
+ digammaGamma = digamma(gamma);
+ // and log normalize:
+ digammaSumGamma = digamma(gamma.zSum());
+ digammaGamma = digammaGamma.plus(-digammaSumGamma);
+
+ double ll = computeLikelihood(wordCounts, columnMap, phi, gamma, digammaGamma);
+ converged = oldLL < 0 && ((oldLL - ll) / oldLL < E_STEP_CONVERGENCE);
+ assert !Double.isNaN(ll);
+
+ oldLL = ll;
+ iteration++;
+ }
+
+ return new InferredDocument(wordCounts, gamma, columnMap, phi, oldLL);
+ }
+
+ private LDAState state;
+
+ private double computeLikelihood(Vector wordCounts, Map<Integer, Integer> columnMap,
+ Matrix phi, Vector gamma, Vector digammaGamma) {
+ double ll = 0.0;
+
+ // log normalizer for q(gamma);
+ ll += Gamma.logGamma(state.topicSmoothing * state.numTopics);
+ ll -= state.numTopics * Gamma.logGamma(state.topicSmoothing);
+ assert !Double.isNaN(ll) : state.topicSmoothing + " " + state.numTopics;
+
+ // now for the the rest of q(gamma);
+ for (int k = 0; k < state.numTopics; ++k) {
+ ll += (state.topicSmoothing - gamma.get(k)) * digammaGamma.get(k);
+ ll += Gamma.logGamma(gamma.get(k));
+
+ }
+ ll -= Gamma.logGamma(gamma.zSum());
+ assert !Double.isNaN(ll);
+
+
+ // for each word
+ for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
+ iter.hasNext();) {
+ Vector.Element e = iter.next();
+ int w = e.index();
+ double n = e.get();
+ int mapping = columnMap.get(w);
+ // now for each topic:
+ for (int k = 0; k < state.numTopics; k++) {
+ double llPart = 0.0;
+ llPart += Math.exp(phi.get(k, mapping))
+ * (digammaGamma.get(k) - phi.get(k, mapping)
+ + state.logProbWordGivenTopic(w, k));
+
+ ll += llPart * n;
+
+ assert state.logProbWordGivenTopic(w, k) < 0;
+ assert !Double.isNaN(llPart);
+ }
+ }
+ assert ll <= 0;
+ return ll;
+ }
+
+ /**
+ * Compute log q(k|w,doc) for each topic k, for a given word.
+ */
+ private Vector eStepForWord(int word, Vector digammaGamma) {
+ Vector phi = new DenseVector(state.numTopics); // log q(k|w), for each w
+ double phiTotal = Double.NEGATIVE_INFINITY; // log Normalizer
+ for (int k = 0; k < state.numTopics; ++k) { // update q(k|w)'s param phi
+ phi.set(k, state.logProbWordGivenTopic(word, k) + digammaGamma.get(k));
+ phiTotal = LDAUtil.logSum(phiTotal, phi.get(k));
+
+ assert !Double.isNaN(phiTotal);
+ assert !Double.isNaN(state.logProbWordGivenTopic(word, k));
+ assert !Double.isInfinite(state.logProbWordGivenTopic(word, k));
+ assert !Double.isNaN(digammaGamma.get(k));
+ }
+ return phi.plus(-phiTotal); // log normalize
+ }
+
+
+ private static Vector digamma(Vector v) {
+ Vector digammaGamma = new DenseVector(v.size());
+ digammaGamma.assign(v, new BinaryFunction() {
+ public double apply(double unused, double g) {
+ return digamma(g);
+ }
+ });
+ return digammaGamma;
+ }
+
+ /**
+ * Approximation to the digamma function, from Radford Neal.
+ *
+ * Original License:
+ * Copyright (c) 1995-2003 by Radford M. Neal
+ *
+ * Permission is granted for anyone to copy, use, modify, or distribute this
+ * program and accompanying programs and documents for any purpose, provided
+ * this copyright notice is retained and prominently displayed, along with
+ * a note saying that the original programs are available from Radford Neal's
+ * web page, and note is made of any changes made to the programs. The
+ * programs and documents are distributed without any warranty, express or
+ * implied. As the programs were written for research purposes only, they have
+ * not been tested to the degree that would be advisable in any important
+ * application. All use of these programs is entirely at the user's own risk.
+ *
+ *
+ * Ported to Java for Mahout.
+ *
+ */
+ private static double digamma(double x) {
+ double r = 0.0;
+
+ while (x <= 5) {
+ r -= 1 / x;
+ x += 1;
+ }
+
+ double f = 1. / (x * x);
+ double t = f * (-1 / 12.0
+ + f * (1 / 120.0
+ + f * (-1 / 252.0
+ + f * (1 / 240.0
+ + f * (-1 / 132.0
+ + f * (691 / 32760.0
+ + f * (-1 / 12.0
+ + f * 3617.0 / 8160.0)))))));
+ return r + Math.log(x) - 0.5 / x + t;
+ }
+
+ private static final double E_STEP_CONVERGENCE = 1E-6;
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.MapContext;
+import org.apache.mahout.matrix.AbstractVector;
+import org.apache.mahout.matrix.Vector;
+
+/**
+* Runs inference on the input documents (which are
+* sparse vectors of word counts) and outputs
+* the sufficient statistics for the word-topic
+* assignments.
+*/
+public class LDAMapper extends
+ Mapper<WritableComparable<?>, Vector, IntPairWritable, DoubleWritable> {
+
+ private LDAState state;
+ private LDAInference infer;
+
+ @Override
+ public void map(WritableComparable<?> key, Vector wordCounts, Context context)
+ throws IOException, InterruptedException {
+ LDAInference.InferredDocument doc = infer.infer(wordCounts);
+
+ double[] logTotals = new double[state.numTopics];
+ Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
+
+ // Output sufficient statistics for each word. == pseudo-log counts.
+ IntPairWritable kw = new IntPairWritable();
+ DoubleWritable v = new DoubleWritable();
+ for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero();
+ iter.hasNext();) {
+ Vector.Element e = iter.next();
+ int w = e.index();
+ kw.setY(w);
+ for (int k = 0; k < state.numTopics; ++k) {
+ v.set(doc.phi(k, w) + Math.log(e.get()));
+
+ kw.setX(k);
+
+ // ouput (topic, word)'s logProb contribution
+ context.write(kw, v);
+ logTotals[k] = LDAUtil.logSum(logTotals[k], v.get());
+ }
+ }
+
+ // Output the totals for the statistics. This is to make
+ // normalizing a lot easier.
+ kw.setY(LDADriver.TOPIC_SUM_KEY);
+ for (int k = 0; k < state.numTopics; ++k) {
+ kw.setX(k);
+ v.set(logTotals[k]);
+ assert !Double.isNaN(v.get());
+ context.write(kw, v);
+ }
+
+ // Output log-likelihoods.
+ kw.setX(LDADriver.LOG_LIKELIHOOD_KEY);
+ kw.setY(LDADriver.LOG_LIKELIHOOD_KEY);
+ v.set(doc.logLikelihood);
+ context.write(kw, v);
+ }
+
+ public void configure(LDAState myState) {
+ this.state = myState;
+ this.infer = new LDAInference(state);
+ }
+
+ public void configure(Configuration job) {
+ try {
+ LDAState myState = LDADriver.createState(job);
+ configure(myState);
+ } catch (IOException e) {
+ throw new RuntimeException("Error creating LDA State!", e);
+ }
+ }
+
+ @Override
+ protected void setup(Context context) {
+ configure(context.getConfiguration());
+ }
+
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,219 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileReader;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.PriorityQueue;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.utils.CommandLineUtil;
+
+/**
+ * Class to print out the top K words for each topic.
+ */
+public class LDAPrintTopics {
+ private LDAPrintTopics() {
+ }
+
+ private static class StringDoublePair implements Comparable<StringDoublePair> {
+ StringDoublePair(double score, String word) {
+ this.score = score;
+ this.word = word;
+ }
+
+ public int compareTo(StringDoublePair other) {
+ return Double.compare(score,other.score);
+ }
+
+ double score;
+ String word;
+ }
+
+ public static List<List<String>> topWordsForTopics(String dir, Configuration job,
+ List<String> wordList, int numWordsToPrint) throws IOException {
+ FileSystem fs = new Path(dir).getFileSystem(job);
+
+ List<PriorityQueue<StringDoublePair>> queues = new ArrayList<PriorityQueue<StringDoublePair>>();
+
+ IntPairWritable key = new IntPairWritable();
+ DoubleWritable value = new DoubleWritable();
+ for (FileStatus status : fs.globStatus(new Path(dir, "*"))) {
+ Path path = status.getPath();
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
+ while (reader.next(key, value)) {
+ int topic = key.getX();
+ int word = key.getY();
+
+ ensureQueueSize(queues,topic);
+ if (word >= 0 && topic >= 0) {
+ double score = value.get();
+ String realWord = wordList.get(word);
+ maybeEnqueue(queues.get(topic), realWord, score, numWordsToPrint);
+ }
+ }
+ reader.close();
+ }
+
+ List<List<String>> result = new ArrayList<List<String>>();
+ for (int i = 0; i < queues.size(); ++i) {
+ result.add(i,new LinkedList<String>());
+ for (StringDoublePair sdp: queues.get(i)) {
+ result.get(i).add(0,sdp.word); // prepend
+ }
+ }
+
+ return result;
+ }
+
+ // Expands the queue list to have a Queue for topic K
+ private static void ensureQueueSize(List<PriorityQueue<StringDoublePair>> queues, int k) {
+ for (int i = queues.size(); i <= k; ++i) {
+ queues.add(new PriorityQueue<StringDoublePair>());
+ }
+ }
+
+ // Adds the word if the queue is below capacity, or the score is high enough
+ private static void maybeEnqueue(PriorityQueue<StringDoublePair> q, String word,
+ double score, int numWordsToPrint) {
+ if (q.size() >= numWordsToPrint && score > q.peek().score) {
+ q.poll();
+ }
+ if (q.size() < numWordsToPrint) {
+ q.add(new StringDoublePair(score,word));
+ }
+ }
+
+ // Reads dictionary in created by the vector Driver in util
+ private static List<String> readDictionary(File path) throws IOException {
+ BufferedReader rdr = new BufferedReader(new FileReader(path));
+
+ List<String> result = new ArrayList<String>();
+
+ // skip 2 lines
+ rdr.readLine();
+ rdr.readLine();
+ String line = null;
+ while ( (line = rdr.readLine()) != null) {
+ String[] parts = line.split("\t");
+ String word = parts[0];
+ int index = Integer.parseInt(parts[2]);
+ assert index == result.size();
+ result.add(word);
+ }
+ rdr.close();
+
+ return result;
+ }
+
+ public static void main(String[] args) {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Path to an LDA output (a state)").withShortName("i").create();
+
+ Option dictOpt = obuilder.withLongName("dict").withRequired(true).withArgument(
+ abuilder.withName("dict").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Dictionary to read in, created by utils.vector.Driver").withShortName("d").create();
+
+ Option outOpt = obuilder.withLongName("output").withRequired(true).withArgument(
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Output directory to write top words").withShortName("o").create();
+
+ Option wordOpt = obuilder.withLongName("words").withRequired(true).withArgument(
+ abuilder.withName("words").withMinimum(0).withMaximum(1).withDefault("20").create()).withDescription(
+ "Number of words to print").withShortName("w").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
+
+ Group group = gbuilder.withName("Options").withOption(dictOpt).withOption(outOpt).withOption(
+ wordOpt).withOption(inputOpt).create();
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+
+ String input = cmdLine.getValue(inputOpt).toString();
+ File output = new File(cmdLine.getValue(outOpt).toString());
+ File dict = new File(cmdLine.getValue(dictOpt).toString());
+ int numWords = 20;
+ if (cmdLine.hasOption(wordOpt)) {
+ numWords = Integer.parseInt(cmdLine.getValue(wordOpt).toString());
+ }
+
+ List<String> wordList = readDictionary(dict);
+
+ Configuration config = new Configuration();
+ List<List<String>> topWords = topWordsForTopics(input, config, wordList, numWords);
+
+ if(!output.exists()) {
+ if (!output.mkdirs()) {
+ throw new IOException("Could not create directory: " + output);
+ }
+ }
+
+ for (int i = 0; i < topWords.size(); ++i) {
+ List<String> topK = topWords.get(i);
+ File out = new File(output,"topic-"+i);
+ PrintWriter writer = new PrintWriter(new FileWriter(out));
+ writer.println("Topic " + i);
+ writer.println("===========");
+ for (String word: topK) {
+ writer.println(word);
+ }
+ writer.close();
+ }
+
+ } catch (OptionException e) {
+ System.err.println("Exception: " + e);
+ CommandLineUtil.printHelp(group);
+ } catch (IOException e) {
+ System.err.println("Exception:" + e);
+ e.printStackTrace();
+ }
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+
+
+/**
+* A very simple reducer which simply logSums the
+* input doubles and outputs a new double for sufficient
+* statistics, and sums log likelihoods.
+*/
+public class LDAReducer extends
+ Reducer<IntPairWritable, DoubleWritable, IntPairWritable, DoubleWritable> {
+
+ @Override
+ public void reduce(IntPairWritable topicWord, Iterable<DoubleWritable> values,
+ Context context)
+ throws java.io.IOException, InterruptedException {
+
+ // sum likelihoods
+ if (topicWord.getY() == LDADriver.LOG_LIKELIHOOD_KEY) {
+ double accum = 0.0;
+ for (DoubleWritable vw : values) {
+ double v = vw.get();
+ assert !Double.isNaN(v) : topicWord.getX() + " " + topicWord.getY();
+ accum += v;
+ }
+ context.write(topicWord, new DoubleWritable(accum));
+ } else { // log sum sufficient statistics.
+ double accum = Double.NEGATIVE_INFINITY;
+ for (DoubleWritable vw : values) {
+ double v = vw.get();
+ assert !Double.isNaN(v) : topicWord.getX() + " " + topicWord.getY();
+ accum = LDAUtil.logSum(accum, v);
+ assert !Double.isNaN(accum) : topicWord.getX() + " " + topicWord.getY();
+ }
+ context.write(topicWord, new DoubleWritable(accum));
+ }
+
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda;
+
+import org.apache.mahout.matrix.Matrix;
+
+public class LDAState {
+ public final int numTopics;
+ public final int numWords;
+ public final double topicSmoothing;
+ private final Matrix topicWordProbabilities; // log p(w|t) for topic=1..nTopics
+ private final double[] logTotals; // log \sum p(w|t) for topic=1..nTopics
+ public final double logLikelihood; // log \sum p(w|t) for topic=1..nTopics
+
+ public LDAState(int numTopics, int numWords, double topicSmoothing,
+ Matrix topicWordProbabilities, double[] logTotals, double ll) {
+ this.numWords = numWords;
+ this.numTopics = numTopics;
+ this.topicSmoothing = topicSmoothing;
+ this.topicWordProbabilities = topicWordProbabilities;
+ this.logTotals = logTotals;
+ this.logLikelihood = ll;
+ }
+
+ public double logProbWordGivenTopic(int word, int topic) {
+ final double logProb = topicWordProbabilities.getQuick(topic, word);
+ return logProb == Double.NEGATIVE_INFINITY ? -100.0
+ : logProb - logTotals[topic];
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda;
+
+/**
+ * Various utility classes for doing LDA inference..
+ */
+final class LDAUtil {
+ private LDAUtil() {
+ } // no creation
+
+ /**
+ * @return log(exp(a) + exp(b))
+ */
+ static double logSum(double a, double b) {
+ return (a == Double.NEGATIVE_INFINITY) ? b
+ : (b == Double.NEGATIVE_INFINITY) ? a
+ : (a < b) ? b + Math.log(1 + Math.exp(a - b))
+ : a + Math.log(1 + Math.exp(b - a));
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAUtil.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,122 @@
+package org.apache.mahout.clustering.lda;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+import org.apache.commons.math.distribution.PoissonDistribution;
+import org.apache.commons.math.distribution.PoissonDistributionImpl;
+
+import org.apache.mahout.matrix.DenseMatrix;
+import org.apache.mahout.matrix.DenseVector;
+import org.apache.mahout.matrix.Matrix;
+import org.apache.mahout.matrix.Vector;
+
+public class TestLDAInference extends TestCase {
+
+ private Random random;
+
+ private static int NUM_TOPICS = 20;
+
+ @Override
+ protected void setUp() throws Exception {
+ super.setUp();
+ random = new Random();
+ }
+
+ /**
+ * Generate random document vector
+ * @param numWords int number of words in the vocabulary
+ * @param numWords E[count] for each word
+ */
+ private Vector generateRandomDoc(int numWords, double sparsity) {
+ Vector v = new DenseVector(numWords);
+ try {
+ PoissonDistribution dist = new PoissonDistributionImpl(sparsity);
+ for (int i = 0; i < numWords; i++) {
+ // random integer
+ v.setQuick(i, dist.inverseCumulativeProbability(random.nextDouble()) + 1);
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ fail("Caught " + e.toString());
+ }
+ return v;
+ }
+
+ private LDAState generateRandomState(int numWords, int numTopics) {
+ double topicSmoothing = 50.0 / numTopics; // whatever
+ Matrix m = new DenseMatrix(numTopics, numWords);
+ double[] logTotals = new double[numTopics];
+ double ll = Double.NEGATIVE_INFINITY;
+
+ for (int k = 0; k < numTopics; ++k) {
+ double total = 0.0; // total number of pseudo counts we made
+ for (int w = 0; w < numWords; ++w) {
+ // A small amount of random noise, minimized by having a floor.
+ double pseudocount = random.nextDouble() + 1E-10;
+ total += pseudocount;
+ m.setQuick(k, w, Math.log(pseudocount));
+ }
+
+ logTotals[k] = Math.log(total);
+ }
+
+ return new LDAState(numTopics, numWords, topicSmoothing, m, logTotals, ll);
+ }
+
+
+ private void runTest(int numWords, double sparsity, int numTests) {
+ LDAState state = generateRandomState(numWords, NUM_TOPICS);
+ LDAInference lda = new LDAInference(state);
+ for (int t = 0; t < numTests; ++t) {
+ Vector v = generateRandomDoc(numWords, sparsity);
+ LDAInference.InferredDocument doc = lda.infer(v);
+
+ assertEquals("wordCounts", doc.wordCounts, v);
+ assertNotNull("gamma", doc.gamma);
+ for (Iterator<Vector.Element> iter = v.iterateNonZero();
+ iter.hasNext(); ) {
+ int w = iter.next().index();
+ for (int k = 0; k < NUM_TOPICS; ++k) {
+ double logProb = doc.phi(k, w);
+ assertTrue(k + " " + w + " logProb " + logProb, logProb <= 0.0);
+ }
+ }
+ assertTrue("log likelihood", doc.logLikelihood <= 1E-10);
+ }
+ }
+
+
+ public void testLDAEasy() {
+ runTest(10, 1, 5); // 1 word per doc in expectation
+ }
+
+ public void testLDASparse() {
+ runTest(100, 0.4, 5); // 40 words per doc in expectation
+ }
+
+ public void testLDADense() {
+ runTest(100, 3, 5); // 300 words per doc in expectation
+ }
+}
Propchange: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestLDAInference.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java Mon Aug 17 13:33:18 2009
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import junit.framework.TestCase;
+
+
+import org.apache.commons.math.distribution.PoissonDistribution;
+import org.apache.commons.math.distribution.PoissonDistributionImpl;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.matrix.DenseMatrix;
+import org.apache.mahout.matrix.Matrix;
+import org.apache.mahout.matrix.SparseVector;
+import org.apache.mahout.matrix.Vector;
+import org.apache.mahout.utils.DummyOutputCollector;
+
+import static org.easymock.classextension.EasyMock.*;
+
+public class TestMapReduce extends TestCase {
+
+
+ private Random random;
+
+ /**
+ * Generate random document vector
+ * @param numWords int number of words in the vocabulary
+ * @param numWords E[count] for each word
+ */
+ private SparseVector generateRandomDoc(int numWords, double sparsity) {
+ SparseVector v = new SparseVector(numWords,(int)(numWords * sparsity));
+ try {
+ PoissonDistribution dist = new PoissonDistributionImpl(sparsity);
+ for (int i = 0; i < numWords; i++) {
+ // random integer
+ v.set(i,dist.inverseCumulativeProbability(random.nextDouble()) + 1);
+ }
+ } catch(Exception e) {
+ e.printStackTrace();
+ fail("Caught " + e.toString());
+ }
+ return v;
+ }
+
+ private LDAState generateRandomState(int numWords, int numTopics) {
+ double topicSmoothing = 50.0 / numTopics; // whatever
+ Matrix m = new DenseMatrix(numTopics,numWords);
+ double[] logTotals = new double[numTopics];
+ double ll = Double.NEGATIVE_INFINITY;
+ for(int k = 0; k < numTopics; ++k) {
+ double total = 0.0; // total number of pseudo counts we made
+ for(int w = 0; w < numWords; ++w) {
+ // A small amount of random noise, minimized by having a floor.
+ double pseudocount = random.nextDouble() + 1E-10;
+ total += pseudocount;
+ m.setQuick(k,w,Math.log(pseudocount));
+ }
+
+ logTotals[k] = Math.log(total);
+ }
+
+ return new LDAState(numTopics,numWords,topicSmoothing,m,logTotals,ll);
+ }
+
+ @Override
+ protected void setUp() throws Exception {
+ super.setUp();
+ File f = new File("input");
+ random = new Random();
+ f.mkdir();
+ }
+
+ private static int NUM_TESTS = 10;
+ private static int NUM_TOPICS = 10;
+
+ /**
+ * Test the basic Mapper
+ *
+ * @throws Exception
+ */
+ public void testMapper() throws Exception {
+ LDAState state = generateRandomState(100,NUM_TOPICS);
+ LDAMapper mapper = new LDAMapper();
+ mapper.configure(state);
+
+ for(int i = 0; i < NUM_TESTS; ++i) {
+ SparseVector v = generateRandomDoc(100,0.3);
+ int myNumWords = numNonZero(v);
+ LDAMapper.Context mock = createMock(LDAMapper.Context.class);
+
+ mock.write(isA(IntPairWritable.class),isA(DoubleWritable.class));
+ expectLastCall().times(myNumWords * NUM_TOPICS + NUM_TOPICS + 1);
+ replay(mock);
+
+ mapper.map(new Text("tstMapper"), v, mock);
+ verify(mock);
+ }
+ }
+
+ private int numNonZero(Vector v) {
+ int count = 0;
+ for(Iterator<Vector.Element> iter = v.iterateNonZero();
+ iter.hasNext();iter.next() ) {
+ count++;
+ }
+ return count;
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/examples/bin/build-reuters.sh
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/bin/build-reuters.sh?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/bin/build-reuters.sh (added)
+++ lucene/mahout/trunk/examples/bin/build-reuters.sh Mon Aug 17 13:33:18 2009
@@ -0,0 +1,60 @@
+#/**
+# * Licensed to the Apache Software Foundation (ASF) under one or more
+# * contributor license agreements. See the NOTICE file distributed with
+# * this work for additional information regarding copyright ownership.
+# * The ASF licenses this file to You under the Apache License, Version 2.0
+# * (the "License"); you may not use this file except in compliance with
+# * the License. You may obtain a copy of the License at
+# *
+# * http://www.apache.org/licenses/LICENSE-2.0
+# *
+# * Unless required by applicable law or agreed to in writing, software
+# * distributed under the License is distributed on an "AS IS" BASIS,
+# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# * See the License for the specific language governing permissions and
+# * limitations under the License.
+# */
+
+#
+# Runs the LDA examples using Reuters.
+#
+# To run: change into the mahout/examples directory (the parent of the one containing this file) and type:
+# bin/build-reuters.sh
+#
+#
+mkdir -p work
+if [ ! -e work/reuters-out ]; then
+ if [ ! -e work/reuters-sgm ]; then
+ if [ ! -f work/reuters21578.tar.gz ]; then
+ echo "Downloading Reuters-21578"
+ curl http://kdd.ics.uci.edu/databases/reuters21578/reuters21578.tar.gz -o work/reuters21578.tar.gz
+ fi
+ mkdir -p work/reuters-sgm
+ echo "Extracting..."
+ cd work/reuters-sgm && tar xzf ../reuters21578.tar.gz && cd .. && cd ..
+ fi
+ echo "Converting to plain text."
+ mvn -e -q exec:java -Dexec.mainClass="org.apache.lucene.benchmark.utils.ExtractReuters" -Dexec.args="work/reuters-sgm work/reuters-out" || exit
+fi
+# Create index
+if [ ! -e work/index ]; then
+ echo "Creating index";
+ mvn -e exec:java -Dexec.classpathScope="test" -Dexec.mainClass="org.apache.lucene.benchmark.byTask.Benchmark" -Dexec.args="bin/lda.algorithm" || ( rm -rf work/index && exit )
+fi
+if [ ! -e work/vectors ]; then
+ echo "Creating vectors from index"
+ cd ../core
+ mvn -q install -DskipTests=true
+ cd ../utils/
+ mvn -q compile
+ mvn -e exec:java -Dexec.mainClass="org.apache.mahout.utils.vectors.lucene.Driver" \
+ -Dexec.args="--dir ../examples/work/index/ --field body --dictOut ../examples/work/dict.txt \
+ --output ../examples/work/vectors --minDF 100 --maxDFPercent 97" || exit
+ cd ../core/
+fi
+echo "Running LDA"
+rm -rf ../examples/work/lda
+MAVEN_OPTS="-Xmx2G -ea" mvn -e exec:java -Dexec.mainClass=org.apache.mahout.clustering.lda.LDADriver -Dexec.args="-i ../examples/work/vectors -o ../examples/work/lda/\
+ -k 20 -v 10000 --maxIter 40"
+echo "Writing top words for each topic to to examples/work/topics/"
+mvn -q exec:java -Dexec.mainClass="org.apache.mahout.clustering.lda.LDAPrintTopics" -Dexec.args="-i `ls -1dtr ../examples/work/lda/state-* | tail -1` -d ../examples/work/dict.txt -o ../examples/work/topics/ -w 100"
Propchange: lucene/mahout/trunk/examples/bin/build-reuters.sh
------------------------------------------------------------------------------
svn:eol-style = native
Propchange: lucene/mahout/trunk/examples/bin/build-reuters.sh
------------------------------------------------------------------------------
svn:executable = *
Added: lucene/mahout/trunk/examples/bin/lda.algorithm
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/bin/lda.algorithm?rev=804979&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/bin/lda.algorithm (added)
+++ lucene/mahout/trunk/examples/bin/lda.algorithm Mon Aug 17 13:33:18 2009
@@ -0,0 +1,30 @@
+merge.policy=org.apache.lucene.index.LogDocMergePolicy
+merge.factor=mrg:10:20
+max.buffered=buf:100:1000
+compound=true
+
+analyzer=org.apache.lucene.analysis.standard.StandardAnalyzer
+directory=FSDirectory
+
+doc.stored=true
+doc.term.vector=true
+doc.tokenized=true
+log.step=600
+
+content.source=org.apache.lucene.benchmark.byTask.feeds.ReutersContentSource
+content.source.forever=false
+doc.maker.forever=false
+query.maker=org.apache.lucene.benchmark.byTask.feeds.SimpleQueryMaker
+
+# task at this depth or less would print when they start
+task.max.depth.log=2
+
+log.queries=false
+# --------- alg
+{ "BuildReuters"
+ CreateIndex
+ { "AddDocs" AddDoc > : *
+# Optimize
+ CloseIndex
+}
+
Modified: lucene/mahout/trunk/examples/pom.xml
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/pom.xml?rev=804979&r1=804978&r2=804979&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/pom.xml (original)
+++ lucene/mahout/trunk/examples/pom.xml Mon Aug 17 13:33:18 2009
@@ -180,6 +180,11 @@
<artifactId>mahout-core</artifactId>
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>org.apache.mahout</groupId>
+ <artifactId>mahout-utils</artifactId>
+ <version>${project.version}</version>
+ </dependency>
<!-- A Lucene wikipedia tokenizer/analyzer -->
@@ -197,6 +202,17 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>org.apache.lucene</groupId>
+ <artifactId>lucene-benchmark</artifactId>
+ <version>2.9-SNAPSHOT</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-compress</artifactId>
+ <version>1.0</version>
+
+ </dependency>
<dependency>
<groupId>org.apache.openejb</groupId>