You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:08:05 UTC

[34/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy into mahout-hdfs and mahout-mr, closes apache/mahout#86

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java
new file mode 100644
index 0000000..0f88a70
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java
@@ -0,0 +1,332 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.Arrays;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+
+/** Train a {@link MultilayerPerceptron}. */
+public final class TrainMultilayerPerceptron {
+
+  private static final Logger log = LoggerFactory.getLogger(TrainMultilayerPerceptron.class);
+  
+  /**  The parameters used by MLP. */
+  static class Parameters {
+    double learningRate;
+    double momemtumWeight;
+    double regularizationWeight;
+
+    String inputFilePath;
+    boolean skipHeader;
+    Map<String, Integer> labelsIndex = Maps.newHashMap();
+
+    String modelFilePath;
+    boolean updateModel;
+    List<Integer> layerSizeList = Lists.newArrayList();
+    String squashingFunctionName;
+  }
+
+  /*
+  private double learningRate;
+  private double momemtumWeight;
+  private double regularizationWeight;
+
+  private String inputFilePath;
+  private boolean skipHeader;
+  private Map<String, Integer> labelsIndex = Maps.newHashMap();
+
+  private String modelFilePath;
+  private boolean updateModel;
+  private List<Integer> layerSizeList = Lists.newArrayList();
+  private String squashingFunctionName;*/
+
+  public static void main(String[] args) throws Exception {
+    Parameters parameters = new Parameters();
+    
+    if (parseArgs(args, parameters)) {
+      log.info("Validate model...");
+      // check whether the model already exists
+      Path modelPath = new Path(parameters.modelFilePath);
+      FileSystem modelFs = modelPath.getFileSystem(new Configuration());
+      MultilayerPerceptron mlp;
+
+      if (modelFs.exists(modelPath) && parameters.updateModel) {
+        // incrementally update existing model
+        log.info("Build model from existing model...");
+        mlp = new MultilayerPerceptron(parameters.modelFilePath);
+      } else {
+        if (modelFs.exists(modelPath)) {
+          modelFs.delete(modelPath, true); // delete the existing file
+        }
+        log.info("Build model from scratch...");
+        mlp = new MultilayerPerceptron();
+        for (int i = 0; i < parameters.layerSizeList.size(); ++i) {
+          if (i != parameters.layerSizeList.size() - 1) {
+            mlp.addLayer(parameters.layerSizeList.get(i), false, parameters.squashingFunctionName);
+          } else {
+            mlp.addLayer(parameters.layerSizeList.get(i), true, parameters.squashingFunctionName);
+          }
+          mlp.setCostFunction("Minus_Squared");
+          mlp.setLearningRate(parameters.learningRate)
+             .setMomentumWeight(parameters.momemtumWeight)
+             .setRegularizationWeight(parameters.regularizationWeight);
+        }
+        mlp.setModelPath(parameters.modelFilePath);
+      }
+
+      // set the parameters
+      mlp.setLearningRate(parameters.learningRate)
+         .setMomentumWeight(parameters.momemtumWeight)
+         .setRegularizationWeight(parameters.regularizationWeight);
+
+      // train by the training data
+      Path trainingDataPath = new Path(parameters.inputFilePath);
+      FileSystem dataFs = trainingDataPath.getFileSystem(new Configuration());
+
+      Preconditions.checkArgument(dataFs.exists(trainingDataPath), "Training dataset %s cannot be found!",
+                                  parameters.inputFilePath);
+
+      log.info("Read data and train model...");
+      BufferedReader reader = null;
+
+      try {
+        reader = new BufferedReader(new InputStreamReader(dataFs.open(trainingDataPath)));
+        String line;
+
+        // read training data line by line
+        if (parameters.skipHeader) {
+          reader.readLine();
+        }
+
+        int labelDimension = parameters.labelsIndex.size();
+        while ((line = reader.readLine()) != null) {
+          String[] token = line.split(",");
+          String label = token[token.length - 1];
+          int labelIndex = parameters.labelsIndex.get(label);
+
+          double[] instances = new double[token.length - 1 + labelDimension];
+          for (int i = 0; i < token.length - 1; ++i) {
+            instances[i] = Double.parseDouble(token[i]);
+          }
+          for (int i = 0; i < labelDimension; ++i) {
+            instances[token.length - 1 + i] = 0;
+          }
+          // set the corresponding dimension
+          instances[token.length - 1 + labelIndex] = 1;
+
+          Vector instance = new DenseVector(instances).viewPart(0, instances.length);
+          mlp.trainOnline(instance);
+        }
+
+        // write model back
+        log.info("Write trained model to {}", parameters.modelFilePath);
+        mlp.writeModelToFile();
+        mlp.close();
+      } finally {
+        Closeables.close(reader, true);
+      }
+    }
+  }
+
+  /**
+   * Parse the input arguments.
+   * 
+   * @param args The input arguments
+   * @param parameters The parameters parsed.
+   * @return Whether the input arguments are valid.
+   * @throws Exception
+   */
+  private static boolean parseArgs(String[] args, Parameters parameters) throws Exception {
+    // build the options
+    log.info("Validate and parse arguments...");
+    DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+    GroupBuilder groupBuilder = new GroupBuilder();
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+    // whether skip the first row of the input file
+    Option skipHeaderOption = optionBuilder.withLongName("skipHeader")
+        .withShortName("sh").create();
+
+    Group skipHeaderGroup = groupBuilder.withOption(skipHeaderOption).create();
+
+    Option inputOption = optionBuilder
+        .withLongName("input")
+        .withShortName("i")
+        .withRequired(true)
+        .withChildren(skipHeaderGroup)
+        .withArgument(argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+                .create()).withDescription("the file path of training dataset")
+        .create();
+
+    Option labelsOption = optionBuilder
+        .withLongName("labels")
+        .withShortName("labels")
+        .withRequired(true)
+        .withArgument(argumentBuilder.withName("label-name").withMinimum(2).create())
+        .withDescription("label names").create();
+
+    Option updateOption = optionBuilder
+        .withLongName("update")
+        .withShortName("u")
+        .withDescription("whether to incrementally update model if the model exists")
+        .create();
+
+    Group modelUpdateGroup = groupBuilder.withOption(updateOption).create();
+
+    Option modelOption = optionBuilder
+        .withLongName("model")
+        .withShortName("mo")
+        .withRequired(true)
+        .withArgument(argumentBuilder.withName("model-path").withMinimum(1).withMaximum(1).create())
+        .withDescription("the path to store the trained model")
+        .withChildren(modelUpdateGroup).create();
+
+    Option layerSizeOption = optionBuilder
+        .withLongName("layerSize")
+        .withShortName("ls")
+        .withRequired(true)
+        .withArgument(argumentBuilder.withName("size of layer").withMinimum(2).withMaximum(5).create())
+        .withDescription("the size of each layer").create();
+
+    Option squashingFunctionOption = optionBuilder
+        .withLongName("squashingFunction")
+        .withShortName("sf")
+        .withArgument(argumentBuilder.withName("squashing function").withMinimum(1).withMaximum(1)
+            .withDefault("Sigmoid").create())
+        .withDescription("the name of squashing function (currently only supports Sigmoid)")
+        .create();
+
+    Option learningRateOption = optionBuilder
+        .withLongName("learningRate")
+        .withShortName("l")
+        .withArgument(argumentBuilder.withName("learning rate").withMaximum(1)
+            .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_LEARNING_RATE).create())
+        .withDescription("learning rate").create();
+
+    Option momemtumOption = optionBuilder
+        .withLongName("momemtumWeight")
+        .withShortName("m")
+        .withArgument(argumentBuilder.withName("momemtum weight").withMaximum(1)
+            .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_MOMENTUM_WEIGHT).create())
+        .withDescription("momemtum weight").create();
+
+    Option regularizationOption = optionBuilder
+        .withLongName("regularizationWeight")
+        .withShortName("r")
+        .withArgument(argumentBuilder.withName("regularization weight").withMaximum(1)
+            .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_REGULARIZATION_WEIGHT).create())
+        .withDescription("regularization weight").create();
+
+    // parse the input
+    Parser parser = new Parser();
+    Group normalOptions = groupBuilder.withOption(inputOption)
+        .withOption(skipHeaderOption).withOption(updateOption)
+        .withOption(labelsOption).withOption(modelOption)
+        .withOption(layerSizeOption).withOption(squashingFunctionOption)
+        .withOption(learningRateOption).withOption(momemtumOption)
+        .withOption(regularizationOption).create();
+
+    parser.setGroup(normalOptions);
+
+    CommandLine commandLine = parser.parseAndHelp(args);
+    if (commandLine == null) {
+      return false;
+    }
+
+    parameters.learningRate = getDouble(commandLine, learningRateOption);
+    parameters.momemtumWeight = getDouble(commandLine, momemtumOption);
+    parameters.regularizationWeight = getDouble(commandLine, regularizationOption);
+
+    parameters.inputFilePath = getString(commandLine, inputOption);
+    parameters.skipHeader = commandLine.hasOption(skipHeaderOption);
+
+    List<String> labelsList = getStringList(commandLine, labelsOption);
+    int currentIndex = 0;
+    for (String label : labelsList) {
+      parameters.labelsIndex.put(label, currentIndex++);
+    }
+
+    parameters.modelFilePath = getString(commandLine, modelOption);
+    parameters.updateModel = commandLine.hasOption(updateOption);
+
+    parameters.layerSizeList = getIntegerList(commandLine, layerSizeOption);
+
+    parameters.squashingFunctionName = getString(commandLine, squashingFunctionOption);
+
+    System.out.printf("Input: %s, Model: %s, Update: %s, Layer size: %s, Squashing function: %s, Learning rate: %f," +
+        " Momemtum weight: %f, Regularization Weight: %f\n", parameters.inputFilePath, parameters.modelFilePath, 
+        parameters.updateModel, Arrays.toString(parameters.layerSizeList.toArray()), 
+        parameters.squashingFunctionName, parameters.learningRate, parameters.momemtumWeight, 
+        parameters.regularizationWeight);
+
+    return true;
+  }
+
+  static Double getDouble(CommandLine commandLine, Option option) {
+    Object val = commandLine.getValue(option);
+    if (val != null) {
+      return Double.parseDouble(val.toString());
+    }
+    return null;
+  }
+
+  static String getString(CommandLine commandLine, Option option) {
+    Object val = commandLine.getValue(option);
+    if (val != null) {
+      return val.toString();
+    }
+    return null;
+  }
+
+  static List<Integer> getIntegerList(CommandLine commandLine, Option option) {
+    List<String> list = commandLine.getValues(option);
+    List<Integer> valList = Lists.newArrayList();
+    for (String str : list) {
+      valList.add(Integer.parseInt(str));
+    }
+    return valList;
+  }
+
+  static List<String> getStringList(CommandLine commandLine, Option option) {
+    return commandLine.getValues(option);
+  }
+
+}
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
new file mode 100644
index 0000000..f0794b3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm. Note that this class
+ * supports {@link #classifyFull}, but not {@code classify} or
+ * {@code classifyScalar}. The reason that these two methods are not
+ * supported is because the scores computed by a NaiveBayesClassifier do not
+ * represent probabilities.
+ */
+public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier {
+
+  private final NaiveBayesModel model;
+  
+  protected AbstractNaiveBayesClassifier(NaiveBayesModel model) {
+    this.model = model;
+  }
+
+  protected NaiveBayesModel getModel() {
+    return model;
+  }
+  
+  protected abstract double getScoreForLabelFeature(int label, int feature);
+
+  protected double getScoreForLabelInstance(int label, Vector instance) {
+    double result = 0.0;
+    for (Element e : instance.nonZeroes()) {
+      result += e.get() * getScoreForLabelFeature(label, e.index());
+    }
+    return result;
+  }
+  
+  @Override
+  public int numCategories() {
+    return model.numLabels();
+  }
+
+  @Override
+  public Vector classifyFull(Vector instance) {
+    return classifyFull(model.createScoringVector(), instance);
+  }
+  
+  @Override
+  public Vector classifyFull(Vector r, Vector instance) {
+    for (int label = 0; label < model.numLabels(); label++) {
+      r.setQuick(label, getScoreForLabelInstance(label, instance));
+    }
+    return r;
+  }
+
+  /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */
+  @Override
+  public double classifyScalar(Vector instance) {
+    throw new UnsupportedOperationException("Not supported in Naive Bayes");
+  }
+  
+  /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */
+  @Override
+  public Vector classify(Vector instance) {
+    throw new UnsupportedOperationException("probabilites not supported in Naive Bayes");
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
new file mode 100644
index 0000000..1e5171c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.naivebayes.training.ThetaMapper;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import com.google.common.io.Closeables;
+
+public final class BayesUtils {
+
+  private static final Pattern SLASH = Pattern.compile("/");
+
+  private BayesUtils() {}
+
+  public static NaiveBayesModel readModelFromDir(Path base, Configuration conf) {
+
+    float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);
+    boolean isComplementary = conf.getBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, true);
+
+    // read feature sums and label sums
+    Vector scoresPerLabel = null;
+    Vector scoresPerFeature = null;
+    for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>(
+        new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) {
+      String key = record.getFirst().toString();
+      VectorWritable value = record.getSecond();
+      if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {
+        scoresPerFeature = value.get();
+      } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {
+        scoresPerLabel = value.get();
+      }
+    }
+
+    Preconditions.checkNotNull(scoresPerFeature);
+    Preconditions.checkNotNull(scoresPerLabel);
+
+    Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size());
+    for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>(
+        new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) {
+      scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());
+    }
+    
+    // perLabelThetaNormalizer is only used by the complementary model, we do not instantiate it for the standard model
+    Vector perLabelThetaNormalizer = null;
+    if (isComplementary) {
+      perLabelThetaNormalizer=scoresPerLabel.like();    
+      for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(
+          new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(), conf)) {
+        if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) {
+          perLabelThetaNormalizer = entry.getSecond().get();
+        }
+      }
+      Preconditions.checkNotNull(perLabelThetaNormalizer);
+    }
+     
+    return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perLabelThetaNormalizer,
+        alphaI, isComplementary);
+  }
+
+  /** Write the list of labels into a map file */
+  public static int writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath)
+    throws IOException {
+    FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class);
+    int i = 0;
+    try {
+      for (String label : labels) {
+        writer.append(new Text(label), new IntWritable(i++));
+      }
+    } finally {
+      Closeables.close(writer, false);
+    }
+    return i;
+  }
+
+  public static int writeLabelIndex(Configuration conf, Path indexPath,
+                                    Iterable<Pair<Text,IntWritable>> labels) throws IOException {
+    FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+    SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class);
+    Collection<String> seen = Sets.newHashSet();
+    int i = 0;
+    try {
+      for (Object label : labels) {
+        String theLabel = SLASH.split(((Pair<?, ?>) label).getFirst().toString())[1];
+        if (!seen.contains(theLabel)) {
+          writer.append(new Text(theLabel), new IntWritable(i++));
+          seen.add(theLabel);
+        }
+      }
+    } finally {
+      Closeables.close(writer, false);
+    }
+    return i;
+  }
+
+  public static Map<Integer, String> readLabelIndex(Configuration conf, Path indexPath) {
+    Map<Integer, String> labelMap = new HashMap<>();
+    for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(indexPath, true, conf)) {
+      labelMap.put(pair.getSecond().get(), pair.getFirst().toString());
+    }
+    return labelMap;
+  }
+
+  public static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException {
+    OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<>();
+    for (Pair<Writable,IntWritable> entry
+        : new SequenceFileIterable<Writable,IntWritable>(HadoopUtil.getSingleCachedFile(conf), conf)) {
+      index.put(entry.getFirst().toString(), entry.getSecond().get());
+    }
+    return index;
+  }
+
+  public static Map<String,Vector> readScoresFromCache(Configuration conf) throws IOException {
+    Map<String,Vector> sumVectors = Maps.newHashMap();
+    for (Pair<Text,VectorWritable> entry
+        : new SequenceFileDirIterable<Text,VectorWritable>(HadoopUtil.getSingleCachedFile(conf),
+          PathType.LIST, PathFilters.partFilter(), conf)) {
+      sumVectors.put(entry.getFirst().toString(), entry.getSecond().get());
+    }
+    return sumVectors;
+  }
+
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
new file mode 100644
index 0000000..18bd3d6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/** Implementation of the Naive Bayes Classifier Algorithm */
+public class ComplementaryNaiveBayesClassifier extends AbstractNaiveBayesClassifier {
+  public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) {
+    super(model);
+  }
+
+  @Override
+  public double getScoreForLabelFeature(int label, int feature) {
+    NaiveBayesModel model = getModel();
+    double weight = computeWeight(model.featureWeight(feature), model.weight(label, feature),
+        model.totalWeightSum(), model.labelWeight(label), model.alphaI(), model.numFeatures());
+    // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+    return weight / model.thetaNormalizer(label);
+  }
+
+  // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.1, Skewed Data bias
+  public static double computeWeight(double featureWeight, double featureLabelWeight,
+      double totalWeight, double labelWeight, double alphaI, double numFeatures) {
+    double numerator = featureWeight - featureLabelWeight + alphaI;
+    double denominator = totalWeight - labelWeight + alphaI * numFeatures;
+    return -Math.log(numerator / denominator);
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
new file mode 100644
index 0000000..f180e8b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
@@ -0,0 +1,176 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+
+/** NaiveBayesModel holds the weight matrix, the feature and label sums and the weight normalizer vectors.*/
+public class NaiveBayesModel {
+
+  private final Vector weightsPerLabel;
+  private final Vector perlabelThetaNormalizer;
+  private final Vector weightsPerFeature;
+  private final Matrix weightsPerLabelAndFeature;
+  private final float alphaI;
+  private final double numFeatures;
+  private final double totalWeightSum;
+  private final boolean isComplementary;  
+   
+  public final static String COMPLEMENTARY_MODEL = "COMPLEMENTARY_MODEL";
+
+  public NaiveBayesModel(Matrix weightMatrix, Vector weightsPerFeature, Vector weightsPerLabel, Vector thetaNormalizer,
+                         float alphaI, boolean isComplementary) {
+    this.weightsPerLabelAndFeature = weightMatrix;
+    this.weightsPerFeature = weightsPerFeature;
+    this.weightsPerLabel = weightsPerLabel;
+    this.perlabelThetaNormalizer = thetaNormalizer;
+    this.numFeatures = weightsPerFeature.getNumNondefaultElements();
+    this.totalWeightSum = weightsPerLabel.zSum();
+    this.alphaI = alphaI;
+    this.isComplementary=isComplementary;
+  }
+
+  public double labelWeight(int label) {
+    return weightsPerLabel.getQuick(label);
+  }
+
+  public double thetaNormalizer(int label) {
+    return perlabelThetaNormalizer.get(label); 
+  }
+
+  public double featureWeight(int feature) {
+    return weightsPerFeature.getQuick(feature);
+  }
+
+  public double weight(int label, int feature) {
+    return weightsPerLabelAndFeature.getQuick(label, feature);
+  }
+
+  public float alphaI() {
+    return alphaI;
+  }
+
+  public double numFeatures() {
+    return numFeatures;
+  }
+
+  public double totalWeightSum() {
+    return totalWeightSum;
+  }
+  
+  public int numLabels() {
+    return weightsPerLabel.size();
+  }
+
+  public Vector createScoringVector() {
+    return weightsPerLabel.like();
+  }
+  
+  public boolean isComplemtary(){
+      return isComplementary;
+  }
+  
+  public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException {
+    FileSystem fs = output.getFileSystem(conf);
+
+    Vector weightsPerLabel = null;
+    Vector perLabelThetaNormalizer = null;
+    Vector weightsPerFeature = null;
+    Matrix weightsPerLabelAndFeature;
+    float alphaI;
+    boolean isComplementary;
+
+    FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin"));
+    try {
+      alphaI = in.readFloat();
+      isComplementary = in.readBoolean();
+      weightsPerFeature = VectorWritable.readVector(in);
+      weightsPerLabel = new DenseVector(VectorWritable.readVector(in));
+      if (isComplementary){
+        perLabelThetaNormalizer = new DenseVector(VectorWritable.readVector(in));
+      }
+      weightsPerLabelAndFeature = new SparseRowMatrix(weightsPerLabel.size(), weightsPerFeature.size());
+      for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) {
+        weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in));
+      }
+    } finally {
+      Closeables.close(in, true);
+    }
+    NaiveBayesModel model = new NaiveBayesModel(weightsPerLabelAndFeature, weightsPerFeature, weightsPerLabel,
+        perLabelThetaNormalizer, alphaI, isComplementary);
+    model.validate();
+    return model;
+  }
+
+  public void serialize(Path output, Configuration conf) throws IOException {
+    FileSystem fs = output.getFileSystem(conf);
+    FSDataOutputStream out = fs.create(new Path(output, "naiveBayesModel.bin"));
+    try {
+      out.writeFloat(alphaI);
+      out.writeBoolean(isComplementary);
+      VectorWritable.writeVector(out, weightsPerFeature);
+      VectorWritable.writeVector(out, weightsPerLabel); 
+      if (isComplementary){
+        VectorWritable.writeVector(out, perlabelThetaNormalizer);
+      }
+      for (int row = 0; row < weightsPerLabelAndFeature.numRows(); row++) {
+        VectorWritable.writeVector(out, weightsPerLabelAndFeature.viewRow(row));
+      }
+    } finally {
+      Closeables.close(out, false);
+    }
+  }
+  
+  public void validate() {
+    Preconditions.checkState(alphaI > 0, "alphaI has to be greater than 0!");
+    Preconditions.checkArgument(numFeatures > 0, "the vocab count has to be greater than 0!");
+    Preconditions.checkArgument(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!");
+    Preconditions.checkNotNull(weightsPerLabel, "the number of labels has to be defined!");
+    Preconditions.checkArgument(weightsPerLabel.getNumNondefaultElements() > 0,
+        "the number of labels has to be greater than 0!");
+    Preconditions.checkNotNull(weightsPerFeature, "the feature sums have to be defined");
+    Preconditions.checkArgument(weightsPerFeature.getNumNondefaultElements() > 0,
+        "the feature sums have to be greater than 0!");
+    if (isComplementary){
+        Preconditions.checkArgument(perlabelThetaNormalizer != null, "the theta normalizers have to be defined");
+        Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() > 0,
+            "the number of theta normalizers has to be greater than 0!");    
+        Preconditions.checkArgument(Math.signum(perlabelThetaNormalizer.minValue()) 
+                == Math.signum(perlabelThetaNormalizer.maxValue()), 
+           "Theta normalizers do not all have the same sign");            
+        Preconditions.checkArgument(perlabelThetaNormalizer.getNumNonZeroElements() 
+                == perlabelThetaNormalizer.size(), 
+           "Theta normalizers can not have zero value.");
+    }
+    
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
new file mode 100644
index 0000000..e4ce8aa
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/** Implementation of the Naive Bayes Classifier Algorithm */
+public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier { 
+ 
+  public StandardNaiveBayesClassifier(NaiveBayesModel model) {
+    super(model);
+  }
+
+  @Override
+  public double getScoreForLabelFeature(int label, int feature) {
+    NaiveBayesModel model = getModel();
+    // Standard Naive Bayes does not use weight normalization
+    return computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI(), model.numFeatures());
+  }
+
+  public static double computeWeight(double featureLabelWeight, double labelWeight, double alphaI, double numFeatures) {
+    double numerator = featureLabelWeight + alphaI;
+    double denominator = labelWeight + alphaI * numFeatures;
+    return Math.log(numerator / denominator);
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
new file mode 100644
index 0000000..37a3b71
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.test;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+/**
+ * Run the input through the model and see if it matches.
+ * <p/>
+ * The output value is the generated label, the Pair is the expected label and true if they match:
+ */
+public class BayesTestMapper extends Mapper<Text, VectorWritable, Text, VectorWritable> {
+
+  private static final Pattern SLASH = Pattern.compile("/");
+
+  private AbstractNaiveBayesClassifier classifier;
+
+  @Override
+  protected void setup(Context context) throws IOException, InterruptedException {
+    super.setup(context);
+    Configuration conf = context.getConfiguration();
+    Path modelPath = HadoopUtil.getSingleCachedFile(conf);
+    NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf);
+    boolean isComplementary = Boolean.parseBoolean(conf.get(TestNaiveBayesDriver.COMPLEMENTARY));
+    
+    // ensure that if we are testing in complementary mode, the model has been
+    // trained complementary. a complementarty model will work for standard classification
+    // a standard model will not work for complementary classification
+    if (isComplementary) {
+      Preconditions.checkArgument((model.isComplemtary()),
+          "Complementary mode in model is different than test mode");
+    }
+    
+    if (isComplementary) {
+      classifier = new ComplementaryNaiveBayesClassifier(model);
+    } else {
+      classifier = new StandardNaiveBayesClassifier(model);
+    }
+  }
+
+  @Override
+  protected void map(Text key, VectorWritable value, Context context) throws IOException, InterruptedException {
+    Vector result = classifier.classifyFull(value.get());
+    //the key is the expected value
+    context.write(new Text(SLASH.split(key.toString())[1]), new VectorWritable(result));
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
new file mode 100644
index 0000000..8fd422f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
@@ -0,0 +1,179 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.test;
+
+import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Test the (Complementary) Naive Bayes model that was built during training
+ * by running the iterating the test set and comparing it to the model
+ */
+public class TestNaiveBayesDriver extends AbstractJob {
+
+  private static final Logger log = LoggerFactory.getLogger(TestNaiveBayesDriver.class);
+
+  public static final String COMPLEMENTARY = "class"; //b for bayes, c for complementary
+  private static final Pattern SLASH = Pattern.compile("/");
+
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new Configuration(), new TestNaiveBayesDriver(), args);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+    addInputOption();
+    addOutputOption();
+    addOption(addOption(DefaultOptionCreator.overwriteOption().create()));
+    addOption("model", "m", "The path to the model built during training", true);
+    addOption(buildOption("testComplementary", "c", "test complementary?", false, false, String.valueOf(false)));
+    addOption(buildOption("runSequential", "seq", "run sequential?", false, false, String.valueOf(false)));
+    addOption("labelIndex", "l", "The path to the location of the label index", true);
+    Map<String, List<String>> parsedArgs = parseArguments(args);
+    if (parsedArgs == null) {
+      return -1;
+    }
+    if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+      HadoopUtil.delete(getConf(), getOutputPath());
+    }
+
+    boolean sequential = hasOption("runSequential");
+    boolean succeeded;
+    if (sequential) {
+       runSequential();
+    } else {
+      succeeded = runMapReduce();
+      if (!succeeded) {
+        return -1;
+      }
+    }
+
+    //load the labels
+    Map<Integer, String> labelMap = BayesUtils.readLabelIndex(getConf(), new Path(getOption("labelIndex")));
+
+    //loop over the results and create the confusion matrix
+    SequenceFileDirIterable<Text, VectorWritable> dirIterable =
+        new SequenceFileDirIterable<>(getOutputPath(), PathType.LIST, PathFilters.partFilter(), getConf());
+    ResultAnalyzer analyzer = new ResultAnalyzer(labelMap.values(), "DEFAULT");
+    analyzeResults(labelMap, dirIterable, analyzer);
+
+    log.info("{} Results: {}", hasOption("testComplementary") ? "Complementary" : "Standard NB", analyzer);
+    return 0;
+  }
+
+  private void runSequential() throws IOException {
+    boolean complementary = hasOption("testComplementary");
+    FileSystem fs = FileSystem.get(getConf());
+    NaiveBayesModel model = NaiveBayesModel.materialize(new Path(getOption("model")), getConf());
+    
+    // Ensure that if we are testing in complementary mode, the model has been
+    // trained complementary. a complementarty model will work for standard classification
+    // a standard model will not work for complementary classification
+    if (complementary){
+        Preconditions.checkArgument((model.isComplemtary()),
+            "Complementary mode in model is different from test mode");
+    }
+    
+    AbstractNaiveBayesClassifier classifier;
+    if (complementary) {
+      classifier = new ComplementaryNaiveBayesClassifier(model);
+    } else {
+      classifier = new StandardNaiveBayesClassifier(model);
+    }
+    SequenceFile.Writer writer = SequenceFile.createWriter(fs, getConf(), new Path(getOutputPath(), "part-r-00000"),
+        Text.class, VectorWritable.class);
+
+    try {
+      SequenceFileDirIterable<Text, VectorWritable> dirIterable =
+          new SequenceFileDirIterable<>(getInputPath(), PathType.LIST, PathFilters.partFilter(), getConf());
+      // loop through the part-r-* files in getInputPath() and get classification scores for all entries
+      for (Pair<Text, VectorWritable> pair : dirIterable) {
+        writer.append(new Text(SLASH.split(pair.getFirst().toString())[1]),
+            new VectorWritable(classifier.classifyFull(pair.getSecond().get())));
+      }
+    } finally {
+      Closeables.close(writer, false);
+    }
+  }
+
+  private boolean runMapReduce() throws IOException,
+      InterruptedException, ClassNotFoundException {
+    Path model = new Path(getOption("model"));
+    HadoopUtil.cacheFiles(model, getConf());
+    //the output key is the expected value, the output value are the scores for all the labels
+    Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class,
+        Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+    //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels"));
+
+
+    boolean complementary = hasOption("testComplementary");
+    testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary));
+    return testJob.waitForCompletion(true);
+  }
+
+  private static void analyzeResults(Map<Integer, String> labelMap,
+                                     SequenceFileDirIterable<Text, VectorWritable> dirIterable,
+                                     ResultAnalyzer analyzer) {
+    for (Pair<Text, VectorWritable> pair : dirIterable) {
+      int bestIdx = Integer.MIN_VALUE;
+      double bestScore = Long.MIN_VALUE;
+      for (Vector.Element element : pair.getSecond().get().all()) {
+        if (element.get() > bestScore) {
+          bestScore = element.get();
+          bestIdx = element.index();
+        }
+      }
+      if (bestIdx != Integer.MIN_VALUE) {
+        ClassifierResult classifierResult = new ClassifierResult(labelMap.get(bestIdx), bestScore);
+        analyzer.addInstance(pair.getFirst().toString(), classifierResult);
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
new file mode 100644
index 0000000..2b8ee1e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.math.Vector;
+
+public class ComplementaryThetaTrainer {
+
+  private final Vector weightsPerFeature;
+  private final Vector weightsPerLabel;
+  private final Vector perLabelThetaNormalizer;
+  private final double alphaI;
+  private final double totalWeightSum;
+  private final double numFeatures;
+
+  public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) {
+    Preconditions.checkNotNull(weightsPerFeature);
+    Preconditions.checkNotNull(weightsPerLabel);
+    this.weightsPerFeature = weightsPerFeature;
+    this.weightsPerLabel = weightsPerLabel;
+    this.alphaI = alphaI;
+    perLabelThetaNormalizer = weightsPerLabel.like();
+    totalWeightSum = weightsPerLabel.zSum();
+    numFeatures = weightsPerFeature.getNumNondefaultElements();
+  }
+
+  public void train(int label, Vector perLabelWeight) {
+    double labelWeight = labelWeight(label);
+    // sum weights for each label including those with zero word counts
+    for(int i = 0; i < perLabelWeight.size(); i++){
+      Vector.Element perLabelWeightElement = perLabelWeight.getElement(i);
+      updatePerLabelThetaNormalizer(label,
+          ComplementaryNaiveBayesClassifier.computeWeight(featureWeight(perLabelWeightElement.index()),
+              perLabelWeightElement.get(), totalWeightSum(), labelWeight, alphaI(), numFeatures()));
+    }
+  }
+
+  protected double alphaI() {
+    return alphaI;
+  }
+
+  protected double numFeatures() {
+    return numFeatures;
+  }
+
+  protected double labelWeight(int label) {
+    return weightsPerLabel.get(label);
+  }
+
+  protected double totalWeightSum() {
+    return totalWeightSum;
+  }
+
+  protected double featureWeight(int feature) {
+    return weightsPerFeature.get(feature);
+  }
+
+  // http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+  protected void updatePerLabelThetaNormalizer(int label, double weight) {
+    perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + Math.abs(weight));
+  }
+
+  public Vector retrievePerLabelThetaNormalizer() {
+    return perLabelThetaNormalizer.clone();
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
new file mode 100644
index 0000000..40ca2e9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class IndexInstancesMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
+
+  private static final Pattern SLASH = Pattern.compile("/");
+
+  public enum Counter { SKIPPED_INSTANCES }
+
+  private OpenObjectIntHashMap<String> labelIndex;
+
+  @Override
+  protected void setup(Context ctx) throws IOException, InterruptedException {
+    super.setup(ctx);
+    labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration());
+  }
+
+  @Override
+  protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException {
+    String label = SLASH.split(labelText.toString())[1];
+    if (labelIndex.containsKey(label)) {
+      ctx.write(new IntWritable(labelIndex.get(label)), instance);
+    } else {
+      ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
new file mode 100644
index 0000000..ff2ea40
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public class ThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+  public static final String ALPHA_I = ThetaMapper.class.getName() + ".alphaI";
+  static final String TRAIN_COMPLEMENTARY = ThetaMapper.class.getName() + ".trainComplementary";
+
+  private ComplementaryThetaTrainer trainer;
+
+  @Override
+  protected void setup(Context ctx) throws IOException, InterruptedException {
+    super.setup(ctx);
+    Configuration conf = ctx.getConfiguration();
+
+    float alphaI = conf.getFloat(ALPHA_I, 1.0f);
+    Map<String, Vector> scores = BayesUtils.readScoresFromCache(conf);    
+    
+    trainer = new ComplementaryThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
+                                            scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), alphaI);
+  }
+
+  @Override
+  protected void map(IntWritable key, VectorWritable value, Context ctx) throws IOException, InterruptedException {
+    trainer.train(key.get(), value.get());
+  }
+
+  @Override
+  protected void cleanup(Context ctx) throws IOException, InterruptedException {
+    ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER),
+        new VectorWritable(trainer.retrievePerLabelThetaNormalizer()));
+    super.cleanup(ctx);
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
new file mode 100644
index 0000000..ac1c4c9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
@@ -0,0 +1,186 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.mapreduce.VectorSumReducer;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.base.Splitter;
+
+/** Trains a Naive Bayes Classifier (parameters for both Naive Bayes and Complementary Naive Bayes) */
+public final class TrainNaiveBayesJob extends AbstractJob {
+  private static final String TRAIN_COMPLEMENTARY = "trainComplementary";
+  private static final String ALPHA_I = "alphaI";
+  private static final String LABEL_INDEX = "labelIndex";
+  private static final String EXTRACT_LABELS = "extractLabels";
+  private static final String LABELS = "labels";
+  public static final String WEIGHTS_PER_FEATURE = "__SPF";
+  public static final String WEIGHTS_PER_LABEL = "__SPL";
+  public static final String LABEL_THETA_NORMALIZER = "_LTN";
+
+  public static final String SUMMED_OBSERVATIONS = "summedObservations";
+  public static final String WEIGHTS = "weights";
+  public static final String THETAS = "thetas";
+
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new Configuration(), new TrainNaiveBayesJob(), args);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+
+    addInputOption();
+    addOutputOption();
+    addOption(LABELS, "l", "comma-separated list of labels to include in training", false);
+
+    addOption(buildOption(EXTRACT_LABELS, "el", "Extract the labels from the input", false, false, ""));
+    addOption(ALPHA_I, "a", "smoothing parameter", String.valueOf(1.0f));
+    addOption(buildOption(TRAIN_COMPLEMENTARY, "c", "train complementary?", false, false, String.valueOf(false)));
+    addOption(LABEL_INDEX, "li", "The path to store the label index in", false);
+    addOption(DefaultOptionCreator.overwriteOption().create());
+    Map<String, List<String>> parsedArgs = parseArguments(args);
+    if (parsedArgs == null) {
+      return -1;
+    }
+    if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+      HadoopUtil.delete(getConf(), getOutputPath());
+      HadoopUtil.delete(getConf(), getTempPath());
+    }
+    Path labPath;
+    String labPathStr = getOption(LABEL_INDEX);
+    if (labPathStr != null) {
+      labPath = new Path(labPathStr);
+    } else {
+      labPath = getTempPath(LABEL_INDEX);
+    }
+    long labelSize = createLabelIndex(labPath);
+    float alphaI = Float.parseFloat(getOption(ALPHA_I));
+    boolean trainComplementary = hasOption(TRAIN_COMPLEMENTARY);
+
+    HadoopUtil.setSerializations(getConf());
+    HadoopUtil.cacheFiles(labPath, getConf());
+
+    // Add up all the vectors with the same labels, while mapping the labels into our index
+    Job indexInstances = prepareJob(getInputPath(),
+                                    getTempPath(SUMMED_OBSERVATIONS),
+                                    SequenceFileInputFormat.class,
+                                    IndexInstancesMapper.class,
+                                    IntWritable.class,
+                                    VectorWritable.class,
+                                    VectorSumReducer.class,
+                                    IntWritable.class,
+                                    VectorWritable.class,
+                                    SequenceFileOutputFormat.class);
+    indexInstances.setCombinerClass(VectorSumReducer.class);
+    boolean succeeded = indexInstances.waitForCompletion(true);
+    if (!succeeded) {
+      return -1;
+    }
+    // Sum up all the weights from the previous step, per label and per feature
+    Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),
+                                  getTempPath(WEIGHTS),
+                                  SequenceFileInputFormat.class,
+                                  WeightsMapper.class,
+                                  Text.class,
+                                  VectorWritable.class,
+                                  VectorSumReducer.class,
+                                  Text.class,
+                                  VectorWritable.class,
+                                  SequenceFileOutputFormat.class);
+    weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(labelSize));
+    weightSummer.setCombinerClass(VectorSumReducer.class);
+    succeeded = weightSummer.waitForCompletion(true);
+    if (!succeeded) {
+      return -1;
+    }
+
+    // Put the per label and per feature vectors into the cache
+    HadoopUtil.cacheFiles(getTempPath(WEIGHTS), getConf());
+
+    if (trainComplementary){
+      // Calculate the per label theta normalizers, write out to LABEL_THETA_NORMALIZER vector
+      // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+      Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),
+                                   getTempPath(THETAS),
+                                   SequenceFileInputFormat.class,
+                                   ThetaMapper.class,
+                                   Text.class,
+                                   VectorWritable.class,
+                                   VectorSumReducer.class,
+                                   Text.class,
+                                   VectorWritable.class,
+                                   SequenceFileOutputFormat.class);
+      thetaSummer.setCombinerClass(VectorSumReducer.class);
+      thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI);
+      thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary);
+      succeeded = thetaSummer.waitForCompletion(true);
+      if (!succeeded) {
+        return -1;
+      }
+    }
+    
+    // Put the per label theta normalizers into the cache
+    HadoopUtil.cacheFiles(getTempPath(THETAS), getConf());
+    
+    // Validate our model and then write it out to the official output
+    getConf().setFloat(ThetaMapper.ALPHA_I, alphaI);
+    getConf().setBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, trainComplementary);
+    NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(getTempPath(), getConf());
+    naiveBayesModel.validate();
+    naiveBayesModel.serialize(getOutputPath(), getConf());
+
+    return 0;
+  }
+
+  private long createLabelIndex(Path labPath) throws IOException {
+    long labelSize = 0;
+    if (hasOption(LABELS)) {
+      Iterable<String> labels = Splitter.on(",").split(getOption(LABELS));
+      labelSize = BayesUtils.writeLabelIndex(getConf(), labels, labPath);
+    } else if (hasOption(EXTRACT_LABELS)) {
+      Iterable<Pair<Text,IntWritable>> iterable =
+          new SequenceFileDirIterable<Text, IntWritable>(getInputPath(),
+                                                         PathType.LIST,
+                                                         PathFilters.logsCRCFilter(),
+                                                         getConf());
+      labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable);
+    }
+    return labelSize;
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
new file mode 100644
index 0000000..5563057
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import com.google.common.base.Preconditions;
+
+public class WeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+  static final String NUM_LABELS = WeightsMapper.class.getName() + ".numLabels";
+
+  private Vector weightsPerFeature;
+  private Vector weightsPerLabel;
+
+  @Override
+  protected void setup(Context ctx) throws IOException, InterruptedException {
+    super.setup(ctx);
+    int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS));
+    Preconditions.checkArgument(numLabels > 0, "Wrong numLabels: " + numLabels + ". Must be > 0!");
+    weightsPerLabel = new DenseVector(numLabels);
+  }
+
+  @Override
+  protected void map(IntWritable index, VectorWritable value, Context ctx) throws IOException, InterruptedException {
+    Vector instance = value.get();
+    if (weightsPerFeature == null) {
+      weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements());
+    }
+
+    int label = index.get();
+    weightsPerFeature.assign(instance, Functions.PLUS);
+    weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum());
+  }
+
+  @Override
+  protected void cleanup(Context ctx) throws IOException, InterruptedException {
+    if (weightsPerFeature != null) {
+      ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature));
+      ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel));
+    }
+    super.cleanup(ctx);
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
new file mode 100644
index 0000000..942a101
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
@@ -0,0 +1,165 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.DataOutputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.Date;
+import java.util.List;
+import java.util.Scanner;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+
+/**
+ * A class for EM training of HMM from console
+ */
+public final class BaumWelchTrainer {
+
+  private BaumWelchTrainer() {
+  }
+
+  public static void main(String[] args) throws IOException {
+    DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+    ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+    Option inputOption = DefaultOptionCreator.inputOption().create();
+
+    Option outputOption = DefaultOptionCreator.outputOption().create();
+
+    Option stateNumberOption = optionBuilder.withLongName("nrOfHiddenStates").
+      withDescription("Number of hidden states").
+      withShortName("nh").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+      withName("number").create()).withRequired(true).create();
+
+    Option observedStateNumberOption = optionBuilder.withLongName("nrOfObservedStates").
+      withDescription("Number of observed states").
+      withShortName("no").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+      withName("number").create()).withRequired(true).create();
+
+    Option epsilonOption = optionBuilder.withLongName("epsilon").
+      withDescription("Convergence threshold").
+      withShortName("e").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+      withName("number").create()).withRequired(true).create();
+
+    Option iterationsOption = optionBuilder.withLongName("max-iterations").
+      withDescription("Maximum iterations number").
+      withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+      withName("number").create()).withRequired(true).create();
+
+    Group optionGroup = new GroupBuilder().withOption(inputOption).
+      withOption(outputOption).withOption(stateNumberOption).withOption(observedStateNumberOption).
+      withOption(epsilonOption).withOption(iterationsOption).
+      withName("Options").create();
+
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(optionGroup);
+      CommandLine commandLine = parser.parse(args);
+
+      String input = (String) commandLine.getValue(inputOption);
+      String output = (String) commandLine.getValue(outputOption);
+
+      int nrOfHiddenStates = Integer.parseInt((String) commandLine.getValue(stateNumberOption));
+      int nrOfObservedStates = Integer.parseInt((String) commandLine.getValue(observedStateNumberOption));
+
+      double epsilon = Double.parseDouble((String) commandLine.getValue(epsilonOption));
+      int maxIterations = Integer.parseInt((String) commandLine.getValue(iterationsOption));
+
+      //constructing random-generated HMM
+      HmmModel model = new HmmModel(nrOfHiddenStates, nrOfObservedStates, new Date().getTime());
+      List<Integer> observations = Lists.newArrayList();
+
+      //reading observations
+      try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) {
+        while (scanner.hasNextInt()) {
+          observations.add(scanner.nextInt());
+        }
+      }
+
+      int[] observationsArray = new int[observations.size()];
+      for (int i = 0; i < observations.size(); ++i) {
+        observationsArray[i] = observations.get(i);
+      }
+
+      //training
+      HmmModel trainedModel = HmmTrainer.trainBaumWelch(model,
+        observationsArray, epsilon, maxIterations, true);
+
+      //serializing trained model
+      DataOutputStream stream  = new DataOutputStream(new FileOutputStream(output));
+      try {
+        LossyHmmSerializer.serialize(trainedModel, stream);
+      } finally {
+        Closeables.close(stream, false);
+      }
+
+      //printing tranied model
+      System.out.println("Initial probabilities: ");
+      for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+        System.out.print(i + " ");
+      }
+      System.out.println();
+      for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+        System.out.print(trainedModel.getInitialProbabilities().get(i) + " ");
+      }
+      System.out.println();
+
+      System.out.println("Transition matrix:");
+      System.out.print("  ");
+      for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+        System.out.print(i + " ");
+      }
+      System.out.println();
+      for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+        System.out.print(i + " ");
+        for (int j = 0; j < trainedModel.getNrOfHiddenStates(); ++j) {
+          System.out.print(trainedModel.getTransitionMatrix().get(i, j) + " ");
+        }
+        System.out.println();
+      }
+      System.out.println("Emission matrix: ");
+      System.out.print("  ");
+      for (int i = 0; i < trainedModel.getNrOfOutputStates(); ++i) {
+        System.out.print(i + " ");
+      }
+      System.out.println();
+      for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+        System.out.print(i + " ");
+        for (int j = 0; j < trainedModel.getNrOfOutputStates(); ++j) {
+          System.out.print(trainedModel.getEmissionMatrix().get(i, j) + " ");
+        }
+        System.out.println();
+      }
+    } catch (OptionException e) {
+      CommandLineUtil.printHelp(optionGroup);
+    }
+  }
+}