You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by jm...@apache.org on 2011/05/07 02:18:52 UTC
svn commit: r1100421 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/lda/
core/src/test/java/org/apache/mahout/clustering/lda/
core/src/test/java/org/apache/mahout/math/hadoop/decomposer/
utils/src/main/java/org/apache/mahout/cluste...
Author: jmannix
Date: Sat May 7 00:18:51 2011
New Revision: 1100421
URL: http://svn.apache.org/viewvc?rev=1100421&view=rev
Log:
Fixes MAHOUT-683 and MAHOUT-682
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADocumentTopicMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAWordTopicMapper.java
Removed:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java
mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADocumentTopicMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADocumentTopicMapper.java?rev=1100421&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADocumentTopicMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADocumentTopicMapper.java Sat May 7 00:18:51 2011
@@ -0,0 +1,52 @@
+package org.apache.mahout.clustering.lda;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+public class LDADocumentTopicMapper extends Mapper<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable> {
+
+ private LDAState state;
+ private LDAInference infer;
+
+ @Override
+ protected void map(WritableComparable<?> key,
+ VectorWritable wordCountsWritable,
+ Context context) throws IOException, InterruptedException {
+
+ Vector wordCounts = wordCountsWritable.get();
+ LDAInference.InferredDocument doc;
+ try {
+ doc = infer.infer(wordCounts);
+ context.write(key, new VectorWritable(doc.getGamma().normalize(1)));
+ } catch (ArrayIndexOutOfBoundsException e1) {
+ throw new IllegalStateException(
+ "This is probably because the --numWords argument is set too small. \n"
+ + "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
+ + "\tlarger if some storage inefficiency can be tolerated.", e1);
+ }
+ }
+
+ public void configure(LDAState myState) {
+ this.state = myState;
+ this.infer = new LDAInference(state);
+ }
+
+ public void configure(Configuration job) {
+ try {
+ LDAState myState = LDADriver.createState(job);
+ configure(myState);
+ } catch (IOException e) {
+ throw new IllegalStateException("Error creating LDA State!", e);
+ }
+ }
+
+ @Override
+ protected void setup(Context context) {
+ configure(context.getConfiguration());
+ }
+}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Sat May 7 00:18:51 2011
@@ -17,21 +17,21 @@
package org.apache.mahout.clustering.lda;
-import java.io.IOException;
-import java.util.Random;
-
+import com.google.common.base.Preconditions;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ReflectionUtils;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
@@ -43,10 +43,17 @@ import org.apache.mahout.common.iterator
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterator;
import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
/**
* Estimates an LDA model from a corpus of documents, which are SparseVectors of word counts. At each phase,
@@ -57,6 +64,8 @@ public final class LDADriver extends Abs
private static final String TOPIC_SMOOTHING_OPTION = "topicSmoothing";
private static final String NUM_WORDS_OPTION = "numWords";
private static final String NUM_TOPICS_OPTION = "numTopics";
+ // TODO: sequential iteration is not yet correct.
+ // private static final String SEQUENTIAL_OPTION = "sequential";
static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
@@ -67,6 +76,14 @@ public final class LDADriver extends Abs
private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
+ private LDAState state = null;
+
+ private LDAState newState = null;
+
+ private LDAInference inference = null;
+
+ private Iterable<Pair<Writable, VectorWritable>> trainingCorpus = null;
+
private LDADriver() {
}
@@ -74,7 +91,11 @@ public final class LDADriver extends Abs
ToolRunner.run(new Configuration(), new LDADriver(), args);
}
- static LDAState createState(Configuration job) {
+ public static LDAState createState(Configuration job) throws IOException {
+ return createState(job, false);
+ }
+
+ public static LDAState createState(Configuration job, boolean empty) throws IOException {
String statePath = job.get(STATE_IN_KEY);
int numTopics = Integer.parseInt(job.get(NUM_TOPICS_KEY));
int numWords = Integer.parseInt(job.get(NUM_WORDS_KEY));
@@ -82,10 +103,14 @@ public final class LDADriver extends Abs
Path dir = new Path(statePath);
+ // TODO scalability bottleneck: numWords * numTopics * 8bytes for the driver *and* M/R classes
DenseMatrix pWgT = new DenseMatrix(numTopics, numWords);
double[] logTotals = new double[numTopics];
+ Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
double ll = 0.0;
-
+ if(empty) {
+ return new LDAState(numTopics, numWords, topicSmoothing, pWgT, logTotals, ll);
+ }
for (Pair<IntPairWritable,DoubleWritable> record
: new SequenceFileDirIterable<IntPairWritable, DoubleWritable>(new Path(dir, "part-*"),
PathType.GLOB,
@@ -126,6 +151,7 @@ public final class LDADriver extends Abs
"v",
"The total number of words in the corpus (can be approximate, needs to exceed the actual value)");
addOption(TOPIC_SMOOTHING_OPTION, "a", "Topic smoothing parameter. Default is 50/numTopics.", "-1.0");
+ // addOption(SEQUENTIAL_OPTION, "seq", "Run sequentially (not Hadoop-based). Default is false.", "false");
addOption(DefaultOptionCreator.maxIterationsOption().withRequired(false).create());
if (parseArguments(args) == null) {
@@ -144,30 +170,61 @@ public final class LDADriver extends Abs
if (topicSmoothing < 1) {
topicSmoothing = 50.0 / numTopics;
}
+ boolean runSequential = false; // Boolean.parseBoolean(getOption(SEQUENTIAL_OPTION));
- run(getConf(), input, output, numTopics, numWords, topicSmoothing, maxIterations);
+ run(getConf(), input, output, numTopics, numWords, topicSmoothing, maxIterations, runSequential);
return 0;
}
- private static void run(Configuration conf,
+ private Path getLastKnownStatePath(Configuration conf, Path stateDir) throws IOException {
+ FileSystem fs = FileSystem.get(conf);
+ Path lastPath = null;
+ int maxIteration = Integer.MIN_VALUE;
+ for(FileStatus fstatus : fs.globStatus(new Path(stateDir, "state-*"))) {
+ try {
+ int iteration = Integer.parseInt(fstatus.getPath().getName().split("-")[1]);
+ if(iteration > maxIteration) {
+ maxIteration = iteration;
+ lastPath = fstatus.getPath();
+ }
+ } catch(NumberFormatException nfe) {
+ throw new IOException(nfe);
+ }
+ }
+ return lastPath;
+ }
+
+ private void run(Configuration conf,
Path input,
Path output,
int numTopics,
int numWords,
double topicSmoothing,
- int maxIterations)
+ int maxIterations,
+ boolean runSequential)
throws IOException, InterruptedException, ClassNotFoundException {
- Path stateIn = new Path(output, "state-0");
- writeInitialState(stateIn, numTopics, numWords);
+ Path lastKnownState = getLastKnownStatePath(conf, output);
+ Path stateIn;
+ if(lastKnownState == null) {
+ stateIn = new Path(output, "state-0");
+ writeInitialState(stateIn, numTopics, numWords);
+ } else {
+ stateIn = lastKnownState;
+ }
+ conf.set(STATE_IN_KEY, stateIn.toString());
+ conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
+ conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
+ conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
double oldLL = Double.NEGATIVE_INFINITY;
boolean converged = false;
-
- for (int iteration = 1; (maxIterations < 1 || iteration <= maxIterations) && !converged; iteration++) {
+ int iteration = Integer.parseInt(stateIn.getName().split("-")[1]) + 1;
+ for (; ((maxIterations < 1) || (iteration <= maxIterations)) && !converged; iteration++) {
log.info("LDA Iteration {}", iteration);
+ conf.set(STATE_IN_KEY, stateIn.toString());
// point the output to a new directory per iteration
Path stateOut = new Path(output, "state-" + iteration);
- double ll = runIteration(conf, input, stateIn, stateOut, numTopics, numWords, topicSmoothing);
+ double ll = runSequential ? runIterationSequential(conf, input, stateOut) : runIteration(conf, input, stateIn, stateOut);
double relChange = (oldLL - ll) / oldLL;
// now point the input to the old output directory
@@ -175,13 +232,18 @@ public final class LDADriver extends Abs
log.info("(Old LL: {})", oldLL);
log.info("(Rel Change: {})", relChange);
- converged = iteration > 3 && relChange < OVERALL_CONVERGENCE;
+ converged = (iteration > 3) && (relChange < OVERALL_CONVERGENCE);
stateIn = stateOut;
oldLL = ll;
}
+ if(runSequential) {
+ computeDocumentTopicProbabilitiesSequential(conf, input, new Path(output, "docTopics"));
+ } else {
+ computeDocumentTopicProbabilities(conf, input, stateIn, new Path(output, "docTopics"), numTopics, numWords, topicSmoothing);
+ }
}
- private static void writeInitialState(Path statePath, int numTopics, int numWords) throws IOException {
+ private void writeInitialState(Path statePath, int numTopics, int numWords) throws IOException {
Configuration job = new Configuration();
FileSystem fs = statePath.getFileSystem(job);
@@ -210,7 +272,33 @@ public final class LDADriver extends Abs
}
}
- private static double findLL(Path statePath, Configuration job) throws IOException {
+ private void writeState(Configuration job, LDAState state, Path statePath) throws IOException {
+ FileSystem fs = statePath.getFileSystem(job);
+ DoubleWritable v = new DoubleWritable();
+
+ for (int k = 0; k < state.getNumTopics(); ++k) {
+ Path path = new Path(statePath, "part-" + k);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class);
+
+ for (int w = 0; w < state.getNumWords(); ++w) {
+ Writable kw = new IntPairWritable(k, w);
+ v.set(state.logProbWordGivenTopic(w,k) + state.getLogTotal(k));
+ writer.append(kw, v);
+ }
+ Writable kTsk = new IntPairWritable(k, TOPIC_SUM_KEY);
+ v.set(state.getLogTotal(k));
+ writer.append(kTsk, v);
+ writer.close();
+ }
+ Path path = new Path(statePath, "part-" + LOG_LIKELIHOOD_KEY);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class, DoubleWritable.class);
+ Writable kTsk = new IntPairWritable(LOG_LIKELIHOOD_KEY,LOG_LIKELIHOOD_KEY);
+ v.set(state.getLogLikelihood());
+ writer.append(kTsk, v);
+ writer.close();
+ }
+
+ private double findLL(Path statePath, Configuration job) throws IOException {
FileSystem fs = statePath.getFileSystem(job);
double ll = 0.0;
for (FileStatus status : fs.globStatus(new Path(statePath, "part-*"))) {
@@ -229,6 +317,66 @@ public final class LDADriver extends Abs
return ll;
}
+ private double runIterationSequential(Configuration conf, Path input, Path stateOut)
+ throws IOException, InterruptedException {
+ if(state == null) {
+ state = createState(conf);
+ }
+ if(trainingCorpus == null) {
+ Class<? extends Writable> keyClass = peekAtSequenceFileForKeyType(conf, input);
+ List<Pair<Writable,VectorWritable>> corpus = new LinkedList<Pair<Writable, VectorWritable>>();
+ for(FileStatus fileStatus : FileSystem.get(conf).globStatus(new Path(input, "part-*"))) {
+ Path inputPart = fileStatus.getPath();
+ SequenceFile.Reader reader = new SequenceFile.Reader(FileSystem.get(conf), inputPart, conf);
+ Writable key = ReflectionUtils.newInstance(keyClass, conf);
+ VectorWritable value = new VectorWritable();
+ while(reader.next(key, value)) {
+ Writable nextKey = ReflectionUtils.newInstance(keyClass, conf);
+ VectorWritable nextValue = new VectorWritable();
+ corpus.add(new Pair<Writable,VectorWritable>(key, value));
+ key = nextKey;
+ value = nextValue;
+ }
+ }
+ trainingCorpus = corpus;
+ }
+ if(inference == null) {
+ inference = new LDAInference(state);
+ }
+ double ll = 0;
+ newState = createState(conf, true);
+ for(Pair<Writable, VectorWritable> slice : trainingCorpus) {
+ LDAInference.InferredDocument doc;
+ Vector wordCounts = slice.getSecond().get();
+ try {
+ doc = inference.infer(wordCounts);
+ } catch (ArrayIndexOutOfBoundsException e1) {
+ throw new IllegalStateException(
+ "This is probably because the --numWords argument is set too small. \n"
+ + "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
+ + "\tlarger if some storage inefficiency can be tolerated.", e1);
+ }
+
+ for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
+ Vector.Element e = iter.next();
+ int w = e.index();
+
+ for (int k = 0; k < state.getNumTopics(); ++k) {
+ double vwUpdate = doc.phi(k, w) + Math.log(e.get());
+ newState.updateLogProbGivenTopic(w, k, vwUpdate); // update state.topicWordProbabilities[v,w]!
+ newState.updateLogTotals(k, vwUpdate);
+ }
+ ll += doc.getLogLikelihood();
+ }
+ }
+ newState.setLogLikelihood(ll);
+ writeState(conf, newState, stateOut);
+ state = newState;
+ newState = null;
+
+ return ll;
+ }
+
/**
* Run the job using supplied arguments
* @param input
@@ -237,21 +385,13 @@ public final class LDADriver extends Abs
* the directory pathname for input state
* @param stateOut
* the directory pathname for output state
- * @param numTopics
- * the number of clusters
*/
- private static double runIteration(Configuration conf,
+ private double runIteration(Configuration conf,
Path input,
Path stateIn,
- Path stateOut,
- int numTopics,
- int numWords,
- double topicSmoothing)
+ Path stateOut)
throws IOException, InterruptedException, ClassNotFoundException {
conf.set(STATE_IN_KEY, stateIn.toString());
- conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
- conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
- conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
Job job = new Job(conf, "LDA Driver running runIteration over stateIn: " + stateIn);
job.setOutputKeyClass(IntPairWritable.class);
@@ -259,7 +399,7 @@ public final class LDADriver extends Abs
FileInputFormat.addInputPaths(job, input.toString());
FileOutputFormat.setOutputPath(job, stateOut);
- job.setMapperClass(LDAMapper.class);
+ job.setMapperClass(LDAWordTopicMapper.class);
job.setReducerClass(LDAReducer.class);
job.setCombinerClass(LDAReducer.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
@@ -271,4 +411,69 @@ public final class LDADriver extends Abs
}
return findLL(stateOut, conf);
}
+
+ private void computeDocumentTopicProbabilities(Configuration conf,
+ Path input,
+ Path stateIn,
+ Path outputPath,
+ int numTopics,
+ int numWords,
+ double topicSmoothing)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ conf.set(STATE_IN_KEY, stateIn.toString());
+ conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
+ conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
+ conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
+
+ Job job = new Job(conf, "LDA Driver computing p(topic|doc) for all docs/topics with stateIn: " + stateIn);
+ job.setOutputKeyClass(peekAtSequenceFileForKeyType(conf, input));
+ job.setOutputValueClass(VectorWritable.class);
+ FileInputFormat.addInputPaths(job, input.toString());
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ job.setMapperClass(LDADocumentTopicMapper.class);
+ job.setNumReduceTasks(0);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setJarByClass(LDADriver.class);
+
+ if (job.waitForCompletion(true) == false) {
+ throw new InterruptedException("LDA failed to compute and output document topic probabilities with: "+ stateIn);
+ }
+ }
+
+ private void computeDocumentTopicProbabilitiesSequential(Configuration conf, Path input, Path outputPath)
+ throws IOException, ClassNotFoundException {
+ FileSystem fs = input.getFileSystem(conf);
+ Class<? extends Writable> keyClass = peekAtSequenceFileForKeyType(conf, input);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, outputPath, keyClass, VectorWritable.class);
+
+ Writable key = ReflectionUtils.newInstance(keyClass, conf);
+ VectorWritable vw = new VectorWritable();
+
+ for(Pair<Writable, VectorWritable> slice : trainingCorpus) {
+ LDAInference.InferredDocument doc;
+ Vector wordCounts = slice.getSecond().get();
+ try {
+ doc = inference.infer(wordCounts);
+ } catch (ArrayIndexOutOfBoundsException e1) {
+ throw new IllegalStateException(
+ "This is probably because the --numWords argument is set too small. \n"
+ + "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
+ + "\tlarger if some storage inefficiency can be tolerated.", e1);
+ }
+ writer.append(key, vw);
+ }
+
+ writer.close();
+ }
+
+ private static Class<? extends Writable> peekAtSequenceFileForKeyType(Configuration conf, Path input) {
+ try {
+ SequenceFile.Reader reader = new SequenceFile.Reader(FileSystem.get(conf), input, conf);
+ return (Class<? extends Writable>) reader.getKeyClass();
+ } catch(IOException ioe) {
+ return Text.class;
+ }
+ }
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java Sat May 7 00:18:51 2011
@@ -43,8 +43,8 @@ public class LDAInference {
}
/**
- * An estimate of the probabilitys for each document. Gamma(k) is the probability of seeing topic k in the
- * document, phi(k,w) is the probability of topic k generating w in this document.
+ * An estimate of the probabilities for each document. Gamma(k) is the probability of seeing topic k in the
+ * document, phi(k,w) is the (log) probability of topic k generating w in this document.
*/
public static class InferredDocument {
@@ -137,7 +137,11 @@ public class LDAInference {
return new InferredDocument(wordCounts, gamma, map, phi, oldLL);
}
-
+
+ /**
+ * @param gamma
+ * @return a vector whose entries are digamma(oldEntry) - digamma(gamma.zSum())
+ */
private Vector digammaGamma(Vector gamma) {
// digamma is expensive, precompute
Vector digammaGamma = digamma(gamma);
@@ -156,7 +160,21 @@ public class LDAInference {
phi.assign(0);
}
}
-
+
+ /**
+ * diGamma(x) = gamma'(x)/gamma(x)
+ * logGamma(x) = log(gamma(x))
+ *
+ * ll = log(gamma(smooth*numTop) / smooth^numTop) +
+ * sum_{i < numTop} (smooth - g[i])*(digamma(g[i]) - digamma(|g|)) + log(gamma(g[i])
+ * Computes the log likelihood of the wordCounts vector, given \phi, \gamma, and \digamma(gamma)
+ * @param wordCounts
+ * @param map
+ * @param phi
+ * @param gamma
+ * @param digammaGamma
+ * @return
+ */
private double computeLikelihood(Vector wordCounts, int[] map, Matrix phi, Vector gamma, Vector digammaGamma) {
double ll = 0.0;
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAState.java Sat May 7 00:18:51 2011
@@ -25,7 +25,7 @@ public class LDAState {
private final double topicSmoothing;
private final Matrix topicWordProbabilities; // log p(w|t) for topic=1..nTopics
private final double[] logTotals; // log \sum p(w|t) for topic=1..nTopics
- private final double logLikelihood; // log \sum p(w|t) for topic=1..nTopics
+ private double logLikelihood; // log \sum p(w|t) for topic=1..nTopics
public LDAState(int numTopics,
int numWords,
@@ -42,10 +42,14 @@ public class LDAState {
}
public double logProbWordGivenTopic(int word, int topic) {
- double logProb = topicWordProbabilities.getQuick(topic, word);
+ double logProb = topicWordProbabilities.get(topic, word);
return logProb == Double.NEGATIVE_INFINITY ? -100.0 : logProb - logTotals[topic];
}
+ public double getLogTotal(int topic) {
+ return logTotals[topic];
+ }
+
public int getNumTopics() {
return numTopics;
}
@@ -62,4 +66,16 @@ public class LDAState {
return logLikelihood;
}
+ public void updateLogProbGivenTopic(int word, int topic, double logProbGivenTopic) {
+ topicWordProbabilities.set(topic, word, LDAUtil.logSum(logProbGivenTopic, topicWordProbabilities.getQuick(topic, word)));
+ }
+
+ public void updateLogTotals(int topic, double logTotal) {
+ logTotals[topic] = LDAUtil.logSum(logTotals[topic], logTotal);
+ }
+
+ public void setLogLikelihood(double logLikelihood) {
+ this.logLikelihood = logLikelihood;
+ }
+
}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAWordTopicMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAWordTopicMapper.java?rev=1100421&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAWordTopicMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAWordTopicMapper.java Sat May 7 00:18:51 2011
@@ -0,0 +1,110 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.lda;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.IntPairWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * Runs inference on the input documents (which are sparse vectors of word counts) and outputs the sufficient
+ * statistics for the word-topic assignments.
+ */
+public class LDAWordTopicMapper extends Mapper<WritableComparable<?>,VectorWritable,IntPairWritable,DoubleWritable> {
+
+ private LDAState state;
+ private LDAInference infer;
+
+ @Override
+ protected void map(WritableComparable<?> key,
+ VectorWritable wordCountsWritable,
+ Context context) throws IOException, InterruptedException {
+ Vector wordCounts = wordCountsWritable.get();
+ LDAInference.InferredDocument doc;
+ try {
+ doc = infer.infer(wordCounts);
+ } catch (ArrayIndexOutOfBoundsException e1) {
+ throw new IllegalStateException(
+ "This is probably because the --numWords argument is set too small. \n"
+ + "\tIt needs to be >= than the number of words (terms actually) in the corpus and can be \n"
+ + "\tlarger if some storage inefficiency can be tolerated.", e1);
+ }
+
+ double[] logTotals = new double[state.getNumTopics()];
+ Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
+
+ // Output sufficient statistics for each word. == pseudo-log counts.
+ DoubleWritable v = new DoubleWritable();
+ for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
+ Vector.Element e = iter.next();
+ int w = e.index();
+
+ for (int k = 0; k < state.getNumTopics(); ++k) {
+ v.set(doc.phi(k, w) + Math.log(e.get()));
+
+ IntPairWritable kw = new IntPairWritable(k, w);
+
+ // output (topic, word)'s logProb contribution
+ context.write(kw, v);
+ logTotals[k] = LDAUtil.logSum(logTotals[k], v.get());
+ }
+ }
+
+ // Output the totals for the statistics. This is to make
+ // normalizing a lot easier.
+ for (int k = 0; k < state.getNumTopics(); ++k) {
+ IntPairWritable kw = new IntPairWritable(k, LDADriver.TOPIC_SUM_KEY);
+ v.set(logTotals[k]);
+ assert !Double.isNaN(v.get());
+ context.write(kw, v);
+ }
+ IntPairWritable llk = new IntPairWritable(LDADriver.LOG_LIKELIHOOD_KEY, LDADriver.LOG_LIKELIHOOD_KEY);
+ // Output log-likelihoods.
+ v.set(doc.getLogLikelihood());
+ context.write(llk, v);
+ }
+
+ public void configure(LDAState myState) {
+ this.state = myState;
+ this.infer = new LDAInference(state);
+ }
+
+ public void configure(Configuration job) {
+ try {
+ LDAState myState = LDADriver.createState(job);
+ configure(myState);
+ } catch (IOException e) {
+ throw new IllegalStateException("Error creating LDA State!", e);
+ }
+ }
+
+ @Override
+ protected void setup(Context context) {
+ configure(context.getConfiguration());
+ }
+
+}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java Sat May 7 00:18:51 2011
@@ -93,12 +93,12 @@ public final class TestMapReduce extends
@Test
public void testMapper() throws Exception {
LDAState state = generateRandomState(100,NUM_TOPICS);
- LDAMapper mapper = new LDAMapper();
+ LDAWordTopicMapper mapper = new LDAWordTopicMapper();
mapper.configure(state);
for(int i = 0; i < NUM_TESTS; ++i) {
RandomAccessSparseVector v = generateRandomDoc(100,0.3);
int myNumWords = numNonZero(v);
- LDAMapper.Context mock = EasyMock.createMock(LDAMapper.Context.class);
+ LDAWordTopicMapper.Context mock = EasyMock.createMock(LDAWordTopicMapper.Context.class);
mock.write(EasyMock.isA(IntPairWritable.class), EasyMock.isA(DoubleWritable.class));
EasyMock.expectLastCall().times(myNumWords * NUM_TOPICS + NUM_TOPICS + 1);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java Sat May 7 00:18:51 2011
@@ -33,8 +33,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
+import java.util.Arrays;
public final class TestDistributedLanczosSolverCLI extends MahoutTestCase {
private static final Logger log = LoggerFactory.getLogger(TestDistributedLanczosSolverCLI.class);
Modified: mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java
URL: http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java?rev=1100421&r1=1100420&r2=1100421&view=diff
==============================================================================
--- mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java (original)
+++ mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/lda/LDAPrintTopics.java Sat May 7 00:18:51 2011
@@ -17,19 +17,6 @@
package org.apache.mahout.clustering.lda;
-import java.io.File;
-import java.io.IOException;
-import java.io.Writer;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.PriorityQueue;
-import java.util.Queue;
-
-import com.google.common.base.Charsets;
-import com.google.common.io.Files;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
import org.apache.commons.cli2.Option;
@@ -48,6 +35,19 @@ import org.apache.mahout.common.iterator
import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.utils.vectors.VectorHelper;
+import java.io.File;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Queue;
+
/**
* Class to print out the top K words for each topic.
*/
@@ -56,7 +56,7 @@ public final class LDAPrintTopics {
private LDAPrintTopics() { }
private static class StringDoublePair implements Comparable<StringDoublePair> {
- private final double score;
+ private double score;
private final String word;
StringDoublePair(double score, String word) {
@@ -153,18 +153,16 @@ public final class LDAPrintTopics {
throw new IllegalArgumentException("Invalid dictionary format");
}
- List<List<String>> topWords = topWordsForTopics(input, config, wordList, numWords);
-
+ List<PriorityQueue<StringDoublePair>> topWords = topWordsForTopics(input, config, wordList, numWords);
+
+ File output = null;
if (cmdLine.hasOption(outOpt)) {
- File output = new File(cmdLine.getValue(outOpt).toString());
+ output = new File(cmdLine.getValue(outOpt).toString());
if (!output.exists() && !output.mkdirs()) {
throw new IOException("Could not create directory: " + output);
}
- writeTopWords(topWords, output);
- } else {
- printTopWords(topWords);
}
-
+ printTopWords(topWords, output);
} catch (OptionException e) {
CommandLineUtil.printHelp(group);
throw e;
@@ -181,63 +179,60 @@ public final class LDAPrintTopics {
}
}
- private static void printTopWords(List<List<String>> topWords) {
+ private static void printTopWords(List<PriorityQueue<StringDoublePair>> topWords, File outputDir)
+ throws IOException {
for (int i = 0; i < topWords.size(); ++i) {
- List<String> topK = topWords.get(i);
- System.out.println("Topic " + i);
- System.out.println("===========");
- for (String word : topK) {
- System.out.println(word);
+ PriorityQueue<StringDoublePair> topK = topWords.get(i);
+ PrintWriter out;
+ if(outputDir != null) {
+ out = new PrintWriter(new File(outputDir, "topic_" + i));
+ } else {
+ out = new PrintWriter(System.out);
+ out.println("Topic " + i);
+ out.println("===========");
+ }
+ List<StringDoublePair> topKasList = new ArrayList<StringDoublePair>(topK.size());
+ for(StringDoublePair wordWithScore : topK) {
+ topKasList.add(wordWithScore);
}
+ Collections.sort(topKasList, Collections.reverseOrder());
+ for(StringDoublePair wordWithScore : topKasList) {
+ out.println(wordWithScore.word + " [p(" + wordWithScore.word + "|topic_" + i +") = "
+ + wordWithScore.score);
+ }
+ out.close();
}
}
- private static List<List<String>> topWordsForTopics(String dir,
+ private static List<PriorityQueue<StringDoublePair>> topWordsForTopics(String dir,
Configuration job,
List<String> wordList,
int numWordsToPrint) {
List<PriorityQueue<StringDoublePair>> queues = new ArrayList<PriorityQueue<StringDoublePair>>();
-
+ Map<Integer,Double> expSums = new HashMap<Integer, Double>();
for (Pair<IntPairWritable,DoubleWritable> record :
new SequenceFileDirIterable<IntPairWritable, DoubleWritable>(
new Path(dir, "part-*"), PathType.GLOB, null, null, true, job)) {
IntPairWritable key = record.getFirst();
int topic = key.getFirst();
int word = key.getSecond();
-
ensureQueueSize(queues, topic);
if (word >= 0 && topic >= 0) {
double score = record.getSecond().get();
+ if(expSums.get(topic) == null) {
+ expSums.put(topic, 0d);
+ }
+ expSums.put(topic, expSums.get(topic) + Math.exp(score));
String realWord = wordList.get(word);
maybeEnqueue(queues.get(topic), realWord, score, numWordsToPrint);
}
}
-
- List<List<String>> result = new ArrayList<List<String>>();
- for (int i = 0; i < queues.size(); ++i) {
- result.add(i, new LinkedList<String>());
- for (StringDoublePair sdp : queues.get(i)) {
- result.get(i).add(0, sdp.word); // prepend
+ for(int i=0; i<queues.size(); i++) {
+ PriorityQueue<StringDoublePair> queue = queues.get(i);
+ for(StringDoublePair pair : queue) {
+ pair.score = Math.exp(pair.score) / expSums.get(i);
}
}
-
- return result;
+ return queues;
}
-
- private static void writeTopWords(List<List<String>> topWords, File output) throws IOException {
- for (int i = 0; i < topWords.size(); ++i) {
- List<String> topK = topWords.get(i);
- Writer writer = Files.newWriter(new File(output, "topic-" + i), Charsets.UTF_8);
- try {
- writer.write("Topic " + i + '\n');
- writer.write("===========\n");
- for (String word : topK) {
- writer.write(word + '\n');
- }
- } finally {
- writer.close();
- }
- }
- }
-
}