You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:07:59 UTC
[28/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java
new file mode 100644
index 0000000..46fcc7f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+public class CVB0DocInferenceMapper extends CachingCVB0Mapper {
+
+ private final VectorWritable topics = new VectorWritable();
+
+ @Override
+ public void map(IntWritable docId, VectorWritable doc, Context context)
+ throws IOException, InterruptedException {
+ int numTopics = getNumTopics();
+ Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics);
+ Matrix docModel = new SparseRowMatrix(numTopics, doc.get().size());
+ int maxIters = getMaxIters();
+ ModelTrainer modelTrainer = getModelTrainer();
+ for (int i = 0; i < maxIters; i++) {
+ modelTrainer.getReadModel().trainDocTopicModel(doc.get(), docTopics, docModel);
+ }
+ topics.set(docTopics);
+ context.write(docId, topics);
+ }
+
+ @Override
+ protected void cleanup(Context context) {
+ getModelTrainer().stop();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java
new file mode 100644
index 0000000..3eee446
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java
@@ -0,0 +1,536 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import com.google.common.base.Joiner;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.mapreduce.VectorSumReducer;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.List;
+
+/**
+ * See {@link CachingCVB0Mapper} for more details on scalability and room for improvement.
+ * To try out this LDA implementation without using Hadoop, check out
+ * {@link InMemoryCollapsedVariationalBayes0}. If you want to do training directly in java code
+ * with your own main(), then look to {@link ModelTrainer} and {@link TopicModel}.
+ *
+ * Usage: {@code ./bin/mahout cvb <i>options</i>}
+ * <p>
+ * Valid options include:
+ * <dl>
+ * <dt>{@code --input path}</td>
+ * <dd>Input path for {@code SequenceFile<IntWritable, VectorWritable>} document vectors. See
+ * {@link org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles}
+ * for details on how to generate this input format.</dd>
+ * <dt>{@code --dictionary path}</dt>
+ * <dd>Path to dictionary file(s) generated during construction of input document vectors (glob
+ * expression supported). If set, this data is scanned to determine an appropriate value for option
+ * {@code --num_terms}.</dd>
+ * <dt>{@code --output path}</dt>
+ * <dd>Output path for topic-term distributions.</dd>
+ * <dt>{@code --doc_topic_output path}</dt>
+ * <dd>Output path for doc-topic distributions.</dd>
+ * <dt>{@code --num_topics k}</dt>
+ * <dd>Number of latent topics.</dd>
+ * <dt>{@code --num_terms nt}</dt>
+ * <dd>Number of unique features defined by input document vectors. If option {@code --dictionary}
+ * is defined and this option is unspecified, term count is calculated from dictionary.</dd>
+ * <dt>{@code --topic_model_temp_dir path}</dt>
+ * <dd>Path in which to store model state after each iteration.</dd>
+ * <dt>{@code --maxIter i}</dt>
+ * <dd>Maximum number of iterations to perform. If this value is less than or equal to the number of
+ * iteration states found beneath the path specified by option {@code --topic_model_temp_dir}, no
+ * further iterations are performed. Instead, output topic-term and doc-topic distributions are
+ * generated using data from the specified iteration.</dd>
+ * <dt>{@code --max_doc_topic_iters i}</dt>
+ * <dd>Maximum number of iterations per doc for p(topic|doc) learning. Defaults to {@code 10}.</dd>
+ * <dt>{@code --doc_topic_smoothing a}</dt>
+ * <dd>Smoothing for doc-topic distribution. Defaults to {@code 0.0001}.</dd>
+ * <dt>{@code --term_topic_smoothing e}</dt>
+ * <dd>Smoothing for topic-term distribution. Defaults to {@code 0.0001}.</dd>
+ * <dt>{@code --random_seed seed}</dt>
+ * <dd>Integer seed for random number generation.</dd>
+ * <dt>{@code --test_set_percentage p}</dt>
+ * <dd>Fraction of data to hold out for testing. Defaults to {@code 0.0}.</dd>
+ * <dt>{@code --iteration_block_size block}</dt>
+ * <dd>Number of iterations between perplexity checks. Defaults to {@code 10}. This option is
+ * ignored unless option {@code --test_set_percentage} is greater than zero.</dd>
+ * </dl>
+ */
+public class CVB0Driver extends AbstractJob {
+ private static final Logger log = LoggerFactory.getLogger(CVB0Driver.class);
+
+ public static final String NUM_TOPICS = "num_topics";
+ public static final String NUM_TERMS = "num_terms";
+ public static final String DOC_TOPIC_SMOOTHING = "doc_topic_smoothing";
+ public static final String TERM_TOPIC_SMOOTHING = "term_topic_smoothing";
+ public static final String DICTIONARY = "dictionary";
+ public static final String DOC_TOPIC_OUTPUT = "doc_topic_output";
+ public static final String MODEL_TEMP_DIR = "topic_model_temp_dir";
+ public static final String ITERATION_BLOCK_SIZE = "iteration_block_size";
+ public static final String RANDOM_SEED = "random_seed";
+ public static final String TEST_SET_FRACTION = "test_set_fraction";
+ public static final String NUM_TRAIN_THREADS = "num_train_threads";
+ public static final String NUM_UPDATE_THREADS = "num_update_threads";
+ public static final String MAX_ITERATIONS_PER_DOC = "max_doc_topic_iters";
+ public static final String MODEL_WEIGHT = "prev_iter_mult";
+ public static final String NUM_REDUCE_TASKS = "num_reduce_tasks";
+ public static final String BACKFILL_PERPLEXITY = "backfill_perplexity";
+ private static final String MODEL_PATHS = "mahout.lda.cvb.modelPath";
+
+ private static final double DEFAULT_CONVERGENCE_DELTA = 0;
+ private static final double DEFAULT_DOC_TOPIC_SMOOTHING = 0.0001;
+ private static final double DEFAULT_TERM_TOPIC_SMOOTHING = 0.0001;
+ private static final int DEFAULT_ITERATION_BLOCK_SIZE = 10;
+ private static final double DEFAULT_TEST_SET_FRACTION = 0;
+ private static final int DEFAULT_NUM_TRAIN_THREADS = 4;
+ private static final int DEFAULT_NUM_UPDATE_THREADS = 1;
+ private static final int DEFAULT_MAX_ITERATIONS_PER_DOC = 10;
+ private static final int DEFAULT_NUM_REDUCE_TASKS = 10;
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.maxIterationsOption().create());
+ addOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION, "cd", "The convergence delta value",
+ String.valueOf(DEFAULT_CONVERGENCE_DELTA));
+ addOption(DefaultOptionCreator.overwriteOption().create());
+
+ addOption(NUM_TOPICS, "k", "Number of topics to learn", true);
+ addOption(NUM_TERMS, "nt", "Vocabulary size", false);
+ addOption(DOC_TOPIC_SMOOTHING, "a", "Smoothing for document/topic distribution",
+ String.valueOf(DEFAULT_DOC_TOPIC_SMOOTHING));
+ addOption(TERM_TOPIC_SMOOTHING, "e", "Smoothing for topic/term distribution",
+ String.valueOf(DEFAULT_TERM_TOPIC_SMOOTHING));
+ addOption(DICTIONARY, "dict", "Path to term-dictionary file(s) (glob expression supported)", false);
+ addOption(DOC_TOPIC_OUTPUT, "dt", "Output path for the training doc/topic distribution", false);
+ addOption(MODEL_TEMP_DIR, "mt", "Path to intermediate model path (useful for restarting)", false);
+ addOption(ITERATION_BLOCK_SIZE, "block", "Number of iterations per perplexity check",
+ String.valueOf(DEFAULT_ITERATION_BLOCK_SIZE));
+ addOption(RANDOM_SEED, "seed", "Random seed", false);
+ addOption(TEST_SET_FRACTION, "tf", "Fraction of data to hold out for testing",
+ String.valueOf(DEFAULT_TEST_SET_FRACTION));
+ addOption(NUM_TRAIN_THREADS, "ntt", "number of threads per mapper to train with",
+ String.valueOf(DEFAULT_NUM_TRAIN_THREADS));
+ addOption(NUM_UPDATE_THREADS, "nut", "number of threads per mapper to update the model with",
+ String.valueOf(DEFAULT_NUM_UPDATE_THREADS));
+ addOption(MAX_ITERATIONS_PER_DOC, "mipd", "max number of iterations per doc for p(topic|doc) learning",
+ String.valueOf(DEFAULT_MAX_ITERATIONS_PER_DOC));
+ addOption(NUM_REDUCE_TASKS, null, "number of reducers to use during model estimation",
+ String.valueOf(DEFAULT_NUM_REDUCE_TASKS));
+ addOption(buildOption(BACKFILL_PERPLEXITY, null, "enable backfilling of missing perplexity values", false, false,
+ null));
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ int numTopics = Integer.parseInt(getOption(NUM_TOPICS));
+ Path inputPath = getInputPath();
+ Path topicModelOutputPath = getOutputPath();
+ int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
+ int iterationBlockSize = Integer.parseInt(getOption(ITERATION_BLOCK_SIZE));
+ double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION));
+ double alpha = Double.parseDouble(getOption(DOC_TOPIC_SMOOTHING));
+ double eta = Double.parseDouble(getOption(TERM_TOPIC_SMOOTHING));
+ int numTrainThreads = Integer.parseInt(getOption(NUM_TRAIN_THREADS));
+ int numUpdateThreads = Integer.parseInt(getOption(NUM_UPDATE_THREADS));
+ int maxItersPerDoc = Integer.parseInt(getOption(MAX_ITERATIONS_PER_DOC));
+ Path dictionaryPath = hasOption(DICTIONARY) ? new Path(getOption(DICTIONARY)) : null;
+ int numTerms = hasOption(NUM_TERMS)
+ ? Integer.parseInt(getOption(NUM_TERMS))
+ : getNumTerms(getConf(), dictionaryPath);
+ Path docTopicOutputPath = hasOption(DOC_TOPIC_OUTPUT) ? new Path(getOption(DOC_TOPIC_OUTPUT)) : null;
+ Path modelTempPath = hasOption(MODEL_TEMP_DIR)
+ ? new Path(getOption(MODEL_TEMP_DIR))
+ : getTempPath("topicModelState");
+ long seed = hasOption(RANDOM_SEED)
+ ? Long.parseLong(getOption(RANDOM_SEED))
+ : System.nanoTime() % 10000;
+ float testFraction = hasOption(TEST_SET_FRACTION)
+ ? Float.parseFloat(getOption(TEST_SET_FRACTION))
+ : 0.0f;
+ int numReduceTasks = Integer.parseInt(getOption(NUM_REDUCE_TASKS));
+ boolean backfillPerplexity = hasOption(BACKFILL_PERPLEXITY);
+
+ return run(getConf(), inputPath, topicModelOutputPath, numTopics, numTerms, alpha, eta,
+ maxIterations, iterationBlockSize, convergenceDelta, dictionaryPath, docTopicOutputPath,
+ modelTempPath, seed, testFraction, numTrainThreads, numUpdateThreads, maxItersPerDoc,
+ numReduceTasks, backfillPerplexity);
+ }
+
+ private static int getNumTerms(Configuration conf, Path dictionaryPath) throws IOException {
+ FileSystem fs = dictionaryPath.getFileSystem(conf);
+ Text key = new Text();
+ IntWritable value = new IntWritable();
+ int maxTermId = -1;
+ for (FileStatus stat : fs.globStatus(dictionaryPath)) {
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, stat.getPath(), conf);
+ while (reader.next(key, value)) {
+ maxTermId = Math.max(maxTermId, value.get());
+ }
+ }
+ return maxTermId + 1;
+ }
+
+ public int run(Configuration conf,
+ Path inputPath,
+ Path topicModelOutputPath,
+ int numTopics,
+ int numTerms,
+ double alpha,
+ double eta,
+ int maxIterations,
+ int iterationBlockSize,
+ double convergenceDelta,
+ Path dictionaryPath,
+ Path docTopicOutputPath,
+ Path topicModelStateTempPath,
+ long randomSeed,
+ float testFraction,
+ int numTrainThreads,
+ int numUpdateThreads,
+ int maxItersPerDoc,
+ int numReduceTasks,
+ boolean backfillPerplexity)
+ throws ClassNotFoundException, IOException, InterruptedException {
+
+ setConf(conf);
+
+ // verify arguments
+ Preconditions.checkArgument(testFraction >= 0.0 && testFraction <= 1.0,
+ "Expected 'testFraction' value in range [0, 1] but found value '%s'", testFraction);
+ Preconditions.checkArgument(!backfillPerplexity || testFraction > 0.0,
+ "Expected 'testFraction' value in range (0, 1] but found value '%s'", testFraction);
+
+ String infoString = "Will run Collapsed Variational Bayes (0th-derivative approximation) "
+ + "learning for LDA on {} (numTerms: {}), finding {}-topics, with document/topic prior {}, "
+ + "topic/term prior {}. Maximum iterations to run will be {}, unless the change in "
+ + "perplexity is less than {}. Topic model output (p(term|topic) for each topic) will be "
+ + "stored {}. Random initialization seed is {}, holding out {} of the data for perplexity "
+ + "check\n";
+ log.info(infoString, inputPath, numTerms, numTopics, alpha, eta, maxIterations,
+ convergenceDelta, topicModelOutputPath, randomSeed, testFraction);
+ infoString = dictionaryPath == null
+ ? "" : "Dictionary to be used located " + dictionaryPath.toString() + '\n';
+ infoString += docTopicOutputPath == null
+ ? "" : "p(topic|docId) will be stored " + docTopicOutputPath.toString() + '\n';
+ log.info(infoString);
+
+ FileSystem fs = FileSystem.get(topicModelStateTempPath.toUri(), conf);
+ int iterationNumber = getCurrentIterationNumber(conf, topicModelStateTempPath, maxIterations);
+ log.info("Current iteration number: {}", iterationNumber);
+
+ conf.set(NUM_TOPICS, String.valueOf(numTopics));
+ conf.set(NUM_TERMS, String.valueOf(numTerms));
+ conf.set(DOC_TOPIC_SMOOTHING, String.valueOf(alpha));
+ conf.set(TERM_TOPIC_SMOOTHING, String.valueOf(eta));
+ conf.set(RANDOM_SEED, String.valueOf(randomSeed));
+ conf.set(NUM_TRAIN_THREADS, String.valueOf(numTrainThreads));
+ conf.set(NUM_UPDATE_THREADS, String.valueOf(numUpdateThreads));
+ conf.set(MAX_ITERATIONS_PER_DOC, String.valueOf(maxItersPerDoc));
+ conf.set(MODEL_WEIGHT, "1"); // TODO
+ conf.set(TEST_SET_FRACTION, String.valueOf(testFraction));
+
+ List<Double> perplexities = Lists.newArrayList();
+ for (int i = 1; i <= iterationNumber; i++) {
+ // form path to model
+ Path modelPath = modelPath(topicModelStateTempPath, i);
+
+ // read perplexity
+ double perplexity = readPerplexity(conf, topicModelStateTempPath, i);
+ if (Double.isNaN(perplexity)) {
+ if (!(backfillPerplexity && i % iterationBlockSize == 0)) {
+ continue;
+ }
+ log.info("Backfilling perplexity at iteration {}", i);
+ if (!fs.exists(modelPath)) {
+ log.error("Model path '{}' does not exist; Skipping iteration {} perplexity calculation",
+ modelPath.toString(), i);
+ continue;
+ }
+ perplexity = calculatePerplexity(conf, inputPath, modelPath, i);
+ }
+
+ // register and log perplexity
+ perplexities.add(perplexity);
+ log.info("Perplexity at iteration {} = {}", i, perplexity);
+ }
+
+ long startTime = System.currentTimeMillis();
+ while (iterationNumber < maxIterations) {
+ // test convergence
+ if (convergenceDelta > 0.0) {
+ double delta = rateOfChange(perplexities);
+ if (delta < convergenceDelta) {
+ log.info("Convergence achieved at iteration {} with perplexity {} and delta {}",
+ iterationNumber, perplexities.get(perplexities.size() - 1), delta);
+ break;
+ }
+ }
+
+ // update model
+ iterationNumber++;
+ log.info("About to run iteration {} of {}", iterationNumber, maxIterations);
+ Path modelInputPath = modelPath(topicModelStateTempPath, iterationNumber - 1);
+ Path modelOutputPath = modelPath(topicModelStateTempPath, iterationNumber);
+ runIteration(conf, inputPath, modelInputPath, modelOutputPath, iterationNumber,
+ maxIterations, numReduceTasks);
+
+ // calculate perplexity
+ if (testFraction > 0 && iterationNumber % iterationBlockSize == 0) {
+ perplexities.add(calculatePerplexity(conf, inputPath, modelOutputPath, iterationNumber));
+ log.info("Current perplexity = {}", perplexities.get(perplexities.size() - 1));
+ log.info("(p_{} - p_{}) / p_0 = {}; target = {}", iterationNumber, iterationNumber - iterationBlockSize,
+ rateOfChange(perplexities), convergenceDelta);
+ }
+ }
+ log.info("Completed {} iterations in {} seconds", iterationNumber,
+ (System.currentTimeMillis() - startTime) / 1000);
+ log.info("Perplexities: ({})", Joiner.on(", ").join(perplexities));
+
+ // write final topic-term and doc-topic distributions
+ Path finalIterationData = modelPath(topicModelStateTempPath, iterationNumber);
+ Job topicModelOutputJob = topicModelOutputPath != null
+ ? writeTopicModel(conf, finalIterationData, topicModelOutputPath)
+ : null;
+ Job docInferenceJob = docTopicOutputPath != null
+ ? writeDocTopicInference(conf, inputPath, finalIterationData, docTopicOutputPath)
+ : null;
+ if (topicModelOutputJob != null && !topicModelOutputJob.waitForCompletion(true)) {
+ return -1;
+ }
+ if (docInferenceJob != null && !docInferenceJob.waitForCompletion(true)) {
+ return -1;
+ }
+ return 0;
+ }
+
+ private static double rateOfChange(List<Double> perplexities) {
+ int sz = perplexities.size();
+ if (sz < 2) {
+ return Double.MAX_VALUE;
+ }
+ return Math.abs(perplexities.get(sz - 1) - perplexities.get(sz - 2)) / perplexities.get(0);
+ }
+
+ private double calculatePerplexity(Configuration conf, Path corpusPath, Path modelPath, int iteration)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ String jobName = "Calculating perplexity for " + modelPath;
+ log.info("About to run: {}", jobName);
+
+ Path outputPath = perplexityPath(modelPath.getParent(), iteration);
+ Job job = prepareJob(corpusPath, outputPath, CachingCVB0PerplexityMapper.class, DoubleWritable.class,
+ DoubleWritable.class, DualDoubleSumReducer.class, DoubleWritable.class, DoubleWritable.class);
+
+ job.setJobName(jobName);
+ job.setCombinerClass(DualDoubleSumReducer.class);
+ job.setNumReduceTasks(1);
+ setModelPaths(job, modelPath);
+ HadoopUtil.delete(conf, outputPath);
+ if (!job.waitForCompletion(true)) {
+ throw new InterruptedException("Failed to calculate perplexity for: " + modelPath);
+ }
+ return readPerplexity(conf, modelPath.getParent(), iteration);
+ }
+
+ /**
+ * Sums keys and values independently.
+ */
+ public static class DualDoubleSumReducer extends
+ Reducer<DoubleWritable, DoubleWritable, DoubleWritable, DoubleWritable> {
+ private final DoubleWritable outKey = new DoubleWritable();
+ private final DoubleWritable outValue = new DoubleWritable();
+
+ @Override
+ public void run(Context context) throws IOException,
+ InterruptedException {
+ double keySum = 0.0;
+ double valueSum = 0.0;
+ while (context.nextKey()) {
+ keySum += context.getCurrentKey().get();
+ for (DoubleWritable value : context.getValues()) {
+ valueSum += value.get();
+ }
+ }
+ outKey.set(keySum);
+ outValue.set(valueSum);
+ context.write(outKey, outValue);
+ }
+ }
+
+ /**
+ * @param topicModelStateTemp
+ * @param iteration
+ * @return {@code double[2]} where first value is perplexity and second is model weight of those
+ * documents sampled during perplexity computation, or {@code null} if no perplexity data
+ * exists for the given iteration.
+ * @throws IOException
+ */
+ public static double readPerplexity(Configuration conf, Path topicModelStateTemp, int iteration)
+ throws IOException {
+ Path perplexityPath = perplexityPath(topicModelStateTemp, iteration);
+ FileSystem fs = FileSystem.get(perplexityPath.toUri(), conf);
+ if (!fs.exists(perplexityPath)) {
+ log.warn("Perplexity path {} does not exist, returning NaN", perplexityPath);
+ return Double.NaN;
+ }
+ double perplexity = 0;
+ double modelWeight = 0;
+ long n = 0;
+ for (Pair<DoubleWritable, DoubleWritable> pair : new SequenceFileDirIterable<DoubleWritable, DoubleWritable>(
+ perplexityPath, PathType.LIST, PathFilters.partFilter(), null, true, conf)) {
+ modelWeight += pair.getFirst().get();
+ perplexity += pair.getSecond().get();
+ n++;
+ }
+ log.info("Read {} entries with total perplexity {} and model weight {}", n,
+ perplexity, modelWeight);
+ return perplexity / modelWeight;
+ }
+
+ private Job writeTopicModel(Configuration conf, Path modelInput, Path output)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ String jobName = String.format("Writing final topic/term distributions from %s to %s", modelInput, output);
+ log.info("About to run: {}", jobName);
+
+ Job job = prepareJob(modelInput, output, SequenceFileInputFormat.class, CVB0TopicTermVectorNormalizerMapper.class,
+ IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, jobName);
+ job.submit();
+ return job;
+ }
+
+ private Job writeDocTopicInference(Configuration conf, Path corpus, Path modelInput, Path output)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ String jobName = String.format("Writing final document/topic inference from %s to %s", corpus, output);
+ log.info("About to run: {}", jobName);
+
+ Job job = prepareJob(corpus, output, SequenceFileInputFormat.class, CVB0DocInferenceMapper.class,
+ IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, jobName);
+
+ FileSystem fs = FileSystem.get(corpus.toUri(), conf);
+ if (modelInput != null && fs.exists(modelInput)) {
+ FileStatus[] statuses = fs.listStatus(modelInput, PathFilters.partFilter());
+ URI[] modelUris = new URI[statuses.length];
+ for (int i = 0; i < statuses.length; i++) {
+ modelUris[i] = statuses[i].getPath().toUri();
+ }
+ DistributedCache.setCacheFiles(modelUris, conf);
+ setModelPaths(job, modelInput);
+ }
+ job.submit();
+ return job;
+ }
+
+ public static Path modelPath(Path topicModelStateTempPath, int iterationNumber) {
+ return new Path(topicModelStateTempPath, "model-" + iterationNumber);
+ }
+
+ public static Path perplexityPath(Path topicModelStateTempPath, int iterationNumber) {
+ return new Path(topicModelStateTempPath, "perplexity-" + iterationNumber);
+ }
+
+ private static int getCurrentIterationNumber(Configuration config, Path modelTempDir, int maxIterations)
+ throws IOException {
+ FileSystem fs = FileSystem.get(modelTempDir.toUri(), config);
+ int iterationNumber = 1;
+ Path iterationPath = modelPath(modelTempDir, iterationNumber);
+ while (fs.exists(iterationPath) && iterationNumber <= maxIterations) {
+ log.info("Found previous state: {}", iterationPath);
+ iterationNumber++;
+ iterationPath = modelPath(modelTempDir, iterationNumber);
+ }
+ return iterationNumber - 1;
+ }
+
+ public void runIteration(Configuration conf, Path corpusInput, Path modelInput, Path modelOutput,
+ int iterationNumber, int maxIterations, int numReduceTasks)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ String jobName = String.format("Iteration %d of %d, input path: %s",
+ iterationNumber, maxIterations, modelInput);
+ log.info("About to run: {}", jobName);
+ Job job = prepareJob(corpusInput, modelOutput, CachingCVB0Mapper.class, IntWritable.class, VectorWritable.class,
+ VectorSumReducer.class, IntWritable.class, VectorWritable.class);
+ job.setCombinerClass(VectorSumReducer.class);
+ job.setNumReduceTasks(numReduceTasks);
+ job.setJobName(jobName);
+ setModelPaths(job, modelInput);
+ HadoopUtil.delete(conf, modelOutput);
+ if (!job.waitForCompletion(true)) {
+ throw new InterruptedException(String.format("Failed to complete iteration %d stage 1",
+ iterationNumber));
+ }
+ }
+
+ private static void setModelPaths(Job job, Path modelPath) throws IOException {
+ Configuration conf = job.getConfiguration();
+ if (modelPath == null || !FileSystem.get(modelPath.toUri(), conf).exists(modelPath)) {
+ return;
+ }
+ FileStatus[] statuses = FileSystem.get(modelPath.toUri(), conf).listStatus(modelPath, PathFilters.partFilter());
+ Preconditions.checkState(statuses.length > 0, "No part files found in model path '%s'", modelPath.toString());
+ String[] modelPaths = new String[statuses.length];
+ for (int i = 0; i < statuses.length; i++) {
+ modelPaths[i] = statuses[i].getPath().toUri().toString();
+ }
+ conf.setStrings(MODEL_PATHS, modelPaths);
+ }
+
+ public static Path[] getModelPaths(Configuration conf) {
+ String[] modelPathNames = conf.getStrings(MODEL_PATHS);
+ if (modelPathNames == null || modelPathNames.length == 0) {
+ return null;
+ }
+ Path[] modelPaths = new Path[modelPathNames.length];
+ for (int i = 0; i < modelPathNames.length; i++) {
+ modelPaths[i] = new Path(modelPathNames[i]);
+ }
+ return modelPaths;
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new CVB0Driver(), args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java
new file mode 100644
index 0000000..1253942
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.IOException;
+
+/**
+ * Performs L1 normalization of input vectors.
+ */
+public class CVB0TopicTermVectorNormalizerMapper extends
+ Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
+
+ @Override
+ protected void map(IntWritable key, VectorWritable value, Context context) throws IOException,
+ InterruptedException {
+ value.get().assign(Functions.div(value.get().norm(1.0)));
+ context.write(key, value);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java
new file mode 100644
index 0000000..96f36d4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java
@@ -0,0 +1,133 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * Run ensemble learning via loading the {@link ModelTrainer} with two {@link TopicModel} instances:
+ * one from the previous iteration, the other empty. Inference is done on the first, and the
+ * learning updates are stored in the second, and only emitted at cleanup().
+ * <p/>
+ * In terms of obvious performance improvements still available, the memory footprint in this
+ * Mapper could be dropped by half if we accumulated model updates onto the model we're using
+ * for inference, which might also speed up convergence, as we'd be able to take advantage of
+ * learning <em>during</em> iteration, not just after each one is done. Most likely we don't
+ * really need to accumulate double values in the model either, floats would most likely be
+ * sufficient. Between these two, we could squeeze another factor of 4 in memory efficiency.
+ * <p/>
+ * In terms of CPU, we're re-learning the p(topic|doc) distribution on every iteration, starting
+ * from scratch. This is usually only 10 fixed-point iterations per doc, but that's 10x more than
+ * only 1. To avoid having to do this, we would need to do a map-side join of the unchanging
+ * corpus with the continually-improving p(topic|doc) matrix, and then emit multiple outputs
+ * from the mappers to make sure we can do the reduce model averaging as well. Tricky, but
+ * possibly worth it.
+ * <p/>
+ * {@link ModelTrainer} already takes advantage (in maybe the not-nice way) of multi-core
+ * availability by doing multithreaded learning, see that class for details.
+ */
+public class CachingCVB0Mapper
+ extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
+
+ private static final Logger log = LoggerFactory.getLogger(CachingCVB0Mapper.class);
+
+ private ModelTrainer modelTrainer;
+ private TopicModel readModel;
+ private TopicModel writeModel;
+ private int maxIters;
+ private int numTopics;
+
+ protected ModelTrainer getModelTrainer() {
+ return modelTrainer;
+ }
+
+ protected int getMaxIters() {
+ return maxIters;
+ }
+
+ protected int getNumTopics() {
+ return numTopics;
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ log.info("Retrieving configuration");
+ Configuration conf = context.getConfiguration();
+ float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN);
+ float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN);
+ long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L);
+ numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1);
+ int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1);
+ int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1);
+ int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4);
+ maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10);
+ float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f);
+
+ log.info("Initializing read model");
+ Path[] modelPaths = CVB0Driver.getModelPaths(conf);
+ if (modelPaths != null && modelPaths.length > 0) {
+ readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths);
+ } else {
+ log.info("No model files found");
+ readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null,
+ numTrainThreads, modelWeight);
+ }
+
+ log.info("Initializing write model");
+ writeModel = modelWeight == 1
+ ? new TopicModel(numTopics, numTerms, eta, alpha, null, numUpdateThreads)
+ : readModel;
+
+ log.info("Initializing model trainer");
+ modelTrainer = new ModelTrainer(readModel, writeModel, numTrainThreads, numTopics, numTerms);
+ modelTrainer.start();
+ }
+
+ @Override
+ public void map(IntWritable docId, VectorWritable document, Context context)
+ throws IOException, InterruptedException {
+ /* where to get docTopics? */
+ Vector topicVector = new DenseVector(numTopics).assign(1.0 / numTopics);
+ modelTrainer.train(document.get(), topicVector, true, maxIters);
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ log.info("Stopping model trainer");
+ modelTrainer.stop();
+
+ log.info("Writing model");
+ TopicModel readFrom = modelTrainer.getReadModel();
+ for (MatrixSlice topic : readFrom) {
+ context.write(new IntWritable(topic.index()), new VectorWritable(topic.vector()));
+ }
+ readModel.stop();
+ writeModel.stop();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java
new file mode 100644
index 0000000..da77baf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.MemoryUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Random;
+
+public class CachingCVB0PerplexityMapper extends
+ Mapper<IntWritable, VectorWritable, DoubleWritable, DoubleWritable> {
+ /**
+ * Hadoop counters for {@link CachingCVB0PerplexityMapper}, to aid in debugging.
+ */
+ public enum Counters {
+ SAMPLED_DOCUMENTS
+ }
+
+ private static final Logger log = LoggerFactory.getLogger(CachingCVB0PerplexityMapper.class);
+
+ private ModelTrainer modelTrainer;
+ private TopicModel readModel;
+ private int maxIters;
+ private int numTopics;
+ private float testFraction;
+ private Random random;
+ private Vector topicVector;
+ private final DoubleWritable outKey = new DoubleWritable();
+ private final DoubleWritable outValue = new DoubleWritable();
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ MemoryUtil.startMemoryLogger(5000);
+
+ log.info("Retrieving configuration");
+ Configuration conf = context.getConfiguration();
+ float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN);
+ float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN);
+ long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L);
+ random = RandomUtils.getRandom(seed);
+ numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1);
+ int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1);
+ int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1);
+ int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4);
+ maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10);
+ float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f);
+ testFraction = conf.getFloat(CVB0Driver.TEST_SET_FRACTION, 0.1f);
+
+ log.info("Initializing read model");
+ Path[] modelPaths = CVB0Driver.getModelPaths(conf);
+ if (modelPaths != null && modelPaths.length > 0) {
+ readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths);
+ } else {
+ log.info("No model files found");
+ readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null,
+ numTrainThreads, modelWeight);
+ }
+
+ log.info("Initializing model trainer");
+ modelTrainer = new ModelTrainer(readModel, null, numTrainThreads, numTopics, numTerms);
+
+ log.info("Initializing topic vector");
+ topicVector = new DenseVector(new double[numTopics]);
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ readModel.stop();
+ MemoryUtil.stopMemoryLogger();
+ }
+
+ @Override
+ public void map(IntWritable docId, VectorWritable document, Context context)
+ throws IOException, InterruptedException {
+ if (testFraction < 1.0f && random.nextFloat() >= testFraction) {
+ return;
+ }
+ context.getCounter(Counters.SAMPLED_DOCUMENTS).increment(1);
+ outKey.set(document.get().norm(1));
+ outValue.set(modelTrainer.calculatePerplexity(document.get(), topicVector.assign(1.0 / numTopics), maxIters));
+ context.write(outKey, outValue);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
new file mode 100644
index 0000000..07ae100
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
@@ -0,0 +1,515 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.DistributedRowMatrixWriter;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.NamedVector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Runs the same algorithm as {@link CVB0Driver}, but sequentially, in memory. Memory requirements
+ * are currently: the entire corpus is read into RAM, two copies of the model (each of size
+ * numTerms * numTopics), and another matrix of size numDocs * numTopics is held in memory
+ * (to store p(topic|doc) for all docs).
+ *
+ * But if all this fits in memory, this can be significantly faster than an iterative MR job.
+ */
+public class InMemoryCollapsedVariationalBayes0 extends AbstractJob {
+
+ private static final Logger log = LoggerFactory.getLogger(InMemoryCollapsedVariationalBayes0.class);
+
+ private int numTopics;
+ private int numTerms;
+ private int numDocuments;
+ private double alpha;
+ private double eta;
+ //private int minDfCt;
+ //private double maxDfPct;
+ private boolean verbose = false;
+ private String[] terms; // of length numTerms;
+ private Matrix corpusWeights; // length numDocs;
+ private double totalCorpusWeight;
+ private double initialModelCorpusFraction;
+ private Matrix docTopicCounts;
+ private int numTrainingThreads;
+ private int numUpdatingThreads;
+ private ModelTrainer modelTrainer;
+
+ private InMemoryCollapsedVariationalBayes0() {
+ // only for main usage
+ }
+
+ public void setVerbose(boolean verbose) {
+ this.verbose = verbose;
+ }
+
+ public InMemoryCollapsedVariationalBayes0(Matrix corpus,
+ String[] terms,
+ int numTopics,
+ double alpha,
+ double eta,
+ int numTrainingThreads,
+ int numUpdatingThreads,
+ double modelCorpusFraction) {
+ //this.seed = seed;
+ this.numTopics = numTopics;
+ this.alpha = alpha;
+ this.eta = eta;
+ //this.minDfCt = 0;
+ //this.maxDfPct = 1.0f;
+ corpusWeights = corpus;
+ numDocuments = corpus.numRows();
+ this.terms = terms;
+ this.initialModelCorpusFraction = modelCorpusFraction;
+ numTerms = terms != null ? terms.length : corpus.numCols();
+ Map<String, Integer> termIdMap = Maps.newHashMap();
+ if (terms != null) {
+ for (int t = 0; t < terms.length; t++) {
+ termIdMap.put(terms[t], t);
+ }
+ }
+ this.numTrainingThreads = numTrainingThreads;
+ this.numUpdatingThreads = numUpdatingThreads;
+ postInitCorpus();
+ initializeModel();
+ }
+
+ private void postInitCorpus() {
+ totalCorpusWeight = 0;
+ int numNonZero = 0;
+ for (int i = 0; i < numDocuments; i++) {
+ Vector v = corpusWeights.viewRow(i);
+ double norm;
+ if (v != null && (norm = v.norm(1)) != 0) {
+ numNonZero += v.getNumNondefaultElements();
+ totalCorpusWeight += norm;
+ }
+ }
+ String s = "Initializing corpus with %d docs, %d terms, %d nonzero entries, total termWeight %f";
+ log.info(String.format(s, numDocuments, numTerms, numNonZero, totalCorpusWeight));
+ }
+
+ private void initializeModel() {
+ TopicModel topicModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(), terms,
+ numUpdatingThreads, initialModelCorpusFraction == 0 ? 1 : initialModelCorpusFraction * totalCorpusWeight);
+ topicModel.setConf(getConf());
+
+ TopicModel updatedModel = initialModelCorpusFraction == 0
+ ? new TopicModel(numTopics, numTerms, eta, alpha, null, terms, numUpdatingThreads, 1)
+ : topicModel;
+ updatedModel.setConf(getConf());
+ docTopicCounts = new DenseMatrix(numDocuments, numTopics);
+ docTopicCounts.assign(1.0 / numTopics);
+ modelTrainer = new ModelTrainer(topicModel, updatedModel, numTrainingThreads, numTopics, numTerms);
+ }
+
+ /*
+ private void inferDocuments(double convergence, int maxIter, boolean recalculate) {
+ for (int docId = 0; docId < corpusWeights.numRows() ; docId++) {
+ Vector inferredDocument = topicModel.infer(corpusWeights.viewRow(docId),
+ docTopicCounts.viewRow(docId));
+ // do what now?
+ }
+ }
+ */
+
+ public void trainDocuments() {
+ trainDocuments(0);
+ }
+
+ public void trainDocuments(double testFraction) {
+ long start = System.nanoTime();
+ modelTrainer.start();
+ for (int docId = 0; docId < corpusWeights.numRows(); docId++) {
+ if (testFraction == 0 || docId % (1 / testFraction) != 0) {
+ Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics); // docTopicCounts.getRow(docId)
+ modelTrainer.trainSync(corpusWeights.viewRow(docId), docTopics , true, 10);
+ }
+ }
+ modelTrainer.stop();
+ logTime("train documents", System.nanoTime() - start);
+ }
+
+ /*
+ private double error(int docId) {
+ Vector docTermCounts = corpusWeights.viewRow(docId);
+ if (docTermCounts == null) {
+ return 0;
+ } else {
+ Vector expectedDocTermCounts =
+ topicModel.infer(corpusWeights.viewRow(docId), docTopicCounts.viewRow(docId));
+ double expectedNorm = expectedDocTermCounts.norm(1);
+ return expectedDocTermCounts.times(docTermCounts.norm(1)/expectedNorm)
+ .minus(docTermCounts).norm(1);
+ }
+ }
+
+ private double error() {
+ long time = System.nanoTime();
+ double error = 0;
+ for (int docId = 0; docId < numDocuments; docId++) {
+ error += error(docId);
+ }
+ logTime("error calculation", System.nanoTime() - time);
+ return error / totalCorpusWeight;
+ }
+ */
+
+ public double iterateUntilConvergence(double minFractionalErrorChange,
+ int maxIterations, int minIter) {
+ return iterateUntilConvergence(minFractionalErrorChange, maxIterations, minIter, 0);
+ }
+
+ public double iterateUntilConvergence(double minFractionalErrorChange,
+ int maxIterations, int minIter, double testFraction) {
+ int iter = 0;
+ double oldPerplexity = 0;
+ while (iter < minIter) {
+ trainDocuments(testFraction);
+ if (verbose) {
+ log.info("model after: {}: {}", iter, modelTrainer.getReadModel());
+ }
+ log.info("iteration {} complete", iter);
+ oldPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts,
+ testFraction);
+ log.info("{} = perplexity", oldPerplexity);
+ iter++;
+ }
+ double newPerplexity = 0;
+ double fractionalChange = Double.MAX_VALUE;
+ while (iter < maxIterations && fractionalChange > minFractionalErrorChange) {
+ trainDocuments();
+ if (verbose) {
+ log.info("model after: {}: {}", iter, modelTrainer.getReadModel());
+ }
+ newPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts,
+ testFraction);
+ log.info("{} = perplexity", newPerplexity);
+ iter++;
+ fractionalChange = Math.abs(newPerplexity - oldPerplexity) / oldPerplexity;
+ log.info("{} = fractionalChange", fractionalChange);
+ oldPerplexity = newPerplexity;
+ }
+ if (iter < maxIterations) {
+ log.info(String.format("Converged! fractional error change: %f, error %f",
+ fractionalChange, newPerplexity));
+ } else {
+ log.info(String.format("Reached max iteration count (%d), fractional error change: %f, error: %f",
+ maxIterations, fractionalChange, newPerplexity));
+ }
+ return newPerplexity;
+ }
+
+ public void writeModel(Path outputPath) throws IOException {
+ modelTrainer.persist(outputPath);
+ }
+
+ private static void logTime(String label, long nanos) {
+ log.info("{} time: {}ms", label, nanos / 1.0e6);
+ }
+
+ public static int main2(String[] args, Configuration conf) throws Exception {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option helpOpt = DefaultOptionCreator.helpOption();
+
+ Option inputDirOpt = obuilder.withLongName("input").withRequired(true).withArgument(
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Directory on HDFS containing the collapsed, properly formatted files having "
+ + "one doc per line").withShortName("i").create();
+
+ Option dictOpt = obuilder.withLongName("dictionary").withRequired(false).withArgument(
+ abuilder.withName("dictionary").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The path to the term-dictionary format is ... ").withShortName("d").create();
+
+ Option dfsOpt = obuilder.withLongName("dfs").withRequired(false).withArgument(
+ abuilder.withName("dfs").withMinimum(1).withMaximum(1).create()).withDescription(
+ "HDFS namenode URI").withShortName("dfs").create();
+
+ Option numTopicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(abuilder
+ .withName("numTopics").withMinimum(1).withMaximum(1)
+ .create()).withDescription("Number of topics to learn").withShortName("top").create();
+
+ Option outputTopicFileOpt = obuilder.withLongName("topicOutputFile").withRequired(true).withArgument(
+ abuilder.withName("topicOutputFile").withMinimum(1).withMaximum(1).create())
+ .withDescription("File to write out p(term | topic)").withShortName("to").create();
+
+ Option outputDocFileOpt = obuilder.withLongName("docOutputFile").withRequired(true).withArgument(
+ abuilder.withName("docOutputFile").withMinimum(1).withMaximum(1).create())
+ .withDescription("File to write out p(topic | docid)").withShortName("do").create();
+
+ Option alphaOpt = obuilder.withLongName("alpha").withRequired(false).withArgument(abuilder
+ .withName("alpha").withMinimum(1).withMaximum(1).withDefault("0.1").create())
+ .withDescription("Smoothing parameter for p(topic | document) prior").withShortName("a").create();
+
+ Option etaOpt = obuilder.withLongName("eta").withRequired(false).withArgument(abuilder
+ .withName("eta").withMinimum(1).withMaximum(1).withDefault("0.1").create())
+ .withDescription("Smoothing parameter for p(term | topic)").withShortName("e").create();
+
+ Option maxIterOpt = obuilder.withLongName("maxIterations").withRequired(false).withArgument(abuilder
+ .withName("maxIterations").withMinimum(1).withMaximum(1).withDefault("10").create())
+ .withDescription("Maximum number of training passes").withShortName("m").create();
+
+ Option modelCorpusFractionOption = obuilder.withLongName("modelCorpusFraction")
+ .withRequired(false).withArgument(abuilder.withName("modelCorpusFraction").withMinimum(1)
+ .withMaximum(1).withDefault("0.0").create()).withShortName("mcf")
+ .withDescription("For online updates, initial value of |model|/|corpus|").create();
+
+ Option burnInOpt = obuilder.withLongName("burnInIterations").withRequired(false).withArgument(abuilder
+ .withName("burnInIterations").withMinimum(1).withMaximum(1).withDefault("5").create())
+ .withDescription("Minimum number of iterations").withShortName("b").create();
+
+ Option convergenceOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(abuilder
+ .withName("convergence").withMinimum(1).withMaximum(1).withDefault("0.0").create())
+ .withDescription("Fractional rate of perplexity to consider convergence").withShortName("c").create();
+
+ Option reInferDocTopicsOpt = obuilder.withLongName("reInferDocTopics").withRequired(false)
+ .withArgument(abuilder.withName("reInferDocTopics").withMinimum(1).withMaximum(1)
+ .withDefault("no").create())
+ .withDescription("re-infer p(topic | doc) : [no | randstart | continue]")
+ .withShortName("rdt").create();
+
+ Option numTrainThreadsOpt = obuilder.withLongName("numTrainThreads").withRequired(false)
+ .withArgument(abuilder.withName("numTrainThreads").withMinimum(1).withMaximum(1)
+ .withDefault("1").create())
+ .withDescription("number of threads to train with")
+ .withShortName("ntt").create();
+
+ Option numUpdateThreadsOpt = obuilder.withLongName("numUpdateThreads").withRequired(false)
+ .withArgument(abuilder.withName("numUpdateThreads").withMinimum(1).withMaximum(1)
+ .withDefault("1").create())
+ .withDescription("number of threads to update the model with")
+ .withShortName("nut").create();
+
+ Option verboseOpt = obuilder.withLongName("verbose").withRequired(false)
+ .withArgument(abuilder.withName("verbose").withMinimum(1).withMaximum(1)
+ .withDefault("false").create())
+ .withDescription("print verbose information, like top-terms in each topic, during iteration")
+ .withShortName("v").create();
+
+ Group group = gbuilder.withName("Options").withOption(inputDirOpt).withOption(numTopicsOpt)
+ .withOption(alphaOpt).withOption(etaOpt)
+ .withOption(maxIterOpt).withOption(burnInOpt).withOption(convergenceOpt)
+ .withOption(dictOpt).withOption(reInferDocTopicsOpt)
+ .withOption(outputDocFileOpt).withOption(outputTopicFileOpt).withOption(dfsOpt)
+ .withOption(numTrainThreadsOpt).withOption(numUpdateThreadsOpt)
+ .withOption(modelCorpusFractionOption).withOption(verboseOpt).create();
+
+ try {
+ Parser parser = new Parser();
+
+ parser.setGroup(group);
+ parser.setHelpOption(helpOpt);
+ CommandLine cmdLine = parser.parse(args);
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return -1;
+ }
+
+ String inputDirString = (String) cmdLine.getValue(inputDirOpt);
+ String dictDirString = cmdLine.hasOption(dictOpt) ? (String)cmdLine.getValue(dictOpt) : null;
+ int numTopics = Integer.parseInt((String) cmdLine.getValue(numTopicsOpt));
+ double alpha = Double.parseDouble((String)cmdLine.getValue(alphaOpt));
+ double eta = Double.parseDouble((String)cmdLine.getValue(etaOpt));
+ int maxIterations = Integer.parseInt((String)cmdLine.getValue(maxIterOpt));
+ int burnInIterations = Integer.parseInt((String)cmdLine.getValue(burnInOpt));
+ double minFractionalErrorChange = Double.parseDouble((String) cmdLine.getValue(convergenceOpt));
+ int numTrainThreads = Integer.parseInt((String)cmdLine.getValue(numTrainThreadsOpt));
+ int numUpdateThreads = Integer.parseInt((String)cmdLine.getValue(numUpdateThreadsOpt));
+ String topicOutFile = (String)cmdLine.getValue(outputTopicFileOpt);
+ String docOutFile = (String)cmdLine.getValue(outputDocFileOpt);
+ //String reInferDocTopics = (String)cmdLine.getValue(reInferDocTopicsOpt);
+ boolean verbose = Boolean.parseBoolean((String) cmdLine.getValue(verboseOpt));
+ double modelCorpusFraction = Double.parseDouble((String)cmdLine.getValue(modelCorpusFractionOption));
+
+ long start = System.nanoTime();
+
+ if (conf.get("fs.default.name") == null) {
+ String dfsNameNode = (String)cmdLine.getValue(dfsOpt);
+ conf.set("fs.default.name", dfsNameNode);
+ }
+ String[] terms = loadDictionary(dictDirString, conf);
+ logTime("dictionary loading", System.nanoTime() - start);
+ start = System.nanoTime();
+ Matrix corpus = loadVectors(inputDirString, conf);
+ logTime("vector seqfile corpus loading", System.nanoTime() - start);
+ start = System.nanoTime();
+ InMemoryCollapsedVariationalBayes0 cvb0 =
+ new InMemoryCollapsedVariationalBayes0(corpus, terms, numTopics, alpha, eta,
+ numTrainThreads, numUpdateThreads, modelCorpusFraction);
+ logTime("cvb0 init", System.nanoTime() - start);
+
+ start = System.nanoTime();
+ cvb0.setVerbose(verbose);
+ cvb0.iterateUntilConvergence(minFractionalErrorChange, maxIterations, burnInIterations);
+ logTime("total training time", System.nanoTime() - start);
+
+ /*
+ if ("randstart".equalsIgnoreCase(reInferDocTopics)) {
+ cvb0.inferDocuments(0.0, 100, true);
+ } else if ("continue".equalsIgnoreCase(reInferDocTopics)) {
+ cvb0.inferDocuments(0.0, 100, false);
+ }
+ */
+
+ start = System.nanoTime();
+ cvb0.writeModel(new Path(topicOutFile));
+ DistributedRowMatrixWriter.write(new Path(docOutFile), conf, cvb0.docTopicCounts);
+ logTime("printTopics", System.nanoTime() - start);
+ } catch (OptionException e) {
+ log.error("Error while parsing options", e);
+ CommandLineUtil.printHelp(group);
+ }
+ return 0;
+ }
+
+ /*
+ private static Map<Integer, Map<String, Integer>> loadCorpus(String path) throws IOException {
+ List<String> lines = Resources.readLines(Resources.getResource(path), Charsets.UTF_8);
+ Map<Integer, Map<String, Integer>> corpus = Maps.newHashMap();
+ for (int i=0; i<lines.size(); i++) {
+ String line = lines.get(i);
+ Map<String, Integer> doc = Maps.newHashMap();
+ for (String s : line.split(" ")) {
+ s = s.replaceAll("\\W", "").toLowerCase().trim();
+ if (s.length() == 0) {
+ continue;
+ }
+ if (!doc.containsKey(s)) {
+ doc.put(s, 0);
+ }
+ doc.put(s, doc.get(s) + 1);
+ }
+ corpus.put(i, doc);
+ }
+ return corpus;
+ }
+ */
+
+ private static String[] loadDictionary(String dictionaryPath, Configuration conf) {
+ if (dictionaryPath == null) {
+ return null;
+ }
+ Path dictionaryFile = new Path(dictionaryPath);
+ List<Pair<Integer, String>> termList = Lists.newArrayList();
+ int maxTermId = 0;
+ // key is word value is id
+ for (Pair<Writable, IntWritable> record
+ : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) {
+ termList.add(new Pair<>(record.getSecond().get(),
+ record.getFirst().toString()));
+ maxTermId = Math.max(maxTermId, record.getSecond().get());
+ }
+ String[] terms = new String[maxTermId + 1];
+ for (Pair<Integer, String> pair : termList) {
+ terms[pair.getFirst()] = pair.getSecond();
+ }
+ return terms;
+ }
+
+ @Override
+ public Configuration getConf() {
+ return super.getConf();
+ }
+
+ private static Matrix loadVectors(String vectorPathString, Configuration conf)
+ throws IOException {
+ Path vectorPath = new Path(vectorPathString);
+ FileSystem fs = vectorPath.getFileSystem(conf);
+ List<Path> subPaths = Lists.newArrayList();
+ if (fs.isFile(vectorPath)) {
+ subPaths.add(vectorPath);
+ } else {
+ for (FileStatus fileStatus : fs.listStatus(vectorPath, PathFilters.logsCRCFilter())) {
+ subPaths.add(fileStatus.getPath());
+ }
+ }
+ List<Pair<Integer, Vector>> rowList = Lists.newArrayList();
+ int numRows = Integer.MIN_VALUE;
+ int numCols = -1;
+ boolean sequentialAccess = false;
+ for (Path subPath : subPaths) {
+ for (Pair<IntWritable, VectorWritable> record
+ : new SequenceFileIterable<IntWritable, VectorWritable>(subPath, true, conf)) {
+ int id = record.getFirst().get();
+ Vector vector = record.getSecond().get();
+ if (vector instanceof NamedVector) {
+ vector = ((NamedVector)vector).getDelegate();
+ }
+ if (numCols < 0) {
+ numCols = vector.size();
+ sequentialAccess = vector.isSequentialAccess();
+ }
+ rowList.add(Pair.of(id, vector));
+ numRows = Math.max(numRows, id);
+ }
+ }
+ numRows++;
+ Vector[] rowVectors = new Vector[numRows];
+ for (Pair<Integer, Vector> pair : rowList) {
+ rowVectors[pair.getFirst()] = pair.getSecond();
+ }
+ return new SparseRowMatrix(numRows, numCols, rowVectors, true, !sequentialAccess);
+
+ }
+
+ @Override
+ public int run(String[] strings) throws Exception {
+ return main2(strings, getConf());
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new InMemoryCollapsedVariationalBayes0(), args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java
new file mode 100644
index 0000000..912b6d5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java
@@ -0,0 +1,301 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Multithreaded LDA model trainer class, which primarily operates by running a "map/reduce"
+ * operation, all in memory locally (ie not a hadoop job!) : the "map" operation is to take
+ * the "read-only" {@link TopicModel} and use it to iteratively learn the p(topic|term, doc)
+ * distribution for documents (this can be done in parallel across many documents, as the
+ * "read-only" model is, well, read-only. Then the outputs of this are "reduced" onto the
+ * "write" model, and these updates are not parallelizable in the same way: individual
+ * documents can't be added to the same entries in different threads at the same time, but
+ * updates across many topics to the same term from the same document can be done in parallel,
+ * so they are.
+ *
+ * Because computation is done asynchronously, when iteration is done, it's important to call
+ * the stop() method, which blocks until work is complete.
+ *
+ * Setting the read model and the write model to be the same object may not quite work yet,
+ * on account of parallelism badness.
+ */
+public class ModelTrainer {
+
+ private static final Logger log = LoggerFactory.getLogger(ModelTrainer.class);
+
+ private final int numTopics;
+ private final int numTerms;
+ private TopicModel readModel;
+ private TopicModel writeModel;
+ private ThreadPoolExecutor threadPool;
+ private BlockingQueue<Runnable> workQueue;
+ private final int numTrainThreads;
+ private final boolean isReadWrite;
+
+ public ModelTrainer(TopicModel initialReadModel, TopicModel initialWriteModel,
+ int numTrainThreads, int numTopics, int numTerms) {
+ this.readModel = initialReadModel;
+ this.writeModel = initialWriteModel;
+ this.numTrainThreads = numTrainThreads;
+ this.numTopics = numTopics;
+ this.numTerms = numTerms;
+ isReadWrite = initialReadModel == initialWriteModel;
+ }
+
+ /**
+ * WARNING: this constructor may not lead to good behavior. What should be verified is that
+ * the model updating process does not conflict with model reading. It might work, but then
+ * again, it might not!
+ * @param model to be used for both reading (inference) and accumulating (learning)
+ * @param numTrainThreads
+ * @param numTopics
+ * @param numTerms
+ */
+ public ModelTrainer(TopicModel model, int numTrainThreads, int numTopics, int numTerms) {
+ this(model, model, numTrainThreads, numTopics, numTerms);
+ }
+
+ public TopicModel getReadModel() {
+ return readModel;
+ }
+
+ public void start() {
+ log.info("Starting training threadpool with {} threads", numTrainThreads);
+ workQueue = new ArrayBlockingQueue<>(numTrainThreads * 10);
+ threadPool = new ThreadPoolExecutor(numTrainThreads, numTrainThreads, 0, TimeUnit.SECONDS,
+ workQueue);
+ threadPool.allowCoreThreadTimeOut(false);
+ threadPool.prestartAllCoreThreads();
+ writeModel.reset();
+ }
+
+ public void train(VectorIterable matrix, VectorIterable docTopicCounts) {
+ train(matrix, docTopicCounts, 1);
+ }
+
+ public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts) {
+ return calculatePerplexity(matrix, docTopicCounts, 0);
+ }
+
+ public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts,
+ double testFraction) {
+ Iterator<MatrixSlice> docIterator = matrix.iterator();
+ Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator();
+ double perplexity = 0;
+ double matrixNorm = 0;
+ while (docIterator.hasNext() && docTopicIterator.hasNext()) {
+ MatrixSlice docSlice = docIterator.next();
+ MatrixSlice topicSlice = docTopicIterator.next();
+ int docId = docSlice.index();
+ Vector document = docSlice.vector();
+ Vector topicDist = topicSlice.vector();
+ if (testFraction == 0 || docId % (1 / testFraction) == 0) {
+ trainSync(document, topicDist, false, 10);
+ perplexity += readModel.perplexity(document, topicDist);
+ matrixNorm += document.norm(1);
+ }
+ }
+ return perplexity / matrixNorm;
+ }
+
+ public void train(VectorIterable matrix, VectorIterable docTopicCounts, int numDocTopicIters) {
+ start();
+ Iterator<MatrixSlice> docIterator = matrix.iterator();
+ Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator();
+ long startTime = System.nanoTime();
+ int i = 0;
+ double[] times = new double[100];
+ Map<Vector, Vector> batch = Maps.newHashMap();
+ int numTokensInBatch = 0;
+ long batchStart = System.nanoTime();
+ while (docIterator.hasNext() && docTopicIterator.hasNext()) {
+ i++;
+ Vector document = docIterator.next().vector();
+ Vector topicDist = docTopicIterator.next().vector();
+ if (isReadWrite) {
+ if (batch.size() < numTrainThreads) {
+ batch.put(document, topicDist);
+ if (log.isDebugEnabled()) {
+ numTokensInBatch += document.getNumNondefaultElements();
+ }
+ } else {
+ batchTrain(batch, true, numDocTopicIters);
+ long time = System.nanoTime();
+ log.debug("trained {} docs with {} tokens, start time {}, end time {}",
+ numTrainThreads, numTokensInBatch, batchStart, time);
+ batchStart = time;
+ numTokensInBatch = 0;
+ }
+ } else {
+ long start = System.nanoTime();
+ train(document, topicDist, true, numDocTopicIters);
+ if (log.isDebugEnabled()) {
+ times[i % times.length] =
+ (System.nanoTime() - start) / (1.0e6 * document.getNumNondefaultElements());
+ if (i % 100 == 0) {
+ long time = System.nanoTime() - startTime;
+ log.debug("trained {} documents in {}ms", i, time / 1.0e6);
+ if (i % 500 == 0) {
+ Arrays.sort(times);
+ log.debug("training took median {}ms per token-instance", times[times.length / 2]);
+ }
+ }
+ }
+ }
+ }
+ stop();
+ }
+
+ public void batchTrain(Map<Vector, Vector> batch, boolean update, int numDocTopicsIters) {
+ while (true) {
+ try {
+ List<TrainerRunnable> runnables = Lists.newArrayList();
+ for (Map.Entry<Vector, Vector> entry : batch.entrySet()) {
+ runnables.add(new TrainerRunnable(readModel, null, entry.getKey(),
+ entry.getValue(), new SparseRowMatrix(numTopics, numTerms, true),
+ numDocTopicsIters));
+ }
+ threadPool.invokeAll(runnables);
+ if (update) {
+ for (TrainerRunnable runnable : runnables) {
+ writeModel.update(runnable.docTopicModel);
+ }
+ }
+ break;
+ } catch (InterruptedException e) {
+ log.warn("Interrupted during batch training, retrying!", e);
+ }
+ }
+ }
+
+ public void train(Vector document, Vector docTopicCounts, boolean update, int numDocTopicIters) {
+ while (true) {
+ try {
+ workQueue.put(new TrainerRunnable(readModel, update
+ ? writeModel
+ : null, document, docTopicCounts, new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters));
+ return;
+ } catch (InterruptedException e) {
+ log.warn("Interrupted waiting to submit document to work queue: {}", document, e);
+ }
+ }
+ }
+
+ public void trainSync(Vector document, Vector docTopicCounts, boolean update,
+ int numDocTopicIters) {
+ new TrainerRunnable(readModel, update
+ ? writeModel
+ : null, document, docTopicCounts, new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters).run();
+ }
+
+ public double calculatePerplexity(Vector document, Vector docTopicCounts, int numDocTopicIters) {
+ TrainerRunnable runner = new TrainerRunnable(readModel, null, document, docTopicCounts,
+ new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters);
+ return runner.call();
+ }
+
+ public void stop() {
+ long startTime = System.nanoTime();
+ log.info("Initiating stopping of training threadpool");
+ try {
+ threadPool.shutdown();
+ if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) {
+ log.warn("Threadpool timed out on await termination - jobs still running!");
+ }
+ long newTime = System.nanoTime();
+ log.info("threadpool took: {}ms", (newTime - startTime) / 1.0e6);
+ startTime = newTime;
+ readModel.stop();
+ newTime = System.nanoTime();
+ log.info("readModel.stop() took {}ms", (newTime - startTime) / 1.0e6);
+ startTime = newTime;
+ writeModel.stop();
+ newTime = System.nanoTime();
+ log.info("writeModel.stop() took {}ms", (newTime - startTime) / 1.0e6);
+ TopicModel tmpModel = writeModel;
+ writeModel = readModel;
+ readModel = tmpModel;
+ } catch (InterruptedException e) {
+ log.error("Interrupted shutting down!", e);
+ }
+ }
+
+ public void persist(Path outputPath) throws IOException {
+ readModel.persist(outputPath, true);
+ }
+
+ private static final class TrainerRunnable implements Runnable, Callable<Double> {
+ private final TopicModel readModel;
+ private final TopicModel writeModel;
+ private final Vector document;
+ private final Vector docTopics;
+ private final Matrix docTopicModel;
+ private final int numDocTopicIters;
+
+ private TrainerRunnable(TopicModel readModel, TopicModel writeModel, Vector document,
+ Vector docTopics, Matrix docTopicModel, int numDocTopicIters) {
+ this.readModel = readModel;
+ this.writeModel = writeModel;
+ this.document = document;
+ this.docTopics = docTopics;
+ this.docTopicModel = docTopicModel;
+ this.numDocTopicIters = numDocTopicIters;
+ }
+
+ @Override
+ public void run() {
+ for (int i = 0; i < numDocTopicIters; i++) {
+ // synchronous read-only call:
+ readModel.trainDocTopicModel(document, docTopics, docTopicModel);
+ }
+ if (writeModel != null) {
+ // parallel call which is read-only on the docTopicModel, and write-only on the writeModel
+ // this method does not return until all rows of the docTopicModel have been submitted
+ // to write work queues
+ writeModel.update(docTopicModel);
+ }
+ }
+
+ @Override
+ public Double call() {
+ run();
+ return readModel.perplexity(document, docTopics);
+ }
+ }
+}