You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:08:05 UTC
[34/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java
new file mode 100644
index 0000000..0f88a70
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java
@@ -0,0 +1,332 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.Arrays;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+
+/** Train a {@link MultilayerPerceptron}. */
+public final class TrainMultilayerPerceptron {
+
+ private static final Logger log = LoggerFactory.getLogger(TrainMultilayerPerceptron.class);
+
+ /** The parameters used by MLP. */
+ static class Parameters {
+ double learningRate;
+ double momemtumWeight;
+ double regularizationWeight;
+
+ String inputFilePath;
+ boolean skipHeader;
+ Map<String, Integer> labelsIndex = Maps.newHashMap();
+
+ String modelFilePath;
+ boolean updateModel;
+ List<Integer> layerSizeList = Lists.newArrayList();
+ String squashingFunctionName;
+ }
+
+ /*
+ private double learningRate;
+ private double momemtumWeight;
+ private double regularizationWeight;
+
+ private String inputFilePath;
+ private boolean skipHeader;
+ private Map<String, Integer> labelsIndex = Maps.newHashMap();
+
+ private String modelFilePath;
+ private boolean updateModel;
+ private List<Integer> layerSizeList = Lists.newArrayList();
+ private String squashingFunctionName;*/
+
+ public static void main(String[] args) throws Exception {
+ Parameters parameters = new Parameters();
+
+ if (parseArgs(args, parameters)) {
+ log.info("Validate model...");
+ // check whether the model already exists
+ Path modelPath = new Path(parameters.modelFilePath);
+ FileSystem modelFs = modelPath.getFileSystem(new Configuration());
+ MultilayerPerceptron mlp;
+
+ if (modelFs.exists(modelPath) && parameters.updateModel) {
+ // incrementally update existing model
+ log.info("Build model from existing model...");
+ mlp = new MultilayerPerceptron(parameters.modelFilePath);
+ } else {
+ if (modelFs.exists(modelPath)) {
+ modelFs.delete(modelPath, true); // delete the existing file
+ }
+ log.info("Build model from scratch...");
+ mlp = new MultilayerPerceptron();
+ for (int i = 0; i < parameters.layerSizeList.size(); ++i) {
+ if (i != parameters.layerSizeList.size() - 1) {
+ mlp.addLayer(parameters.layerSizeList.get(i), false, parameters.squashingFunctionName);
+ } else {
+ mlp.addLayer(parameters.layerSizeList.get(i), true, parameters.squashingFunctionName);
+ }
+ mlp.setCostFunction("Minus_Squared");
+ mlp.setLearningRate(parameters.learningRate)
+ .setMomentumWeight(parameters.momemtumWeight)
+ .setRegularizationWeight(parameters.regularizationWeight);
+ }
+ mlp.setModelPath(parameters.modelFilePath);
+ }
+
+ // set the parameters
+ mlp.setLearningRate(parameters.learningRate)
+ .setMomentumWeight(parameters.momemtumWeight)
+ .setRegularizationWeight(parameters.regularizationWeight);
+
+ // train by the training data
+ Path trainingDataPath = new Path(parameters.inputFilePath);
+ FileSystem dataFs = trainingDataPath.getFileSystem(new Configuration());
+
+ Preconditions.checkArgument(dataFs.exists(trainingDataPath), "Training dataset %s cannot be found!",
+ parameters.inputFilePath);
+
+ log.info("Read data and train model...");
+ BufferedReader reader = null;
+
+ try {
+ reader = new BufferedReader(new InputStreamReader(dataFs.open(trainingDataPath)));
+ String line;
+
+ // read training data line by line
+ if (parameters.skipHeader) {
+ reader.readLine();
+ }
+
+ int labelDimension = parameters.labelsIndex.size();
+ while ((line = reader.readLine()) != null) {
+ String[] token = line.split(",");
+ String label = token[token.length - 1];
+ int labelIndex = parameters.labelsIndex.get(label);
+
+ double[] instances = new double[token.length - 1 + labelDimension];
+ for (int i = 0; i < token.length - 1; ++i) {
+ instances[i] = Double.parseDouble(token[i]);
+ }
+ for (int i = 0; i < labelDimension; ++i) {
+ instances[token.length - 1 + i] = 0;
+ }
+ // set the corresponding dimension
+ instances[token.length - 1 + labelIndex] = 1;
+
+ Vector instance = new DenseVector(instances).viewPart(0, instances.length);
+ mlp.trainOnline(instance);
+ }
+
+ // write model back
+ log.info("Write trained model to {}", parameters.modelFilePath);
+ mlp.writeModelToFile();
+ mlp.close();
+ } finally {
+ Closeables.close(reader, true);
+ }
+ }
+ }
+
+ /**
+ * Parse the input arguments.
+ *
+ * @param args The input arguments
+ * @param parameters The parameters parsed.
+ * @return Whether the input arguments are valid.
+ * @throws Exception
+ */
+ private static boolean parseArgs(String[] args, Parameters parameters) throws Exception {
+ // build the options
+ log.info("Validate and parse arguments...");
+ DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ GroupBuilder groupBuilder = new GroupBuilder();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+ // whether skip the first row of the input file
+ Option skipHeaderOption = optionBuilder.withLongName("skipHeader")
+ .withShortName("sh").create();
+
+ Group skipHeaderGroup = groupBuilder.withOption(skipHeaderOption).create();
+
+ Option inputOption = optionBuilder
+ .withLongName("input")
+ .withShortName("i")
+ .withRequired(true)
+ .withChildren(skipHeaderGroup)
+ .withArgument(argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+ .create()).withDescription("the file path of training dataset")
+ .create();
+
+ Option labelsOption = optionBuilder
+ .withLongName("labels")
+ .withShortName("labels")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("label-name").withMinimum(2).create())
+ .withDescription("label names").create();
+
+ Option updateOption = optionBuilder
+ .withLongName("update")
+ .withShortName("u")
+ .withDescription("whether to incrementally update model if the model exists")
+ .create();
+
+ Group modelUpdateGroup = groupBuilder.withOption(updateOption).create();
+
+ Option modelOption = optionBuilder
+ .withLongName("model")
+ .withShortName("mo")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("model-path").withMinimum(1).withMaximum(1).create())
+ .withDescription("the path to store the trained model")
+ .withChildren(modelUpdateGroup).create();
+
+ Option layerSizeOption = optionBuilder
+ .withLongName("layerSize")
+ .withShortName("ls")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("size of layer").withMinimum(2).withMaximum(5).create())
+ .withDescription("the size of each layer").create();
+
+ Option squashingFunctionOption = optionBuilder
+ .withLongName("squashingFunction")
+ .withShortName("sf")
+ .withArgument(argumentBuilder.withName("squashing function").withMinimum(1).withMaximum(1)
+ .withDefault("Sigmoid").create())
+ .withDescription("the name of squashing function (currently only supports Sigmoid)")
+ .create();
+
+ Option learningRateOption = optionBuilder
+ .withLongName("learningRate")
+ .withShortName("l")
+ .withArgument(argumentBuilder.withName("learning rate").withMaximum(1)
+ .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_LEARNING_RATE).create())
+ .withDescription("learning rate").create();
+
+ Option momemtumOption = optionBuilder
+ .withLongName("momemtumWeight")
+ .withShortName("m")
+ .withArgument(argumentBuilder.withName("momemtum weight").withMaximum(1)
+ .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_MOMENTUM_WEIGHT).create())
+ .withDescription("momemtum weight").create();
+
+ Option regularizationOption = optionBuilder
+ .withLongName("regularizationWeight")
+ .withShortName("r")
+ .withArgument(argumentBuilder.withName("regularization weight").withMaximum(1)
+ .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_REGULARIZATION_WEIGHT).create())
+ .withDescription("regularization weight").create();
+
+ // parse the input
+ Parser parser = new Parser();
+ Group normalOptions = groupBuilder.withOption(inputOption)
+ .withOption(skipHeaderOption).withOption(updateOption)
+ .withOption(labelsOption).withOption(modelOption)
+ .withOption(layerSizeOption).withOption(squashingFunctionOption)
+ .withOption(learningRateOption).withOption(momemtumOption)
+ .withOption(regularizationOption).create();
+
+ parser.setGroup(normalOptions);
+
+ CommandLine commandLine = parser.parseAndHelp(args);
+ if (commandLine == null) {
+ return false;
+ }
+
+ parameters.learningRate = getDouble(commandLine, learningRateOption);
+ parameters.momemtumWeight = getDouble(commandLine, momemtumOption);
+ parameters.regularizationWeight = getDouble(commandLine, regularizationOption);
+
+ parameters.inputFilePath = getString(commandLine, inputOption);
+ parameters.skipHeader = commandLine.hasOption(skipHeaderOption);
+
+ List<String> labelsList = getStringList(commandLine, labelsOption);
+ int currentIndex = 0;
+ for (String label : labelsList) {
+ parameters.labelsIndex.put(label, currentIndex++);
+ }
+
+ parameters.modelFilePath = getString(commandLine, modelOption);
+ parameters.updateModel = commandLine.hasOption(updateOption);
+
+ parameters.layerSizeList = getIntegerList(commandLine, layerSizeOption);
+
+ parameters.squashingFunctionName = getString(commandLine, squashingFunctionOption);
+
+ System.out.printf("Input: %s, Model: %s, Update: %s, Layer size: %s, Squashing function: %s, Learning rate: %f," +
+ " Momemtum weight: %f, Regularization Weight: %f\n", parameters.inputFilePath, parameters.modelFilePath,
+ parameters.updateModel, Arrays.toString(parameters.layerSizeList.toArray()),
+ parameters.squashingFunctionName, parameters.learningRate, parameters.momemtumWeight,
+ parameters.regularizationWeight);
+
+ return true;
+ }
+
+ static Double getDouble(CommandLine commandLine, Option option) {
+ Object val = commandLine.getValue(option);
+ if (val != null) {
+ return Double.parseDouble(val.toString());
+ }
+ return null;
+ }
+
+ static String getString(CommandLine commandLine, Option option) {
+ Object val = commandLine.getValue(option);
+ if (val != null) {
+ return val.toString();
+ }
+ return null;
+ }
+
+ static List<Integer> getIntegerList(CommandLine commandLine, Option option) {
+ List<String> list = commandLine.getValues(option);
+ List<Integer> valList = Lists.newArrayList();
+ for (String str : list) {
+ valList.add(Integer.parseInt(str));
+ }
+ return valList;
+ }
+
+ static List<String> getStringList(CommandLine commandLine, Option option) {
+ return commandLine.getValues(option);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
new file mode 100644
index 0000000..f0794b3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm. Note that this class
+ * supports {@link #classifyFull}, but not {@code classify} or
+ * {@code classifyScalar}. The reason that these two methods are not
+ * supported is because the scores computed by a NaiveBayesClassifier do not
+ * represent probabilities.
+ */
+public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier {
+
+ private final NaiveBayesModel model;
+
+ protected AbstractNaiveBayesClassifier(NaiveBayesModel model) {
+ this.model = model;
+ }
+
+ protected NaiveBayesModel getModel() {
+ return model;
+ }
+
+ protected abstract double getScoreForLabelFeature(int label, int feature);
+
+ protected double getScoreForLabelInstance(int label, Vector instance) {
+ double result = 0.0;
+ for (Element e : instance.nonZeroes()) {
+ result += e.get() * getScoreForLabelFeature(label, e.index());
+ }
+ return result;
+ }
+
+ @Override
+ public int numCategories() {
+ return model.numLabels();
+ }
+
+ @Override
+ public Vector classifyFull(Vector instance) {
+ return classifyFull(model.createScoringVector(), instance);
+ }
+
+ @Override
+ public Vector classifyFull(Vector r, Vector instance) {
+ for (int label = 0; label < model.numLabels(); label++) {
+ r.setQuick(label, getScoreForLabelInstance(label, instance));
+ }
+ return r;
+ }
+
+ /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */
+ @Override
+ public double classifyScalar(Vector instance) {
+ throw new UnsupportedOperationException("Not supported in Naive Bayes");
+ }
+
+ /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */
+ @Override
+ public Vector classify(Vector instance) {
+ throw new UnsupportedOperationException("probabilites not supported in Naive Bayes");
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
new file mode 100644
index 0000000..1e5171c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.naivebayes.training.ThetaMapper;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import com.google.common.io.Closeables;
+
+public final class BayesUtils {
+
+ private static final Pattern SLASH = Pattern.compile("/");
+
+ private BayesUtils() {}
+
+ public static NaiveBayesModel readModelFromDir(Path base, Configuration conf) {
+
+ float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);
+ boolean isComplementary = conf.getBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, true);
+
+ // read feature sums and label sums
+ Vector scoresPerLabel = null;
+ Vector scoresPerFeature = null;
+ for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ String key = record.getFirst().toString();
+ VectorWritable value = record.getSecond();
+ if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {
+ scoresPerFeature = value.get();
+ } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {
+ scoresPerLabel = value.get();
+ }
+ }
+
+ Preconditions.checkNotNull(scoresPerFeature);
+ Preconditions.checkNotNull(scoresPerLabel);
+
+ Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size());
+ for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());
+ }
+
+ // perLabelThetaNormalizer is only used by the complementary model, we do not instantiate it for the standard model
+ Vector perLabelThetaNormalizer = null;
+ if (isComplementary) {
+ perLabelThetaNormalizer=scoresPerLabel.like();
+ for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) {
+ perLabelThetaNormalizer = entry.getSecond().get();
+ }
+ }
+ Preconditions.checkNotNull(perLabelThetaNormalizer);
+ }
+
+ return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perLabelThetaNormalizer,
+ alphaI, isComplementary);
+ }
+
+ /** Write the list of labels into a map file */
+ public static int writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath)
+ throws IOException {
+ FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class);
+ int i = 0;
+ try {
+ for (String label : labels) {
+ writer.append(new Text(label), new IntWritable(i++));
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ return i;
+ }
+
+ public static int writeLabelIndex(Configuration conf, Path indexPath,
+ Iterable<Pair<Text,IntWritable>> labels) throws IOException {
+ FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class);
+ Collection<String> seen = Sets.newHashSet();
+ int i = 0;
+ try {
+ for (Object label : labels) {
+ String theLabel = SLASH.split(((Pair<?, ?>) label).getFirst().toString())[1];
+ if (!seen.contains(theLabel)) {
+ writer.append(new Text(theLabel), new IntWritable(i++));
+ seen.add(theLabel);
+ }
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ return i;
+ }
+
+ public static Map<Integer, String> readLabelIndex(Configuration conf, Path indexPath) {
+ Map<Integer, String> labelMap = new HashMap<>();
+ for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(indexPath, true, conf)) {
+ labelMap.put(pair.getSecond().get(), pair.getFirst().toString());
+ }
+ return labelMap;
+ }
+
+ public static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException {
+ OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<>();
+ for (Pair<Writable,IntWritable> entry
+ : new SequenceFileIterable<Writable,IntWritable>(HadoopUtil.getSingleCachedFile(conf), conf)) {
+ index.put(entry.getFirst().toString(), entry.getSecond().get());
+ }
+ return index;
+ }
+
+ public static Map<String,Vector> readScoresFromCache(Configuration conf) throws IOException {
+ Map<String,Vector> sumVectors = Maps.newHashMap();
+ for (Pair<Text,VectorWritable> entry
+ : new SequenceFileDirIterable<Text,VectorWritable>(HadoopUtil.getSingleCachedFile(conf),
+ PathType.LIST, PathFilters.partFilter(), conf)) {
+ sumVectors.put(entry.getFirst().toString(), entry.getSecond().get());
+ }
+ return sumVectors;
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
new file mode 100644
index 0000000..18bd3d6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/** Implementation of the Naive Bayes Classifier Algorithm */
+public class ComplementaryNaiveBayesClassifier extends AbstractNaiveBayesClassifier {
+ public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) {
+ super(model);
+ }
+
+ @Override
+ public double getScoreForLabelFeature(int label, int feature) {
+ NaiveBayesModel model = getModel();
+ double weight = computeWeight(model.featureWeight(feature), model.weight(label, feature),
+ model.totalWeightSum(), model.labelWeight(label), model.alphaI(), model.numFeatures());
+ // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+ return weight / model.thetaNormalizer(label);
+ }
+
+ // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.1, Skewed Data bias
+ public static double computeWeight(double featureWeight, double featureLabelWeight,
+ double totalWeight, double labelWeight, double alphaI, double numFeatures) {
+ double numerator = featureWeight - featureLabelWeight + alphaI;
+ double denominator = totalWeight - labelWeight + alphaI * numFeatures;
+ return -Math.log(numerator / denominator);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
new file mode 100644
index 0000000..f180e8b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
@@ -0,0 +1,176 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+
+/** NaiveBayesModel holds the weight matrix, the feature and label sums and the weight normalizer vectors.*/
+public class NaiveBayesModel {
+
+ private final Vector weightsPerLabel;
+ private final Vector perlabelThetaNormalizer;
+ private final Vector weightsPerFeature;
+ private final Matrix weightsPerLabelAndFeature;
+ private final float alphaI;
+ private final double numFeatures;
+ private final double totalWeightSum;
+ private final boolean isComplementary;
+
+ public final static String COMPLEMENTARY_MODEL = "COMPLEMENTARY_MODEL";
+
+ public NaiveBayesModel(Matrix weightMatrix, Vector weightsPerFeature, Vector weightsPerLabel, Vector thetaNormalizer,
+ float alphaI, boolean isComplementary) {
+ this.weightsPerLabelAndFeature = weightMatrix;
+ this.weightsPerFeature = weightsPerFeature;
+ this.weightsPerLabel = weightsPerLabel;
+ this.perlabelThetaNormalizer = thetaNormalizer;
+ this.numFeatures = weightsPerFeature.getNumNondefaultElements();
+ this.totalWeightSum = weightsPerLabel.zSum();
+ this.alphaI = alphaI;
+ this.isComplementary=isComplementary;
+ }
+
+ public double labelWeight(int label) {
+ return weightsPerLabel.getQuick(label);
+ }
+
+ public double thetaNormalizer(int label) {
+ return perlabelThetaNormalizer.get(label);
+ }
+
+ public double featureWeight(int feature) {
+ return weightsPerFeature.getQuick(feature);
+ }
+
+ public double weight(int label, int feature) {
+ return weightsPerLabelAndFeature.getQuick(label, feature);
+ }
+
+ public float alphaI() {
+ return alphaI;
+ }
+
+ public double numFeatures() {
+ return numFeatures;
+ }
+
+ public double totalWeightSum() {
+ return totalWeightSum;
+ }
+
+ public int numLabels() {
+ return weightsPerLabel.size();
+ }
+
+ public Vector createScoringVector() {
+ return weightsPerLabel.like();
+ }
+
+ public boolean isComplemtary(){
+ return isComplementary;
+ }
+
+ public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException {
+ FileSystem fs = output.getFileSystem(conf);
+
+ Vector weightsPerLabel = null;
+ Vector perLabelThetaNormalizer = null;
+ Vector weightsPerFeature = null;
+ Matrix weightsPerLabelAndFeature;
+ float alphaI;
+ boolean isComplementary;
+
+ FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin"));
+ try {
+ alphaI = in.readFloat();
+ isComplementary = in.readBoolean();
+ weightsPerFeature = VectorWritable.readVector(in);
+ weightsPerLabel = new DenseVector(VectorWritable.readVector(in));
+ if (isComplementary){
+ perLabelThetaNormalizer = new DenseVector(VectorWritable.readVector(in));
+ }
+ weightsPerLabelAndFeature = new SparseRowMatrix(weightsPerLabel.size(), weightsPerFeature.size());
+ for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) {
+ weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in));
+ }
+ } finally {
+ Closeables.close(in, true);
+ }
+ NaiveBayesModel model = new NaiveBayesModel(weightsPerLabelAndFeature, weightsPerFeature, weightsPerLabel,
+ perLabelThetaNormalizer, alphaI, isComplementary);
+ model.validate();
+ return model;
+ }
+
+ public void serialize(Path output, Configuration conf) throws IOException {
+ FileSystem fs = output.getFileSystem(conf);
+ FSDataOutputStream out = fs.create(new Path(output, "naiveBayesModel.bin"));
+ try {
+ out.writeFloat(alphaI);
+ out.writeBoolean(isComplementary);
+ VectorWritable.writeVector(out, weightsPerFeature);
+ VectorWritable.writeVector(out, weightsPerLabel);
+ if (isComplementary){
+ VectorWritable.writeVector(out, perlabelThetaNormalizer);
+ }
+ for (int row = 0; row < weightsPerLabelAndFeature.numRows(); row++) {
+ VectorWritable.writeVector(out, weightsPerLabelAndFeature.viewRow(row));
+ }
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public void validate() {
+ Preconditions.checkState(alphaI > 0, "alphaI has to be greater than 0!");
+ Preconditions.checkArgument(numFeatures > 0, "the vocab count has to be greater than 0!");
+ Preconditions.checkArgument(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!");
+ Preconditions.checkNotNull(weightsPerLabel, "the number of labels has to be defined!");
+ Preconditions.checkArgument(weightsPerLabel.getNumNondefaultElements() > 0,
+ "the number of labels has to be greater than 0!");
+ Preconditions.checkNotNull(weightsPerFeature, "the feature sums have to be defined");
+ Preconditions.checkArgument(weightsPerFeature.getNumNondefaultElements() > 0,
+ "the feature sums have to be greater than 0!");
+ if (isComplementary){
+ Preconditions.checkArgument(perlabelThetaNormalizer != null, "the theta normalizers have to be defined");
+ Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() > 0,
+ "the number of theta normalizers has to be greater than 0!");
+ Preconditions.checkArgument(Math.signum(perlabelThetaNormalizer.minValue())
+ == Math.signum(perlabelThetaNormalizer.maxValue()),
+ "Theta normalizers do not all have the same sign");
+ Preconditions.checkArgument(perlabelThetaNormalizer.getNumNonZeroElements()
+ == perlabelThetaNormalizer.size(),
+ "Theta normalizers can not have zero value.");
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
new file mode 100644
index 0000000..e4ce8aa
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/** Implementation of the Naive Bayes Classifier Algorithm */
+public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier {
+
+ public StandardNaiveBayesClassifier(NaiveBayesModel model) {
+ super(model);
+ }
+
+ @Override
+ public double getScoreForLabelFeature(int label, int feature) {
+ NaiveBayesModel model = getModel();
+ // Standard Naive Bayes does not use weight normalization
+ return computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI(), model.numFeatures());
+ }
+
+ public static double computeWeight(double featureLabelWeight, double labelWeight, double alphaI, double numFeatures) {
+ double numerator = featureLabelWeight + alphaI;
+ double denominator = labelWeight + alphaI * numFeatures;
+ return Math.log(numerator / denominator);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
new file mode 100644
index 0000000..37a3b71
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.test;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+/**
+ * Run the input through the model and see if it matches.
+ * <p/>
+ * The output value is the generated label, the Pair is the expected label and true if they match:
+ */
+public class BayesTestMapper extends Mapper<Text, VectorWritable, Text, VectorWritable> {
+
+ private static final Pattern SLASH = Pattern.compile("/");
+
+ private AbstractNaiveBayesClassifier classifier;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ Path modelPath = HadoopUtil.getSingleCachedFile(conf);
+ NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf);
+ boolean isComplementary = Boolean.parseBoolean(conf.get(TestNaiveBayesDriver.COMPLEMENTARY));
+
+ // ensure that if we are testing in complementary mode, the model has been
+ // trained complementary. a complementarty model will work for standard classification
+ // a standard model will not work for complementary classification
+ if (isComplementary) {
+ Preconditions.checkArgument((model.isComplemtary()),
+ "Complementary mode in model is different than test mode");
+ }
+
+ if (isComplementary) {
+ classifier = new ComplementaryNaiveBayesClassifier(model);
+ } else {
+ classifier = new StandardNaiveBayesClassifier(model);
+ }
+ }
+
+ @Override
+ protected void map(Text key, VectorWritable value, Context context) throws IOException, InterruptedException {
+ Vector result = classifier.classifyFull(value.get());
+ //the key is the expected value
+ context.write(new Text(SLASH.split(key.toString())[1]), new VectorWritable(result));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
new file mode 100644
index 0000000..8fd422f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
@@ -0,0 +1,179 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.test;
+
+import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Test the (Complementary) Naive Bayes model that was built during training
+ * by running the iterating the test set and comparing it to the model
+ */
+public class TestNaiveBayesDriver extends AbstractJob {
+
+ private static final Logger log = LoggerFactory.getLogger(TestNaiveBayesDriver.class);
+
+ public static final String COMPLEMENTARY = "class"; //b for bayes, c for complementary
+ private static final Pattern SLASH = Pattern.compile("/");
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new TestNaiveBayesDriver(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(addOption(DefaultOptionCreator.overwriteOption().create()));
+ addOption("model", "m", "The path to the model built during training", true);
+ addOption(buildOption("testComplementary", "c", "test complementary?", false, false, String.valueOf(false)));
+ addOption(buildOption("runSequential", "seq", "run sequential?", false, false, String.valueOf(false)));
+ addOption("labelIndex", "l", "The path to the location of the label index", true);
+ Map<String, List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), getOutputPath());
+ }
+
+ boolean sequential = hasOption("runSequential");
+ boolean succeeded;
+ if (sequential) {
+ runSequential();
+ } else {
+ succeeded = runMapReduce();
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ //load the labels
+ Map<Integer, String> labelMap = BayesUtils.readLabelIndex(getConf(), new Path(getOption("labelIndex")));
+
+ //loop over the results and create the confusion matrix
+ SequenceFileDirIterable<Text, VectorWritable> dirIterable =
+ new SequenceFileDirIterable<>(getOutputPath(), PathType.LIST, PathFilters.partFilter(), getConf());
+ ResultAnalyzer analyzer = new ResultAnalyzer(labelMap.values(), "DEFAULT");
+ analyzeResults(labelMap, dirIterable, analyzer);
+
+ log.info("{} Results: {}", hasOption("testComplementary") ? "Complementary" : "Standard NB", analyzer);
+ return 0;
+ }
+
+ private void runSequential() throws IOException {
+ boolean complementary = hasOption("testComplementary");
+ FileSystem fs = FileSystem.get(getConf());
+ NaiveBayesModel model = NaiveBayesModel.materialize(new Path(getOption("model")), getConf());
+
+ // Ensure that if we are testing in complementary mode, the model has been
+ // trained complementary. a complementarty model will work for standard classification
+ // a standard model will not work for complementary classification
+ if (complementary){
+ Preconditions.checkArgument((model.isComplemtary()),
+ "Complementary mode in model is different from test mode");
+ }
+
+ AbstractNaiveBayesClassifier classifier;
+ if (complementary) {
+ classifier = new ComplementaryNaiveBayesClassifier(model);
+ } else {
+ classifier = new StandardNaiveBayesClassifier(model);
+ }
+ SequenceFile.Writer writer = SequenceFile.createWriter(fs, getConf(), new Path(getOutputPath(), "part-r-00000"),
+ Text.class, VectorWritable.class);
+
+ try {
+ SequenceFileDirIterable<Text, VectorWritable> dirIterable =
+ new SequenceFileDirIterable<>(getInputPath(), PathType.LIST, PathFilters.partFilter(), getConf());
+ // loop through the part-r-* files in getInputPath() and get classification scores for all entries
+ for (Pair<Text, VectorWritable> pair : dirIterable) {
+ writer.append(new Text(SLASH.split(pair.getFirst().toString())[1]),
+ new VectorWritable(classifier.classifyFull(pair.getSecond().get())));
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ private boolean runMapReduce() throws IOException,
+ InterruptedException, ClassNotFoundException {
+ Path model = new Path(getOption("model"));
+ HadoopUtil.cacheFiles(model, getConf());
+ //the output key is the expected value, the output value are the scores for all the labels
+ Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class,
+ Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+ //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels"));
+
+
+ boolean complementary = hasOption("testComplementary");
+ testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary));
+ return testJob.waitForCompletion(true);
+ }
+
+ private static void analyzeResults(Map<Integer, String> labelMap,
+ SequenceFileDirIterable<Text, VectorWritable> dirIterable,
+ ResultAnalyzer analyzer) {
+ for (Pair<Text, VectorWritable> pair : dirIterable) {
+ int bestIdx = Integer.MIN_VALUE;
+ double bestScore = Long.MIN_VALUE;
+ for (Vector.Element element : pair.getSecond().get().all()) {
+ if (element.get() > bestScore) {
+ bestScore = element.get();
+ bestIdx = element.index();
+ }
+ }
+ if (bestIdx != Integer.MIN_VALUE) {
+ ClassifierResult classifierResult = new ClassifierResult(labelMap.get(bestIdx), bestScore);
+ analyzer.addInstance(pair.getFirst().toString(), classifierResult);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
new file mode 100644
index 0000000..2b8ee1e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.math.Vector;
+
+public class ComplementaryThetaTrainer {
+
+ private final Vector weightsPerFeature;
+ private final Vector weightsPerLabel;
+ private final Vector perLabelThetaNormalizer;
+ private final double alphaI;
+ private final double totalWeightSum;
+ private final double numFeatures;
+
+ public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) {
+ Preconditions.checkNotNull(weightsPerFeature);
+ Preconditions.checkNotNull(weightsPerLabel);
+ this.weightsPerFeature = weightsPerFeature;
+ this.weightsPerLabel = weightsPerLabel;
+ this.alphaI = alphaI;
+ perLabelThetaNormalizer = weightsPerLabel.like();
+ totalWeightSum = weightsPerLabel.zSum();
+ numFeatures = weightsPerFeature.getNumNondefaultElements();
+ }
+
+ public void train(int label, Vector perLabelWeight) {
+ double labelWeight = labelWeight(label);
+ // sum weights for each label including those with zero word counts
+ for(int i = 0; i < perLabelWeight.size(); i++){
+ Vector.Element perLabelWeightElement = perLabelWeight.getElement(i);
+ updatePerLabelThetaNormalizer(label,
+ ComplementaryNaiveBayesClassifier.computeWeight(featureWeight(perLabelWeightElement.index()),
+ perLabelWeightElement.get(), totalWeightSum(), labelWeight, alphaI(), numFeatures()));
+ }
+ }
+
+ protected double alphaI() {
+ return alphaI;
+ }
+
+ protected double numFeatures() {
+ return numFeatures;
+ }
+
+ protected double labelWeight(int label) {
+ return weightsPerLabel.get(label);
+ }
+
+ protected double totalWeightSum() {
+ return totalWeightSum;
+ }
+
+ protected double featureWeight(int feature) {
+ return weightsPerFeature.get(feature);
+ }
+
+ // http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+ protected void updatePerLabelThetaNormalizer(int label, double weight) {
+ perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + Math.abs(weight));
+ }
+
+ public Vector retrievePerLabelThetaNormalizer() {
+ return perLabelThetaNormalizer.clone();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
new file mode 100644
index 0000000..40ca2e9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class IndexInstancesMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
+
+ private static final Pattern SLASH = Pattern.compile("/");
+
+ public enum Counter { SKIPPED_INSTANCES }
+
+ private OpenObjectIntHashMap<String> labelIndex;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
+ labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration());
+ }
+
+ @Override
+ protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException {
+ String label = SLASH.split(labelText.toString())[1];
+ if (labelIndex.containsKey(label)) {
+ ctx.write(new IntWritable(labelIndex.get(label)), instance);
+ } else {
+ ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
new file mode 100644
index 0000000..ff2ea40
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public class ThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+ public static final String ALPHA_I = ThetaMapper.class.getName() + ".alphaI";
+ static final String TRAIN_COMPLEMENTARY = ThetaMapper.class.getName() + ".trainComplementary";
+
+ private ComplementaryThetaTrainer trainer;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
+ Configuration conf = ctx.getConfiguration();
+
+ float alphaI = conf.getFloat(ALPHA_I, 1.0f);
+ Map<String, Vector> scores = BayesUtils.readScoresFromCache(conf);
+
+ trainer = new ComplementaryThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
+ scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), alphaI);
+ }
+
+ @Override
+ protected void map(IntWritable key, VectorWritable value, Context ctx) throws IOException, InterruptedException {
+ trainer.train(key.get(), value.get());
+ }
+
+ @Override
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER),
+ new VectorWritable(trainer.retrievePerLabelThetaNormalizer()));
+ super.cleanup(ctx);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
new file mode 100644
index 0000000..ac1c4c9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
@@ -0,0 +1,186 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.mapreduce.VectorSumReducer;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.base.Splitter;
+
+/** Trains a Naive Bayes Classifier (parameters for both Naive Bayes and Complementary Naive Bayes) */
+public final class TrainNaiveBayesJob extends AbstractJob {
+ private static final String TRAIN_COMPLEMENTARY = "trainComplementary";
+ private static final String ALPHA_I = "alphaI";
+ private static final String LABEL_INDEX = "labelIndex";
+ private static final String EXTRACT_LABELS = "extractLabels";
+ private static final String LABELS = "labels";
+ public static final String WEIGHTS_PER_FEATURE = "__SPF";
+ public static final String WEIGHTS_PER_LABEL = "__SPL";
+ public static final String LABEL_THETA_NORMALIZER = "_LTN";
+
+ public static final String SUMMED_OBSERVATIONS = "summedObservations";
+ public static final String WEIGHTS = "weights";
+ public static final String THETAS = "thetas";
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new TrainNaiveBayesJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption(LABELS, "l", "comma-separated list of labels to include in training", false);
+
+ addOption(buildOption(EXTRACT_LABELS, "el", "Extract the labels from the input", false, false, ""));
+ addOption(ALPHA_I, "a", "smoothing parameter", String.valueOf(1.0f));
+ addOption(buildOption(TRAIN_COMPLEMENTARY, "c", "train complementary?", false, false, String.valueOf(false)));
+ addOption(LABEL_INDEX, "li", "The path to store the label index in", false);
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ Map<String, List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), getOutputPath());
+ HadoopUtil.delete(getConf(), getTempPath());
+ }
+ Path labPath;
+ String labPathStr = getOption(LABEL_INDEX);
+ if (labPathStr != null) {
+ labPath = new Path(labPathStr);
+ } else {
+ labPath = getTempPath(LABEL_INDEX);
+ }
+ long labelSize = createLabelIndex(labPath);
+ float alphaI = Float.parseFloat(getOption(ALPHA_I));
+ boolean trainComplementary = hasOption(TRAIN_COMPLEMENTARY);
+
+ HadoopUtil.setSerializations(getConf());
+ HadoopUtil.cacheFiles(labPath, getConf());
+
+ // Add up all the vectors with the same labels, while mapping the labels into our index
+ Job indexInstances = prepareJob(getInputPath(),
+ getTempPath(SUMMED_OBSERVATIONS),
+ SequenceFileInputFormat.class,
+ IndexInstancesMapper.class,
+ IntWritable.class,
+ VectorWritable.class,
+ VectorSumReducer.class,
+ IntWritable.class,
+ VectorWritable.class,
+ SequenceFileOutputFormat.class);
+ indexInstances.setCombinerClass(VectorSumReducer.class);
+ boolean succeeded = indexInstances.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ // Sum up all the weights from the previous step, per label and per feature
+ Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),
+ getTempPath(WEIGHTS),
+ SequenceFileInputFormat.class,
+ WeightsMapper.class,
+ Text.class,
+ VectorWritable.class,
+ VectorSumReducer.class,
+ Text.class,
+ VectorWritable.class,
+ SequenceFileOutputFormat.class);
+ weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(labelSize));
+ weightSummer.setCombinerClass(VectorSumReducer.class);
+ succeeded = weightSummer.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ // Put the per label and per feature vectors into the cache
+ HadoopUtil.cacheFiles(getTempPath(WEIGHTS), getConf());
+
+ if (trainComplementary){
+ // Calculate the per label theta normalizers, write out to LABEL_THETA_NORMALIZER vector
+ // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+ Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),
+ getTempPath(THETAS),
+ SequenceFileInputFormat.class,
+ ThetaMapper.class,
+ Text.class,
+ VectorWritable.class,
+ VectorSumReducer.class,
+ Text.class,
+ VectorWritable.class,
+ SequenceFileOutputFormat.class);
+ thetaSummer.setCombinerClass(VectorSumReducer.class);
+ thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI);
+ thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary);
+ succeeded = thetaSummer.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ // Put the per label theta normalizers into the cache
+ HadoopUtil.cacheFiles(getTempPath(THETAS), getConf());
+
+ // Validate our model and then write it out to the official output
+ getConf().setFloat(ThetaMapper.ALPHA_I, alphaI);
+ getConf().setBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, trainComplementary);
+ NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(getTempPath(), getConf());
+ naiveBayesModel.validate();
+ naiveBayesModel.serialize(getOutputPath(), getConf());
+
+ return 0;
+ }
+
+ private long createLabelIndex(Path labPath) throws IOException {
+ long labelSize = 0;
+ if (hasOption(LABELS)) {
+ Iterable<String> labels = Splitter.on(",").split(getOption(LABELS));
+ labelSize = BayesUtils.writeLabelIndex(getConf(), labels, labPath);
+ } else if (hasOption(EXTRACT_LABELS)) {
+ Iterable<Pair<Text,IntWritable>> iterable =
+ new SequenceFileDirIterable<Text, IntWritable>(getInputPath(),
+ PathType.LIST,
+ PathFilters.logsCRCFilter(),
+ getConf());
+ labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable);
+ }
+ return labelSize;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
new file mode 100644
index 0000000..5563057
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import com.google.common.base.Preconditions;
+
+public class WeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+ static final String NUM_LABELS = WeightsMapper.class.getName() + ".numLabels";
+
+ private Vector weightsPerFeature;
+ private Vector weightsPerLabel;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
+ int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS));
+ Preconditions.checkArgument(numLabels > 0, "Wrong numLabels: " + numLabels + ". Must be > 0!");
+ weightsPerLabel = new DenseVector(numLabels);
+ }
+
+ @Override
+ protected void map(IntWritable index, VectorWritable value, Context ctx) throws IOException, InterruptedException {
+ Vector instance = value.get();
+ if (weightsPerFeature == null) {
+ weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements());
+ }
+
+ int label = index.get();
+ weightsPerFeature.assign(instance, Functions.PLUS);
+ weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum());
+ }
+
+ @Override
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ if (weightsPerFeature != null) {
+ ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature));
+ ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel));
+ }
+ super.cleanup(ctx);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
new file mode 100644
index 0000000..942a101
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
@@ -0,0 +1,165 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.DataOutputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.Date;
+import java.util.List;
+import java.util.Scanner;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+
+/**
+ * A class for EM training of HMM from console
+ */
+public final class BaumWelchTrainer {
+
+ private BaumWelchTrainer() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+ Option inputOption = DefaultOptionCreator.inputOption().create();
+
+ Option outputOption = DefaultOptionCreator.outputOption().create();
+
+ Option stateNumberOption = optionBuilder.withLongName("nrOfHiddenStates").
+ withDescription("Number of hidden states").
+ withShortName("nh").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ Option observedStateNumberOption = optionBuilder.withLongName("nrOfObservedStates").
+ withDescription("Number of observed states").
+ withShortName("no").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ Option epsilonOption = optionBuilder.withLongName("epsilon").
+ withDescription("Convergence threshold").
+ withShortName("e").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ Option iterationsOption = optionBuilder.withLongName("max-iterations").
+ withDescription("Maximum iterations number").
+ withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ Group optionGroup = new GroupBuilder().withOption(inputOption).
+ withOption(outputOption).withOption(stateNumberOption).withOption(observedStateNumberOption).
+ withOption(epsilonOption).withOption(iterationsOption).
+ withName("Options").create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(optionGroup);
+ CommandLine commandLine = parser.parse(args);
+
+ String input = (String) commandLine.getValue(inputOption);
+ String output = (String) commandLine.getValue(outputOption);
+
+ int nrOfHiddenStates = Integer.parseInt((String) commandLine.getValue(stateNumberOption));
+ int nrOfObservedStates = Integer.parseInt((String) commandLine.getValue(observedStateNumberOption));
+
+ double epsilon = Double.parseDouble((String) commandLine.getValue(epsilonOption));
+ int maxIterations = Integer.parseInt((String) commandLine.getValue(iterationsOption));
+
+ //constructing random-generated HMM
+ HmmModel model = new HmmModel(nrOfHiddenStates, nrOfObservedStates, new Date().getTime());
+ List<Integer> observations = Lists.newArrayList();
+
+ //reading observations
+ try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) {
+ while (scanner.hasNextInt()) {
+ observations.add(scanner.nextInt());
+ }
+ }
+
+ int[] observationsArray = new int[observations.size()];
+ for (int i = 0; i < observations.size(); ++i) {
+ observationsArray[i] = observations.get(i);
+ }
+
+ //training
+ HmmModel trainedModel = HmmTrainer.trainBaumWelch(model,
+ observationsArray, epsilon, maxIterations, true);
+
+ //serializing trained model
+ DataOutputStream stream = new DataOutputStream(new FileOutputStream(output));
+ try {
+ LossyHmmSerializer.serialize(trainedModel, stream);
+ } finally {
+ Closeables.close(stream, false);
+ }
+
+ //printing tranied model
+ System.out.println("Initial probabilities: ");
+ for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+ System.out.print(i + " ");
+ }
+ System.out.println();
+ for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+ System.out.print(trainedModel.getInitialProbabilities().get(i) + " ");
+ }
+ System.out.println();
+
+ System.out.println("Transition matrix:");
+ System.out.print(" ");
+ for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+ System.out.print(i + " ");
+ }
+ System.out.println();
+ for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+ System.out.print(i + " ");
+ for (int j = 0; j < trainedModel.getNrOfHiddenStates(); ++j) {
+ System.out.print(trainedModel.getTransitionMatrix().get(i, j) + " ");
+ }
+ System.out.println();
+ }
+ System.out.println("Emission matrix: ");
+ System.out.print(" ");
+ for (int i = 0; i < trainedModel.getNrOfOutputStates(); ++i) {
+ System.out.print(i + " ");
+ }
+ System.out.println();
+ for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+ System.out.print(i + " ");
+ for (int j = 0; j < trainedModel.getNrOfOutputStates(); ++j) {
+ System.out.print(trainedModel.getEmissionMatrix().get(i, j) + " ");
+ }
+ System.out.println();
+ }
+ } catch (OptionException e) {
+ CommandLineUtil.printHelp(optionGroup);
+ }
+ }
+}