You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2010/01/10 18:48:13 UTC
svn commit: r897668 -
/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
Author: gsingers
Date: Sun Jan 10 17:48:13 2010
New Revision: 897668
URL: http://svn.apache.org/viewvc?rev=897668&view=rev
Log:
small clarification
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=897668&r1=897667&r2=897668&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Sun Jan 10 17:48:13 2010
@@ -47,11 +47,11 @@
import org.slf4j.LoggerFactory;
/**
-* Estimates an LDA model from a corpus of documents,
-* which are SparseVectors of word counts. At each
-* phase, it outputs a matrix of log probabilities of
-* each topic.
-*/
+ * Estimates an LDA model from a corpus of documents,
+ * which are SparseVectors of word counts. At each
+ * phase, it outputs a matrix of log probabilities of
+ * each topic.
+ */
public final class LDADriver {
static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
@@ -72,47 +72,47 @@
}
public static void main(String[] args) throws ClassNotFoundException,
- IOException, InterruptedException {
+ IOException, InterruptedException {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
Option inputOpt = obuilder.withLongName("input").withRequired(true).withArgument(
- abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
- "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Path for input Vectors. Must be a SequenceFile of Writable, Vector").withShortName("i").create();
Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
- abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
- "The Output Working Directory").withShortName("o").create();
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Output Working Directory").withShortName("o").create();
Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withDescription(
- "If set, overwrite the output directory").withShortName("w").create();
+ "If set, overwrite the output directory").withShortName("w").create();
Option topicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(
- abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription(
- "The number of topics").withShortName("k").create();
+ abuilder.withName("numTopics").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The number of topics").withShortName("k").create();
Option wordsOpt = obuilder.withLongName("numWords").withRequired(true).withArgument(
- abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
- "The number of words in the corpus").withShortName("v").create();
+ abuilder.withName("numWords").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The total number of words in the corpus").withShortName("v").create();
Option topicSmOpt = obuilder.withLongName("topicSmoothing").withRequired(false).withArgument(abuilder
- .withName("topicSmoothing").withDefault(-1.0).withMinimum(0).withMaximum(1).create()).withDescription(
- "Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create();
+ .withName("topicSmoothing").withDefault(-1.0).withMinimum(0).withMaximum(1).create()).withDescription(
+ "Topic smoothing parameter. Default is 50/numTopics.").withShortName("a").create();
Option maxIterOpt = obuilder.withLongName("maxIter").withRequired(false).withArgument(
- abuilder.withName("maxIter").withDefault(-1).withMinimum(0).withMaximum(1).create()).withDescription(
- "Max iterations to run (or until convergence). -1 (default) waits until convergence.").create();
+ abuilder.withName("maxIter").withDefault(-1).withMinimum(0).withMaximum(1).create()).withDescription(
+ "Max iterations to run (or until convergence). -1 (default) waits until convergence.").create();
Option numReducOpt = obuilder.withLongName("numReducers").withRequired(false).withArgument(
- abuilder.withName("numReducers").withDefault(10).withMinimum(0).withMaximum(1).create()).withDescription(
- "Max iterations to run (or until convergence). Default 10").create();
+ abuilder.withName("numReducers").withDefault(10).withMinimum(0).withMaximum(1).create()).withDescription(
+ "Max iterations to run (or until convergence). Default 10").create();
Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h").create();
Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(
- topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt).withOption(
+ topicsOpt).withOption(wordsOpt).withOption(topicSmOpt).withOption(maxIterOpt).withOption(
numReducOpt).withOption(overwriteOutput).withOption(helpOpt).create();
try {
Parser parser = new Parser();
@@ -154,12 +154,12 @@
if (cmdLine.hasOption(topicSmOpt)) {
topicSmoothing = Double.parseDouble(cmdLine.getValue(maxIterOpt).toString());
}
- if(topicSmoothing < 1) {
+ if (topicSmoothing < 1) {
topicSmoothing = 50.0 / numTopics;
}
runJob(input, output, numTopics, numWords, topicSmoothing, maxIterations,
- numReduceTasks);
+ numReduceTasks);
} catch (OptionException e) {
log.error("Exception", e);
@@ -170,18 +170,18 @@
/**
* Run the job using supplied arguments
*
- * @param input the directory pathname for input points
- * @param output the directory pathname for output points
- * @param numTopics the number of topics
- * @param numWords the number of words
- * @param topicSmoothing pseudocounts for each topic, typically small < .5
- * @param maxIterations the maximum number of iterations
- * @param numReducers the number of Reducers desired
- * @throws IOException
+ * @param input the directory pathname for input points
+ * @param output the directory pathname for output points
+ * @param numTopics the number of topics
+ * @param numWords the number of words
+ * @param topicSmoothing pseudocounts for each topic, typically small < .5
+ * @param maxIterations the maximum number of iterations
+ * @param numReducers the number of Reducers desired
+ * @throws IOException
*/
- public static void runJob(String input, String output, int numTopics,
- int numWords, double topicSmoothing, int maxIterations, int numReducers)
- throws IOException, InterruptedException, ClassNotFoundException {
+ public static void runJob(String input, String output, int numTopics,
+ int numWords, double topicSmoothing, int maxIterations, int numReducers)
+ throws IOException, InterruptedException, ClassNotFoundException {
String stateIn = output + "/state-0";
writeInitialState(stateIn, numTopics, numWords);
@@ -193,7 +193,7 @@
// point the output to a new directory per iteration
String stateOut = output + "/state-" + (iteration + 1);
double ll = runIteration(input, stateIn, stateOut, numTopics,
- numWords, topicSmoothing, numReducers);
+ numWords, topicSmoothing, numReducers);
double relChange = (oldLL - ll) / oldLL;
// now point the input to the old output directory
@@ -207,8 +207,8 @@
}
}
- private static void writeInitialState(String statePath,
- int numTopics, int numWords) throws IOException {
+ private static void writeInitialState(String statePath,
+ int numTopics, int numWords) throws IOException {
Path dir = new Path(statePath);
Configuration job = new Configuration();
FileSystem fs = dir.getFileSystem(job);
@@ -221,7 +221,7 @@
for (int k = 0; k < numTopics; ++k) {
Path path = new Path(dir, "part-" + k);
SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path,
- IntPairWritable.class, DoubleWritable.class);
+ IntPairWritable.class, DoubleWritable.class);
kw.setX(k);
double total = 0.0; // total number of pseudo counts we made
@@ -250,7 +250,7 @@
IntPairWritable key = new IntPairWritable();
DoubleWritable value = new DoubleWritable();
- for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
+ for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
Path path = status.getPath();
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
while (reader.next(key, value)) {
@@ -268,21 +268,21 @@
/**
* Run the job using supplied arguments
*
- * @param input the directory pathname for input points
- * @param stateIn the directory pathname for input state
- * @param stateOut the directory pathname for output state
+ * @param input the directory pathname for input points
+ * @param stateIn the directory pathname for input state
+ * @param stateOut the directory pathname for output state
* @param numTopics the number of clusters
- * @param numReducers the number of Reducers desired
+ * @param numReducers the number of Reducers desired
*/
public static double runIteration(String input, String stateIn,
- String stateOut, int numTopics, int numWords, double topicSmoothing,
- int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+ String stateOut, int numTopics, int numWords, double topicSmoothing,
+ int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
Configuration conf = new Configuration();
conf.set(STATE_IN_KEY, stateIn);
conf.set(NUM_TOPICS_KEY, Integer.toString(numTopics));
conf.set(NUM_WORDS_KEY, Integer.toString(numWords));
conf.set(TOPIC_SMOOTHING_KEY, Double.toString(topicSmoothing));
-
+
Job job = new Job(conf);
job.setOutputKeyClass(IntPairWritable.class);
@@ -298,7 +298,7 @@
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setJarByClass(LDADriver.class);
-
+
job.waitForCompletion(true);
return findLL(stateOut, conf);
}
@@ -318,7 +318,7 @@
IntPairWritable key = new IntPairWritable();
DoubleWritable value = new DoubleWritable();
- for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
+ for (FileStatus status : fs.globStatus(new Path(dir, "part-*"))) {
Path path = status.getPath();
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
while (reader.next(key, value)) {
@@ -349,6 +349,6 @@
}
return new LDAState(numTopics, numWords, topicSmoothing,
- pWgT, logTotals, ll);
+ pWgT, logTotals, ll);
}
}