You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/28 14:54:41 UTC

[13/51] [partial] mahout git commit: NO-JIRA Clean up MR refactor

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
new file mode 100644
index 0000000..56b1a04
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.commons.math3.stat.descriptive.rank.Percentile;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataUtils;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.TreeSet;
+
+/**
+ * <p>Optimized implementation of IgSplit.
+ * This class can be used when the criterion variable is the categorical attribute.</p>
+ *
+ * <p>This code was changed in MAHOUT-1419 to deal in sampled splits among numeric
+ * features to fix a performance problem. To generate some synthetic data that exercises
+ * the issue, try for example generating 4 features of Normal(0,1) values with a random
+ * boolean 0/1 categorical feature. In Scala:</p>
+ *
+ * {@code
+ *  val r = new scala.util.Random()
+ *  val pw = new java.io.PrintWriter("random.csv")
+ *  (1 to 10000000).foreach(e =>
+ *    pw.println(r.nextDouble() + "," +
+ *               r.nextDouble() + "," +
+ *               r.nextDouble() + "," +
+ *               r.nextDouble() + "," +
+ *               (if (r.nextBoolean()) 1 else 0))
+ *   )
+ *   pw.close()
+ * }
+ */
+@Deprecated
+public class OptIgSplit extends IgSplit {
+
+  private static final int MAX_NUMERIC_SPLITS = 16;
+
+  @Override
+  public Split computeSplit(Data data, int attr) {
+    if (data.getDataset().isNumerical(attr)) {
+      return numericalSplit(data, attr);
+    } else {
+      return categoricalSplit(data, attr);
+    }
+  }
+
+  /**
+   * Computes the split for a CATEGORICAL attribute
+   */
+  private static Split categoricalSplit(Data data, int attr) {
+    double[] values = data.values(attr).clone();
+
+    double[] splitPoints = chooseCategoricalSplitPoints(values);
+
+    int numLabels = data.getDataset().nblabels();
+    int[][] counts = new int[splitPoints.length][numLabels];
+    int[] countAll = new int[numLabels];
+
+    computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+    int size = data.size();
+    double hy = entropy(countAll, size); // H(Y)
+    double hyx = 0.0; // H(Y|X)
+    double invDataSize = 1.0 / size;
+
+    for (int index = 0; index < splitPoints.length; index++) {
+      size = DataUtils.sum(counts[index]);
+      hyx += size * invDataSize * entropy(counts[index], size);
+    }
+
+    double ig = hy - hyx;
+    return new Split(attr, ig);
+  }
+
+  static void computeFrequencies(Data data,
+                                 int attr,
+                                 double[] splitPoints,
+                                 int[][] counts,
+                                 int[] countAll) {
+    Dataset dataset = data.getDataset();
+
+    for (int index = 0; index < data.size(); index++) {
+      Instance instance = data.get(index);
+      int label = (int) dataset.getLabel(instance);
+      double value = instance.get(attr);
+      int split = 0;
+      while (split < splitPoints.length && value > splitPoints[split]) {
+        split++;
+      }
+      if (split < splitPoints.length) {
+        counts[split][label]++;
+      } // Otherwise it's in the last split, which we don't need to count
+      countAll[label]++;
+    }
+  }
+
+  /**
+   * Computes the best split for a NUMERICAL attribute
+   */
+  static Split numericalSplit(Data data, int attr) {
+    double[] values = data.values(attr).clone();
+    Arrays.sort(values);
+
+    double[] splitPoints = chooseNumericSplitPoints(values);
+
+    int numLabels = data.getDataset().nblabels();
+    int[][] counts = new int[splitPoints.length][numLabels];
+    int[] countAll = new int[numLabels];
+    int[] countLess = new int[numLabels];
+
+    computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+    int size = data.size();
+    double hy = entropy(countAll, size);
+    double invDataSize = 1.0 / size;
+
+    int best = -1;
+    double bestIg = -1.0;
+
+    // try each possible split value
+    for (int index = 0; index < splitPoints.length; index++) {
+      double ig = hy;
+
+      DataUtils.add(countLess, counts[index]);
+      DataUtils.dec(countAll, counts[index]);
+
+      // instance with attribute value < values[index]
+      size = DataUtils.sum(countLess);
+      ig -= size * invDataSize * entropy(countLess, size);
+      // instance with attribute value >= values[index]
+      size = DataUtils.sum(countAll);
+      ig -= size * invDataSize * entropy(countAll, size);
+
+      if (ig > bestIg) {
+        bestIg = ig;
+        best = index;
+      }
+    }
+
+    if (best == -1) {
+      throw new IllegalStateException("no best split found !");
+    }
+    return new Split(attr, bestIg, splitPoints[best]);
+  }
+
+  /**
+   * @return an array of values to split the numeric feature's values on when
+   *  building candidate splits. When input size is <= MAX_NUMERIC_SPLITS + 1, it will
+   *  return the averages between success values as split points. When larger, it will
+   *  return MAX_NUMERIC_SPLITS approximate percentiles through the data.
+   */
+  private static double[] chooseNumericSplitPoints(double[] values) {
+    if (values.length <= 1) {
+      return values;
+    }
+    if (values.length <= MAX_NUMERIC_SPLITS + 1) {
+      double[] splitPoints = new double[values.length - 1];
+      for (int i = 1; i < values.length; i++) {
+        splitPoints[i-1] = (values[i] + values[i-1]) / 2.0;
+      }
+      return splitPoints;
+    }
+    Percentile distribution = new Percentile();
+    distribution.setData(values);
+    double[] percentiles = new double[MAX_NUMERIC_SPLITS];
+    for (int i = 0 ; i < percentiles.length; i++) {
+      double p = 100.0 * ((i + 1.0) / (MAX_NUMERIC_SPLITS + 1.0));
+      percentiles[i] = distribution.evaluate(p);
+    }
+    return percentiles;
+  }
+
+  private static double[] chooseCategoricalSplitPoints(double[] values) {
+    // There is no great reason to believe that categorical value order matters,
+    // but the original code worked this way, and it's not terrible in the absence
+    // of more sophisticated analysis
+    Collection<Double> uniqueOrderedCategories = new TreeSet<>();
+    for (double v : values) {
+      uniqueOrderedCategories.add(v);
+    }
+    double[] uniqueValues = new double[uniqueOrderedCategories.size()];
+    Iterator<Double> it = uniqueOrderedCategories.iterator();
+    for (int i = 0; i < uniqueValues.length; i++) {
+      uniqueValues[i] = it.next();
+    }
+    return uniqueValues;
+  }
+
+  /**
+   * Computes the Entropy
+   *
+   * @param counts   counts[i] = numInstances with label i
+   * @param dataSize numInstances
+   */
+  private static double entropy(int[] counts, int dataSize) {
+    if (dataSize == 0) {
+      return 0.0;
+    }
+
+    double entropy = 0.0;
+
+    for (int count : counts) {
+      if (count > 0) {
+        double p = count / (double) dataSize;
+        entropy -= p * Math.log(p);
+      }
+    }
+
+    return entropy / LOG2;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
new file mode 100644
index 0000000..38695a3
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
@@ -0,0 +1,177 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Comparator;
+
+/**
+ * Regression problem implementation of IgSplit. This class can be used when the criterion variable is the numerical
+ * attribute.
+ */
+@Deprecated
+public class RegressionSplit extends IgSplit {
+  
+  /**
+   * Comparator for Instance sort
+   */
+  private static class InstanceComparator implements Comparator<Instance>, Serializable {
+    private final int attr;
+
+    InstanceComparator(int attr) {
+      this.attr = attr;
+    }
+    
+    @Override
+    public int compare(Instance arg0, Instance arg1) {
+      return Double.compare(arg0.get(attr), arg1.get(attr));
+    }
+  }
+  
+  @Override
+  public Split computeSplit(Data data, int attr) {
+    if (data.getDataset().isNumerical(attr)) {
+      return numericalSplit(data, attr);
+    } else {
+      return categoricalSplit(data, attr);
+    }
+  }
+  
+  /**
+   * Computes the split for a CATEGORICAL attribute
+   */
+  private static Split categoricalSplit(Data data, int attr) {
+    FullRunningAverage[] ra = new FullRunningAverage[data.getDataset().nbValues(attr)];
+    double[] sk = new double[data.getDataset().nbValues(attr)];
+    for (int i = 0; i < ra.length; i++) {
+      ra[i] = new FullRunningAverage();
+    }
+    FullRunningAverage totalRa = new FullRunningAverage();
+    double totalSk = 0.0;
+
+    for (int i = 0; i < data.size(); i++) {
+      // computes the variance
+      Instance instance = data.get(i);
+      int value = (int) instance.get(attr);
+      double xk = data.getDataset().getLabel(instance);
+      if (ra[value].getCount() == 0) {
+        ra[value].addDatum(xk);
+        sk[value] = 0.0;
+      } else {
+        double mk = ra[value].getAverage();
+        ra[value].addDatum(xk);
+        sk[value] += (xk - mk) * (xk - ra[value].getAverage());
+      }
+
+      // total variance
+      if (i == 0) {
+        totalRa.addDatum(xk);
+        totalSk = 0.0;
+      } else {
+        double mk = totalRa.getAverage();
+        totalRa.addDatum(xk);
+        totalSk += (xk - mk) * (xk - totalRa.getAverage());
+      }
+    }
+
+    // computes the variance gain
+    double ig = totalSk;
+    for (double aSk : sk) {
+      ig -= aSk;
+    }
+
+    return new Split(attr, ig);
+  }
+  
+  /**
+   * Computes the best split for a NUMERICAL attribute
+   */
+  private static Split numericalSplit(Data data, int attr) {
+    FullRunningAverage[] ra = new FullRunningAverage[2];
+    for (int i = 0; i < ra.length; i++) {
+      ra[i] = new FullRunningAverage();
+    }
+
+    // Instance sort
+    Instance[] instances = new Instance[data.size()];
+    for (int i = 0; i < data.size(); i++) {
+      instances[i] = data.get(i);
+    }
+    Arrays.sort(instances, new InstanceComparator(attr));
+
+    double[] sk = new double[2];
+    for (Instance instance : instances) {
+      double xk = data.getDataset().getLabel(instance);
+      if (ra[1].getCount() == 0) {
+        ra[1].addDatum(xk);
+        sk[1] = 0.0;
+      } else {
+        double mk = ra[1].getAverage();
+        ra[1].addDatum(xk);
+        sk[1] += (xk - mk) * (xk - ra[1].getAverage());
+      }
+    }
+    double totalSk = sk[1];
+
+    // find the best split point
+    double split = Double.NaN;
+    double preSplit = Double.NaN;
+    double bestVal = Double.MAX_VALUE;
+    double bestSk = 0.0;
+
+    // computes total variance
+    for (Instance instance : instances) {
+      double xk = data.getDataset().getLabel(instance);
+
+      if (instance.get(attr) > preSplit) {
+        double curVal = sk[0] / ra[0].getCount() + sk[1] / ra[1].getCount();
+        if (curVal < bestVal) {
+          bestVal = curVal;
+          bestSk = sk[0] + sk[1];
+          split = (instance.get(attr) + preSplit) / 2.0;
+        }
+      }
+
+      // computes the variance
+      if (ra[0].getCount() == 0) {
+        ra[0].addDatum(xk);
+        sk[0] = 0.0;
+      } else {
+        double mk = ra[0].getAverage();
+        ra[0].addDatum(xk);
+        sk[0] += (xk - mk) * (xk - ra[0].getAverage());
+      }
+
+      double mk = ra[1].getAverage();
+      ra[1].removeDatum(xk);
+      sk[1] -= (xk - mk) * (xk - ra[1].getAverage());
+
+      preSplit = instance.get(attr);
+    }
+
+    // computes the variance gain
+    double ig = totalSk - bestSk;
+
+    return new Split(attr, ig, split);
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
new file mode 100644
index 0000000..2a6a322
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import java.util.Locale;
+
+/**
+ * Contains enough information to identify each split
+ */
+@Deprecated
+public final class Split {
+  
+  private final int attr;
+  private final double ig;
+  private final double split;
+  
+  public Split(int attr, double ig, double split) {
+    this.attr = attr;
+    this.ig = ig;
+    this.split = split;
+  }
+  
+  public Split(int attr, double ig) {
+    this(attr, ig, Double.NaN);
+  }
+
+  /**
+   * @return attribute to split for
+   */
+  public int getAttr() {
+    return attr;
+  }
+
+  /**
+   * @return Information Gain of the split
+   */
+  public double getIg() {
+    return ig;
+  }
+
+  /**
+   * @return split value for NUMERICAL attributes
+   */
+  public double getSplit() {
+    return split;
+  }
+
+  @Override
+  public String toString() {
+    return String.format(Locale.ENGLISH, "attr: %d, ig: %f, split: %f", attr, ig, split);
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
new file mode 100644
index 0000000..f29faed
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
@@ -0,0 +1,166 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.DescriptorUtils;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Generates a file descriptor for a given dataset
+ */
+public final class Describe implements Tool {
+
+  private static final Logger log = LoggerFactory.getLogger(Describe.class);
+
+  private Describe() {}
+
+  public static int main(String[] args) throws Exception {
+    return ToolRunner.run(new Describe(), args);
+  }
+
+  @Override
+  public int run(String[] args) throws Exception {
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+
+    Option pathOpt = obuilder.withLongName("path").withShortName("p").withRequired(true).withArgument(
+        abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+    Option descriptorOpt = obuilder.withLongName("descriptor").withShortName("d").withRequired(true)
+        .withArgument(abuilder.withName("descriptor").withMinimum(1).create()).withDescription(
+            "data descriptor").create();
+
+    Option descPathOpt = obuilder.withLongName("file").withShortName("f").withRequired(true).withArgument(
+        abuilder.withName("file").withMinimum(1).withMaximum(1).create()).withDescription(
+        "Path to generated descriptor file").create();
+
+    Option regOpt = obuilder.withLongName("regression").withDescription("Regression Problem").withShortName("r")
+        .create();
+
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+        .create();
+
+    Group group = gbuilder.withName("Options").withOption(pathOpt).withOption(descPathOpt).withOption(
+        descriptorOpt).withOption(regOpt).withOption(helpOpt).create();
+
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        return -1;
+      }
+
+      String dataPath = cmdLine.getValue(pathOpt).toString();
+      String descPath = cmdLine.getValue(descPathOpt).toString();
+      List<String> descriptor = convert(cmdLine.getValues(descriptorOpt));
+      boolean regression = cmdLine.hasOption(regOpt);
+
+      log.debug("Data path : {}", dataPath);
+      log.debug("Descriptor path : {}", descPath);
+      log.debug("Descriptor : {}", descriptor);
+      log.debug("Regression : {}", regression);
+
+      runTool(dataPath, descriptor, descPath, regression);
+    } catch (OptionException e) {
+      log.warn(e.toString());
+      CommandLineUtil.printHelp(group);
+    }
+    return 0;
+  }
+
+  private void runTool(String dataPath, Iterable<String> description, String filePath, boolean regression)
+    throws DescriptorException, IOException {
+    log.info("Generating the descriptor...");
+    String descriptor = DescriptorUtils.generateDescriptor(description);
+
+    Path fPath = validateOutput(filePath);
+
+    log.info("generating the dataset...");
+    Dataset dataset = generateDataset(descriptor, dataPath, regression);
+
+    log.info("storing the dataset description");
+    String json = dataset.toJSON();
+    DFUtils.storeString(conf, fPath, json);
+  }
+
+  private Dataset generateDataset(String descriptor, String dataPath, boolean regression) throws IOException,
+      DescriptorException {
+    Path path = new Path(dataPath);
+    FileSystem fs = path.getFileSystem(conf);
+
+    return DataLoader.generateDataset(descriptor, regression, fs, path);
+  }
+
+  private Path validateOutput(String filePath) throws IOException {
+    Path path = new Path(filePath);
+    FileSystem fs = path.getFileSystem(conf);
+    if (fs.exists(path)) {
+      throw new IllegalStateException("Descriptor's file already exists");
+    }
+
+    return path;
+  }
+
+  private static List<String> convert(Collection<?> values) {
+    List<String> list = new ArrayList<>(values.size());
+    for (Object value : values) {
+      list.add(value.toString());
+    }
+    return list;
+  }
+
+  private Configuration conf;
+
+  @Override
+  public void setConf(Configuration entries) {
+    this.conf = entries;
+  }
+
+  @Override
+  public Configuration getConf() {
+    return conf;
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
new file mode 100644
index 0000000..b421c4e
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
@@ -0,0 +1,158 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This tool is to visualize the Decision Forest
+ */
+@Deprecated
+public final class ForestVisualizer {
+
+  private static final Logger log = LoggerFactory.getLogger(ForestVisualizer.class);
+
+  private ForestVisualizer() {
+  }
+
+  public static String toString(DecisionForest forest, Dataset dataset, String[] attrNames) {
+
+    List<Node> trees;
+    try {
+      Method getTrees = forest.getClass().getDeclaredMethod("getTrees");
+      getTrees.setAccessible(true);
+      trees = (List<Node>) getTrees.invoke(forest);
+    } catch (IllegalAccessException e) {
+      throw new IllegalStateException(e);
+    } catch (InvocationTargetException e) {
+      throw new IllegalStateException(e);
+    } catch (NoSuchMethodException e) {
+      throw new IllegalStateException(e);
+    }
+
+    int cnt = 1;
+    StringBuilder buff = new StringBuilder();
+    for (Node tree : trees) {
+      buff.append("Tree[").append(cnt).append("]:");
+      buff.append(TreeVisualizer.toString(tree, dataset, attrNames));
+      buff.append('\n');
+      cnt++;
+    }
+    return buff.toString();
+  }
+
+  /**
+   * Decision Forest to String
+   * @param forestPath
+   *          path to the Decision Forest
+   * @param datasetPath
+   *          dataset path
+   * @param attrNames
+   *          attribute names
+   */
+  public static String toString(String forestPath, String datasetPath, String[] attrNames) throws IOException {
+    Configuration conf = new Configuration();
+    DecisionForest forest = DecisionForest.load(conf, new Path(forestPath));
+    Dataset dataset = Dataset.load(conf, new Path(datasetPath));
+    return toString(forest, dataset, attrNames);
+  }
+
+  /**
+   * Print Decision Forest
+   * @param forestPath
+   *          path to the Decision Forest
+   * @param datasetPath
+   *          dataset path
+   * @param attrNames
+   *          attribute names
+   */
+  public static void print(String forestPath, String datasetPath, String[] attrNames) throws IOException {
+    System.out.println(toString(forestPath, datasetPath, attrNames));
+  }
+  
+  public static void main(String[] args) {
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+
+    Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true)
+      .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
+      .withDescription("Dataset path").create();
+
+    Option modelOpt = obuilder.withLongName("model").withShortName("m").withRequired(true)
+      .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create())
+      .withDescription("Path to the Decision Forest").create();
+
+    Option attrNamesOpt = obuilder.withLongName("names").withShortName("n").withRequired(false)
+      .withArgument(abuilder.withName("names").withMinimum(1).create())
+      .withDescription("Optional, Attribute names").create();
+
+    Option helpOpt = obuilder.withLongName("help").withShortName("h")
+      .withDescription("Print out help").create();
+  
+    Group group = gbuilder.withName("Options").withOption(datasetOpt).withOption(modelOpt)
+      .withOption(attrNamesOpt).withOption(helpOpt).create();
+  
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+      
+      if (cmdLine.hasOption("help")) {
+        CommandLineUtil.printHelp(group);
+        return;
+      }
+  
+      String datasetName = cmdLine.getValue(datasetOpt).toString();
+      String modelName = cmdLine.getValue(modelOpt).toString();
+      String[] attrNames = null;
+      if (cmdLine.hasOption(attrNamesOpt)) {
+        Collection<String> names = (Collection<String>) cmdLine.getValues(attrNamesOpt);
+        if (!names.isEmpty()) {
+          attrNames = new String[names.size()];
+          names.toArray(attrNames);
+        }
+      }
+      
+      print(modelName, datasetName, attrNames);
+    } catch (Exception e) {
+      log.error("Exception", e);
+      CommandLineUtil.printHelp(group);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
new file mode 100644
index 0000000..c37af4e
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
@@ -0,0 +1,122 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.conf.Configured;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * Compute the frequency distribution of the "class label"<br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+@Deprecated
+public final class Frequencies extends Configured implements Tool {
+  
+  private static final Logger log = LoggerFactory.getLogger(Frequencies.class);
+  
+  private Frequencies() { }
+  
+  @Override
+  public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
+    
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+    
+    Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+      abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+    
+    Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+      abuilder.withName("path").withMinimum(1).create()).withDescription("dataset path").create();
+    
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+        .create();
+    
+    Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt).withOption(helpOpt)
+        .create();
+    
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+      
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        return 0;
+      }
+      
+      String dataPath = cmdLine.getValue(dataOpt).toString();
+      String datasetPath = cmdLine.getValue(datasetOpt).toString();
+      
+      log.debug("Data path : {}", dataPath);
+      log.debug("Dataset path : {}", datasetPath);
+      
+      runTool(dataPath, datasetPath);
+    } catch (OptionException e) {
+      log.warn(e.toString(), e);
+      CommandLineUtil.printHelp(group);
+    }
+    
+    return 0;
+  }
+  
+  private void runTool(String data, String dataset) throws IOException,
+                                                   ClassNotFoundException,
+                                                   InterruptedException {
+    
+    FileSystem fs = FileSystem.get(getConf());
+    Path workingDir = fs.getWorkingDirectory();
+    
+    Path dataPath = new Path(data);
+    Path datasetPath = new Path(dataset);
+    
+    log.info("Computing the frequencies...");
+    FrequenciesJob job = new FrequenciesJob(new Path(workingDir, "output"), dataPath, datasetPath);
+    
+    int[][] counts = job.run(getConf());
+    
+    // outputing the frequencies
+    log.info("counts[partition][class]");
+    for (int[] count : counts) {
+      log.info(Arrays.toString(count));
+    }
+  }
+  
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new Configuration(), new Frequencies(), args);
+  }
+  
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
new file mode 100644
index 0000000..9d7e2ff
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
@@ -0,0 +1,297 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.net.URI;
+import java.util.Arrays;
+
+/**
+ * Temporary class used to compute the frequency distribution of the "class attribute".<br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+@Deprecated
+public class FrequenciesJob {
+  
+  private static final Logger log = LoggerFactory.getLogger(FrequenciesJob.class);
+  
+  /** directory that will hold this job's output */
+  private final Path outputPath;
+  
+  /** file that contains the serialized dataset */
+  private final Path datasetPath;
+  
+  /** directory that contains the data used in the first step */
+  private final Path dataPath;
+  
+  /**
+   * @param base
+   *          base directory
+   * @param dataPath
+   *          data used in the first step
+   */
+  public FrequenciesJob(Path base, Path dataPath, Path datasetPath) {
+    this.outputPath = new Path(base, "frequencies.output");
+    this.dataPath = dataPath;
+    this.datasetPath = datasetPath;
+  }
+  
+  /**
+   * @return counts[partition][label] = num tuples from 'partition' with class == label
+   */
+  public int[][] run(Configuration conf) throws IOException, ClassNotFoundException, InterruptedException {
+    
+    // check the output
+    FileSystem fs = outputPath.getFileSystem(conf);
+    if (fs.exists(outputPath)) {
+      throw new IOException("Output path already exists : " + outputPath);
+    }
+    
+    // put the dataset into the DistributedCache
+    URI[] files = {datasetPath.toUri()};
+    DistributedCache.setCacheFiles(files, conf);
+    
+    Job job = new Job(conf);
+    job.setJarByClass(FrequenciesJob.class);
+    
+    FileInputFormat.setInputPaths(job, dataPath);
+    FileOutputFormat.setOutputPath(job, outputPath);
+    
+    job.setMapOutputKeyClass(LongWritable.class);
+    job.setMapOutputValueClass(IntWritable.class);
+    job.setOutputKeyClass(LongWritable.class);
+    job.setOutputValueClass(Frequencies.class);
+    
+    job.setMapperClass(FrequenciesMapper.class);
+    job.setReducerClass(FrequenciesReducer.class);
+    
+    job.setInputFormatClass(TextInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    
+    // run the job
+    boolean succeeded = job.waitForCompletion(true);
+    if (!succeeded) {
+      throw new IllegalStateException("Job failed!");
+    }
+    
+    int[][] counts = parseOutput(job);
+
+    HadoopUtil.delete(conf, outputPath);
+    
+    return counts;
+  }
+  
+  /**
+   * Extracts the output and processes it
+   * 
+   * @return counts[partition][label] = num tuples from 'partition' with class == label
+   */
+  int[][] parseOutput(JobContext job) throws IOException {
+    Configuration conf = job.getConfiguration();
+    
+    int numMaps = conf.getInt("mapred.map.tasks", -1);
+    log.info("mapred.map.tasks = {}", numMaps);
+    
+    FileSystem fs = outputPath.getFileSystem(conf);
+    
+    Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);
+    
+    Frequencies[] values = new Frequencies[numMaps];
+    
+    // read all the outputs
+    int index = 0;
+    for (Path path : outfiles) {
+      for (Frequencies value : new SequenceFileValueIterable<Frequencies>(path, conf)) {
+        values[index++] = value;
+      }
+    }
+    
+    if (index < numMaps) {
+      throw new IllegalStateException("number of output Frequencies (" + index
+          + ") is lesser than the number of mappers!");
+    }
+    
+    // sort the frequencies using the firstIds
+    Arrays.sort(values);
+    return Frequencies.extractCounts(values);
+  }
+  
+  /**
+   * Outputs the first key and the label of each tuple
+   * 
+   */
+  private static class FrequenciesMapper extends Mapper<LongWritable,Text,LongWritable,IntWritable> {
+    
+    private LongWritable firstId;
+    
+    private DataConverter converter;
+    private Dataset dataset;
+    
+    @Override
+    protected void setup(Context context) throws IOException, InterruptedException {
+      Configuration conf = context.getConfiguration();
+      
+      dataset = Builder.loadDataset(conf);
+      setup(dataset);
+    }
+    
+    /**
+     * Useful when testing
+     */
+    void setup(Dataset dataset) {
+      converter = new DataConverter(dataset);
+    }
+    
+    @Override
+    protected void map(LongWritable key, Text value, Context context) throws IOException,
+                                                                     InterruptedException {
+      if (firstId == null) {
+        firstId = new LongWritable(key.get());
+      }
+      
+      Instance instance = converter.convert(value.toString());
+      
+      context.write(firstId, new IntWritable((int) dataset.getLabel(instance)));
+    }
+    
+  }
+  
+  private static class FrequenciesReducer extends Reducer<LongWritable,IntWritable,LongWritable,Frequencies> {
+    
+    private int nblabels;
+    
+    @Override
+    protected void setup(Context context) throws IOException, InterruptedException {
+      Configuration conf = context.getConfiguration();
+      Dataset dataset = Builder.loadDataset(conf);
+      setup(dataset.nblabels());
+    }
+    
+    /**
+     * Useful when testing
+     */
+    void setup(int nblabels) {
+      this.nblabels = nblabels;
+    }
+    
+    @Override
+    protected void reduce(LongWritable key, Iterable<IntWritable> values, Context context)
+      throws IOException, InterruptedException {
+      int[] counts = new int[nblabels];
+      for (IntWritable value : values) {
+        counts[value.get()]++;
+      }
+      
+      context.write(key, new Frequencies(key.get(), counts));
+    }
+  }
+  
+  /**
+   * Output of the job
+   * 
+   */
+  private static class Frequencies implements Writable, Comparable<Frequencies>, Cloneable {
+    
+    /** first key of the partition used to sort the partitions */
+    private long firstId;
+    
+    /** counts[c] = num tuples from the partition with label == c */
+    private int[] counts;
+    
+    Frequencies() { }
+    
+    Frequencies(long firstId, int[] counts) {
+      this.firstId = firstId;
+      this.counts = Arrays.copyOf(counts, counts.length);
+    }
+    
+    @Override
+    public void readFields(DataInput in) throws IOException {
+      firstId = in.readLong();
+      counts = DFUtils.readIntArray(in);
+    }
+    
+    @Override
+    public void write(DataOutput out) throws IOException {
+      out.writeLong(firstId);
+      DFUtils.writeArray(out, counts);
+    }
+    
+    @Override
+    public boolean equals(Object other) {
+      return other instanceof Frequencies && firstId == ((Frequencies) other).firstId;
+    }
+    
+    @Override
+    public int hashCode() {
+      return (int) firstId;
+    }
+    
+    @Override
+    protected Frequencies clone() {
+      return new Frequencies(firstId, counts);
+    }
+    
+    @Override
+    public int compareTo(Frequencies obj) {
+      if (firstId < obj.firstId) {
+        return -1;
+      } else if (firstId > obj.firstId) {
+        return 1;
+      } else {
+        return 0;
+      }
+    }
+    
+    public static int[][] extractCounts(Frequencies[] partitions) {
+      int[][] counts = new int[partitions.length][];
+      for (int p = 0; p < partitions.length; p++) {
+        counts[p] = partitions[p].counts;
+      }
+      return counts;
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
new file mode 100644
index 0000000..a2a3458
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
@@ -0,0 +1,264 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.lang.reflect.Field;
+import java.text.DecimalFormat;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.CategoricalNode;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.classifier.df.node.NumericalNode;
+
+/**
+ * This tool is to visualize the Decision tree
+ */
+@Deprecated
+public final class TreeVisualizer {
+  
+  private TreeVisualizer() {}
+  
+  private static String doubleToString(double value) {
+    DecimalFormat df = new DecimalFormat("0.##");
+    return df.format(value);
+  }
+  
+  private static String toStringNode(Node node, Dataset dataset,
+      String[] attrNames, Map<String,Field> fields, int layer) {
+    
+    StringBuilder buff = new StringBuilder();
+    
+    try {
+      if (node instanceof CategoricalNode) {
+        CategoricalNode cnode = (CategoricalNode) node;
+        int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
+        double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode);
+        Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode);
+        String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset);
+        for (int i = 0; i < attrValues[attr].length; i++) {
+          int index = ArrayUtils.indexOf(values, i);
+          if (index < 0) {
+            continue;
+          }
+          buff.append('\n');
+          for (int j = 0; j < layer; j++) {
+            buff.append("|   ");
+          }
+          buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+              .append(attrValues[attr][i]);
+          buff.append(toStringNode(childs[index], dataset, attrNames, fields, layer + 1));
+        }
+      } else if (node instanceof NumericalNode) {
+        NumericalNode nnode = (NumericalNode) node;
+        int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
+        double split = (Double) fields.get("NumericalNode.split").get(nnode);
+        Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
+        Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
+        buff.append('\n');
+        for (int j = 0; j < layer; j++) {
+          buff.append("|   ");
+        }
+        buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ")
+            .append(doubleToString(split));
+        buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 1));
+        buff.append('\n');
+        for (int j = 0; j < layer; j++) {
+          buff.append("|   ");
+        }
+        buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ")
+            .append(doubleToString(split));
+        buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1));
+      } else if (node instanceof Leaf) {
+        Leaf leaf = (Leaf) node;
+        double label = (Double) fields.get("Leaf.label").get(leaf);
+        if (dataset.isNumerical(dataset.getLabelId())) {
+          buff.append(" : ").append(doubleToString(label));
+        } else {
+          buff.append(" : ").append(dataset.getLabelString(label));
+        }
+      }
+    } catch (IllegalAccessException iae) {
+      throw new IllegalStateException(iae);
+    }
+    
+    return buff.toString();
+  }
+  
+  private static Map<String,Field> getReflectMap() {
+    Map<String,Field> fields = new HashMap<>();
+    
+    try {
+      Field m = CategoricalNode.class.getDeclaredField("attr");
+      m.setAccessible(true);
+      fields.put("CategoricalNode.attr", m);
+      m = CategoricalNode.class.getDeclaredField("values");
+      m.setAccessible(true);
+      fields.put("CategoricalNode.values", m);
+      m = CategoricalNode.class.getDeclaredField("childs");
+      m.setAccessible(true);
+      fields.put("CategoricalNode.childs", m);
+      m = NumericalNode.class.getDeclaredField("attr");
+      m.setAccessible(true);
+      fields.put("NumericalNode.attr", m);
+      m = NumericalNode.class.getDeclaredField("split");
+      m.setAccessible(true);
+      fields.put("NumericalNode.split", m);
+      m = NumericalNode.class.getDeclaredField("loChild");
+      m.setAccessible(true);
+      fields.put("NumericalNode.loChild", m);
+      m = NumericalNode.class.getDeclaredField("hiChild");
+      m.setAccessible(true);
+      fields.put("NumericalNode.hiChild", m);
+      m = Leaf.class.getDeclaredField("label");
+      m.setAccessible(true);
+      fields.put("Leaf.label", m);
+      m = Dataset.class.getDeclaredField("values");
+      m.setAccessible(true);
+      fields.put("Dataset.values", m);
+    } catch (NoSuchFieldException nsfe) {
+      throw new IllegalStateException(nsfe);
+    }
+    
+    return fields;
+  }
+  
+  /**
+   * Decision tree to String
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static String toString(Node tree, Dataset dataset, String[] attrNames) {
+    return toStringNode(tree, dataset, attrNames, getReflectMap(), 0);
+  }
+  
+  /**
+   * Print Decision tree
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static void print(Node tree, Dataset dataset, String[] attrNames) {
+    System.out.println(toString(tree, dataset, attrNames));
+  }
+  
+  private static String toStringPredict(Node node, Instance instance,
+      Dataset dataset, String[] attrNames, Map<String,Field> fields) {
+    StringBuilder buff = new StringBuilder();
+    
+    try {
+      if (node instanceof CategoricalNode) {
+        CategoricalNode cnode = (CategoricalNode) node;
+        int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
+        double[] values = (double[]) fields.get("CategoricalNode.values").get(
+            cnode);
+        Node[] childs = (Node[]) fields.get("CategoricalNode.childs")
+            .get(cnode);
+        String[][] attrValues = (String[][]) fields.get("Dataset.values").get(
+            dataset);
+        
+        int index = ArrayUtils.indexOf(values, instance.get(attr));
+        if (index >= 0) {
+          buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+              .append(attrValues[attr][(int) instance.get(attr)]);
+          buff.append(" -> ");
+          buff.append(toStringPredict(childs[index], instance, dataset,
+              attrNames, fields));
+        }
+      } else if (node instanceof NumericalNode) {
+        NumericalNode nnode = (NumericalNode) node;
+        int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
+        double split = (Double) fields.get("NumericalNode.split").get(nnode);
+        Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
+        Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
+        
+        if (instance.get(attr) < split) {
+          buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+              .append(" = ").append(doubleToString(instance.get(attr)))
+              .append(") < ").append(doubleToString(split));
+          buff.append(" -> ");
+          buff.append(toStringPredict(loChild, instance, dataset, attrNames,
+              fields));
+        } else {
+          buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+              .append(" = ").append(doubleToString(instance.get(attr)))
+              .append(") >= ").append(doubleToString(split));
+          buff.append(" -> ");
+          buff.append(toStringPredict(hiChild, instance, dataset, attrNames,
+              fields));
+        }
+      } else if (node instanceof Leaf) {
+        Leaf leaf = (Leaf) node;
+        double label = (Double) fields.get("Leaf.label").get(leaf);
+        if (dataset.isNumerical(dataset.getLabelId())) {
+          buff.append(doubleToString(label));
+        } else {
+          buff.append(dataset.getLabelString(label));
+        }
+      }
+    } catch (IllegalAccessException iae) {
+      throw new IllegalStateException(iae);
+    }
+    
+    return buff.toString();
+  }
+  
+  /**
+   * Predict trace to String
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static String[] predictTrace(Node tree, Data data, String[] attrNames) {
+    Map<String,Field> reflectMap = getReflectMap();
+    String[] prediction = new String[data.size()];
+    for (int i = 0; i < data.size(); i++) {
+      prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(),
+          attrNames, reflectMap);
+    }
+    return prediction;
+  }
+  
+  /**
+   * Print predict trace
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
+   */
+  public static void predictTracePrint(Node tree, Data data, String[] attrNames) {
+    Map<String,Field> reflectMap = getReflectMap();
+    for (int i = 0; i < data.size(); i++) {
+      System.out.println(toStringPredict(tree, data.get(i), data.getDataset(),
+          attrNames, reflectMap));
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
new file mode 100644
index 0000000..e1b55ab
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
@@ -0,0 +1,212 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Locale;
+import java.util.Random;
+import java.util.Scanner;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This tool is used to uniformly distribute the class of all the tuples of the dataset over a given number of
+ * partitions.<br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+@Deprecated
+public final class UDistrib {
+  
+  private static final Logger log = LoggerFactory.getLogger(UDistrib.class);
+  
+  private UDistrib() {}
+  
+  /**
+   * Launch the uniform distribution tool. Requires the following command line arguments:<br>
+   * 
+   * data : data path dataset : dataset path numpartitions : num partitions output : output path
+   *
+   * @throws java.io.IOException
+   */
+  public static void main(String[] args) throws IOException {
+    
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+    
+    Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+      abuilder.withName("data").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+    
+    Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+      abuilder.withName("dataset").withMinimum(1).create()).withDescription("Dataset path").create();
+    
+    Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument(
+      abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+      "Path to generated files").create();
+    
+    Option partitionsOpt = obuilder.withLongName("numpartitions").withShortName("p").withRequired(true)
+        .withArgument(abuilder.withName("numparts").withMinimum(1).withMinimum(1).create()).withDescription(
+          "Number of partitions to create").create();
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+        .create();
+    
+    Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(outputOpt).withOption(
+      datasetOpt).withOption(partitionsOpt).withOption(helpOpt).create();
+    
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+      
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        return;
+      }
+      
+      String data = cmdLine.getValue(dataOpt).toString();
+      String dataset = cmdLine.getValue(datasetOpt).toString();
+      int numPartitions = Integer.parseInt(cmdLine.getValue(partitionsOpt).toString());
+      String output = cmdLine.getValue(outputOpt).toString();
+      
+      runTool(data, dataset, output, numPartitions);
+    } catch (OptionException e) {
+      log.warn(e.toString(), e);
+      CommandLineUtil.printHelp(group);
+    }
+    
+  }
+  
+  private static void runTool(String dataStr, String datasetStr, String output, int numPartitions) throws IOException {
+
+    Preconditions.checkArgument(numPartitions > 0, "numPartitions <= 0");
+    
+    // make sure the output file does not exist
+    Path outputPath = new Path(output);
+    Configuration conf = new Configuration();
+    FileSystem fs = outputPath.getFileSystem(conf);
+
+    Preconditions.checkArgument(!fs.exists(outputPath), "Output path already exists");
+    
+    // create a new file corresponding to each partition
+    // Path workingDir = fs.getWorkingDirectory();
+    // FileSystem wfs = workingDir.getFileSystem(conf);
+    // File parentFile = new File(workingDir.toString());
+    // File tempFile = FileUtil.createLocalTempFile(parentFile, "Parts", true);
+    // File tempFile = File.createTempFile("df.tools.UDistrib","");
+    // tempFile.deleteOnExit();
+    File tempFile = FileUtil.createLocalTempFile(new File(""), "df.tools.UDistrib", true);
+    Path partsPath = new Path(tempFile.toString());
+    FileSystem pfs = partsPath.getFileSystem(conf);
+    
+    Path[] partPaths = new Path[numPartitions];
+    FSDataOutputStream[] files = new FSDataOutputStream[numPartitions];
+    for (int p = 0; p < numPartitions; p++) {
+      partPaths[p] = new Path(partsPath, String.format(Locale.ENGLISH, "part.%03d", p));
+      files[p] = pfs.create(partPaths[p]);
+    }
+    
+    Path datasetPath = new Path(datasetStr);
+    Dataset dataset = Dataset.load(conf, datasetPath);
+    
+    // currents[label] = next partition file where to place the tuple
+    int[] currents = new int[dataset.nblabels()];
+    
+    // currents is initialized randomly in the range [0, numpartitions[
+    Random random = RandomUtils.getRandom();
+    for (int c = 0; c < currents.length; c++) {
+      currents[c] = random.nextInt(numPartitions);
+    }
+    
+    // foreach tuple of the data
+    Path dataPath = new Path(dataStr);
+    FileSystem ifs = dataPath.getFileSystem(conf);
+    FSDataInputStream input = ifs.open(dataPath);
+    Scanner scanner = new Scanner(input, "UTF-8");
+    DataConverter converter = new DataConverter(dataset);
+    
+    int id = 0;
+    while (scanner.hasNextLine()) {
+      if (id % 1000 == 0) {
+        log.info("progress : {}", id);
+      }
+      
+      String line = scanner.nextLine();
+      if (line.isEmpty()) {
+        continue; // skip empty lines
+      }
+      
+      // write the tuple in files[tuple.label]
+      Instance instance = converter.convert(line);
+      int label = (int) dataset.getLabel(instance);
+      files[currents[label]].writeBytes(line);
+      files[currents[label]].writeChar('\n');
+      
+      // update currents
+      currents[label]++;
+      if (currents[label] == numPartitions) {
+        currents[label] = 0;
+      }
+    }
+    
+    // close all the files.
+    scanner.close();
+    for (FSDataOutputStream file : files) {
+      Closeables.close(file, false);
+    }
+    
+    // merge all output files
+    FileUtil.copyMerge(pfs, partsPath, fs, outputPath, true, conf, null);
+    /*
+     * FSDataOutputStream joined = fs.create(new Path(outputPath, "uniform.data")); for (int p = 0; p <
+     * numPartitions; p++) {log.info("Joining part : {}", p); FSDataInputStream partStream =
+     * fs.open(partPaths[p]);
+     * 
+     * IOUtils.copyBytes(partStream, joined, conf, false);
+     * 
+     * partStream.close(); }
+     * 
+     * joined.close();
+     * 
+     * fs.delete(partsPath, true);
+     */
+  }
+  
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
new file mode 100644
index 0000000..049f9bf
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.evaluation;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.list.DoubleArrayList;
+
+import com.google.common.base.Preconditions;
+
+import java.util.Random;
+
+/**
+ * Computes AUC and a few other accuracy statistics without storing huge amounts of data.  This is
+ * done by keeping uniform samples of the positive and negative scores.  Then, when AUC is to be
+ * computed, the remaining scores are sorted and a rank-sum statistic is used to compute the AUC.
+ * Since AUC is invariant with respect to down-sampling of either positives or negatives, this is
+ * close to correct and is exactly correct if maxBufferSize or fewer positive and negative scores
+ * are examined.
+ */
+public class Auc {
+
+  private int maxBufferSize = 10000;
+  private final DoubleArrayList[] scores = {new DoubleArrayList(), new DoubleArrayList()};
+  private final Random rand;
+  private int samples;
+  private final double threshold;
+  private final Matrix confusion;
+  private final DenseMatrix entropy;
+
+  private boolean probabilityScore = true;
+
+  private boolean hasScore;
+
+  /**
+   * Allocates a new data-structure for accumulating information about AUC and a few other accuracy
+   * measures.
+   * @param threshold The threshold to use in computing the confusion matrix.
+   */
+  public Auc(double threshold) {
+    confusion = new DenseMatrix(2, 2);
+    entropy = new DenseMatrix(2, 2);
+    this.rand = RandomUtils.getRandom();
+    this.threshold = threshold;
+  }
+
+  public Auc() {
+    this(0.5);
+  }
+
+  /**
+   * Adds a score to the AUC buffers.
+   *
+   * @param trueValue Whether this score is for a true-positive or a true-negative example.
+   * @param score     The score for this example.
+   */
+  public void add(int trueValue, double score) {
+    Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1");
+    hasScore = true;
+
+    int predictedClass = score > threshold ? 1 : 0;
+    confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1);
+
+    samples++;
+    if (isProbabilityScore()) {
+      double limited = Math.max(1.0e-20, Math.min(score, 1 - 1.0e-20));
+      double v0 = entropy.get(trueValue, 0);
+      entropy.set(trueValue, 0, (Math.log1p(-limited) - v0) / samples + v0);
+
+      double v1 = entropy.get(trueValue, 1);
+      entropy.set(trueValue, 1, (Math.log(limited) - v1) / samples + v1);
+    }
+
+    // add to buffers
+    DoubleArrayList buf = scores[trueValue];
+    if (buf.size() >= maxBufferSize) {
+      // but if too many points are seen, we insert into a random
+      // place and discard the predecessor.  The random place could
+      // be anywhere, possibly not even in the buffer.
+      // this is a special case of Knuth's permutation algorithm
+      // but since we don't ever shuffle the first maxBufferSize
+      // samples, the result isn't just a fair sample of the prefixes
+      // of all permutations.  The CONTENTs of the result, however,
+      // will be a fair and uniform sample of maxBufferSize elements
+      // chosen from all elements without replacement
+      int index = rand.nextInt(samples);
+      if (index < buf.size()) {
+        buf.set(index, score);
+      }
+    } else {
+      // for small buffers, we collect all points without permuting
+      // since we sort the data later, permuting now would just be
+      // pedantic
+      buf.add(score);
+    }
+  }
+
+  public void add(int trueValue, int predictedClass) {
+    hasScore = false;
+    Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1");
+    confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1);
+  }
+
+  /**
+   * Computes the AUC of points seen so far.  This can be moderately expensive since it requires
+   * that all points that have been retained be sorted.
+   *
+   * @return The value of the Area Under the receiver operating Curve.
+   */
+  public double auc() {
+    Preconditions.checkArgument(hasScore, "Can't compute AUC for classifier without a score");
+    scores[0].sort();
+    scores[1].sort();
+
+    double n0 = scores[0].size();
+    double n1 = scores[1].size();
+
+    if (n0 == 0 || n1 == 0) {
+      return 0.5;
+    }
+
+    // scan the data
+    int i0 = 0;
+    int i1 = 0;
+    int rank = 1;
+    double rankSum = 0;
+    while (i0 < n0 && i1 < n1) {
+
+      double v0 = scores[0].get(i0);
+      double v1 = scores[1].get(i1);
+
+      if (v0 < v1) {
+        i0++;
+        rank++;
+      } else if (v1 < v0) {
+        i1++;
+        rankSum += rank;
+        rank++;
+      } else {
+        // ties have to be handled delicately
+        double tieScore = v0;
+
+        // how many negatives are tied?
+        int k0 = 0;
+        while (i0 < n0 && scores[0].get(i0) == tieScore) {
+          k0++;
+          i0++;
+        }
+
+        // and how many positives
+        int k1 = 0;
+        while (i1 < n1 && scores[1].get(i1) == tieScore) {
+          k1++;
+          i1++;
+        }
+
+        // we found k0 + k1 tied values which have
+        // ranks in the half open interval [rank, rank + k0 + k1)
+        // the average rank is assigned to all
+        rankSum += (rank + (k0 + k1 - 1) / 2.0) * k1;
+        rank += k0 + k1;
+      }
+    }
+
+    if (i1 < n1) {
+      rankSum += (rank + (n1 - i1 - 1) / 2.0) * (n1 - i1);
+      rank += (int) (n1 - i1);
+    }
+
+    return (rankSum / n1 - (n1 + 1) / 2) / n0;
+  }
+
+  /**
+   * Returns the confusion matrix for the classifier supposing that we were to use a particular
+   * threshold.
+   * @return The confusion matrix.
+   */
+  public Matrix confusion() {
+    return confusion;
+  }
+
+  /**
+   * Returns a matrix related to the confusion matrix and to the log-likelihood.  For a
+   * pretty accurate classifier, N + entropy is nearly the same as the confusion matrix
+   * because log(1-eps) \approx -eps if eps is small.
+   *
+   * For lower accuracy classifiers, this measure will give us a better picture of how
+   * things work our.
+   *
+   * Also, by definition, log-likelihood = sum(diag(entropy))
+   * @return Returns a cell by cell break-down of the log-likelihood
+   */
+  public Matrix entropy() {
+    if (!hasScore) {
+      // find a constant score that would optimize log-likelihood, but use a dash of Bayesian
+      // conservatism to avoid dividing by zero or taking log(0)
+      double p = (0.5 + confusion.get(1, 1)) / (1 + confusion.get(0, 0) + confusion.get(1, 1));
+      entropy.set(0, 0, confusion.get(0, 0) * Math.log1p(-p));
+      entropy.set(0, 1, confusion.get(0, 1) * Math.log(p));
+      entropy.set(1, 0, confusion.get(1, 0) * Math.log1p(-p));
+      entropy.set(1, 1, confusion.get(1, 1) * Math.log(p));
+    }
+    return entropy;
+  }
+
+  public void setMaxBufferSize(int maxBufferSize) {
+    this.maxBufferSize = maxBufferSize;
+  }
+
+  public boolean isProbabilityScore() {
+    return probabilityScore;
+  }
+
+  public void setProbabilityScore(boolean probabilityScore) {
+    this.probabilityScore = probabilityScore;
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
new file mode 100644
index 0000000..f0794b3
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm. Note that this class
+ * supports {@link #classifyFull}, but not {@code classify} or
+ * {@code classifyScalar}. The reason that these two methods are not
+ * supported is because the scores computed by a NaiveBayesClassifier do not
+ * represent probabilities.
+ */
+public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier {
+
+  private final NaiveBayesModel model;
+  
+  protected AbstractNaiveBayesClassifier(NaiveBayesModel model) {
+    this.model = model;
+  }
+
+  protected NaiveBayesModel getModel() {
+    return model;
+  }
+  
+  protected abstract double getScoreForLabelFeature(int label, int feature);
+
+  protected double getScoreForLabelInstance(int label, Vector instance) {
+    double result = 0.0;
+    for (Element e : instance.nonZeroes()) {
+      result += e.get() * getScoreForLabelFeature(label, e.index());
+    }
+    return result;
+  }
+  
+  @Override
+  public int numCategories() {
+    return model.numLabels();
+  }
+
+  @Override
+  public Vector classifyFull(Vector instance) {
+    return classifyFull(model.createScoringVector(), instance);
+  }
+  
+  @Override
+  public Vector classifyFull(Vector r, Vector instance) {
+    for (int label = 0; label < model.numLabels(); label++) {
+      r.setQuick(label, getScoreForLabelInstance(label, instance));
+    }
+    return r;
+  }
+
+  /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */
+  @Override
+  public double classifyScalar(Vector instance) {
+    throw new UnsupportedOperationException("Not supported in Naive Bayes");
+  }
+  
+  /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */
+  @Override
+  public Vector classify(Vector instance) {
+    throw new UnsupportedOperationException("probabilites not supported in Naive Bayes");
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
new file mode 100644
index 0000000..4db8b17
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.regex.Pattern;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.naivebayes.training.ThetaMapper;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public final class BayesUtils {
+
+  private static final Pattern SLASH = Pattern.compile("/");
+
+  private BayesUtils() {}
+
+  public static NaiveBayesModel readModelFromDir(Path base, Configuration conf) {
+
+    float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);
+    boolean isComplementary = conf.getBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, true);
+
+    // read feature sums and label sums
+    Vector scoresPerLabel = null;
+    Vector scoresPerFeature = null;
+    for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>(
+        new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) {
+      String key = record.getFirst().toString();
+      VectorWritable value = record.getSecond();
+      if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {
+        scoresPerFeature = value.get();
+      } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {
+        scoresPerLabel = value.get();
+      }
+    }
+
+    Preconditions.checkNotNull(scoresPerFeature);
+    Preconditions.checkNotNull(scoresPerLabel);
+
+    Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size());
+    for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>(
+        new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) {
+      scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());
+    }
+    
+    // perLabelThetaNormalizer is only used by the complementary model, we do not instantiate it for the standard model
+    Vector perLabelThetaNormalizer = null;
+    if (isComplementary) {
+      perLabelThetaNormalizer=scoresPerLabel.like();    
+      for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(
+          new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(), conf)) {
+        if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) {
+          perLabelThetaNormalizer = entry.getSecond().get();
+        }
+      }
+      Preconditions.checkNotNull(perLabelThetaNormalizer);
+    }
+     
+    return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perLabelThetaNormalizer,
+        alphaI, isComplementary);
+  }
+
+  /** Write the list of labels into a map file */
+  public static int writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath)
+    throws IOException {
+    FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+    int i = 0;
+    try (SequenceFile.Writer writer =
+           SequenceFile.createWriter(fs.getConf(), SequenceFile.Writer.file(indexPath),
+             SequenceFile.Writer.keyClass(Text.class), SequenceFile.Writer.valueClass(IntWritable.class))) {
+      for (String label : labels) {
+        writer.append(new Text(label), new IntWritable(i++));
+      }
+    }
+    return i;
+  }
+
+  public static int writeLabelIndex(Configuration conf, Path indexPath,
+                                    Iterable<Pair<Text,IntWritable>> labels) throws IOException {
+    FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+    Collection<String> seen = new HashSet<>();
+    int i = 0;
+    try (SequenceFile.Writer writer =
+           SequenceFile.createWriter(fs.getConf(), SequenceFile.Writer.file(indexPath),
+             SequenceFile.Writer.keyClass(Text.class), SequenceFile.Writer.valueClass(IntWritable.class))){
+      for (Object label : labels) {
+        String theLabel = SLASH.split(((Pair<?, ?>) label).getFirst().toString())[1];
+        if (!seen.contains(theLabel)) {
+          writer.append(new Text(theLabel), new IntWritable(i++));
+          seen.add(theLabel);
+        }
+      }
+    }
+    return i;
+  }
+
+  public static Map<Integer, String> readLabelIndex(Configuration conf, Path indexPath) {
+    Map<Integer, String> labelMap = new HashMap<>();
+    for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(indexPath, true, conf)) {
+      labelMap.put(pair.getSecond().get(), pair.getFirst().toString());
+    }
+    return labelMap;
+  }
+
+  public static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException {
+    OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<>();
+    for (Pair<Writable,IntWritable> entry
+        : new SequenceFileIterable<Writable,IntWritable>(HadoopUtil.getSingleCachedFile(conf), conf)) {
+      index.put(entry.getFirst().toString(), entry.getSecond().get());
+    }
+    return index;
+  }
+
+  public static Map<String,Vector> readScoresFromCache(Configuration conf) throws IOException {
+    Map<String,Vector> sumVectors = new HashMap<>();
+    for (Pair<Text,VectorWritable> entry
+        : new SequenceFileDirIterable<Text,VectorWritable>(HadoopUtil.getSingleCachedFile(conf),
+          PathType.LIST, PathFilters.partFilter(), conf)) {
+      sumVectors.put(entry.getFirst().toString(), entry.getSecond().get());
+    }
+    return sumVectors;
+  }
+
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
new file mode 100644
index 0000000..18bd3d6
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/** Implementation of the Naive Bayes Classifier Algorithm */
+public class ComplementaryNaiveBayesClassifier extends AbstractNaiveBayesClassifier {
+  public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) {
+    super(model);
+  }
+
+  @Override
+  public double getScoreForLabelFeature(int label, int feature) {
+    NaiveBayesModel model = getModel();
+    double weight = computeWeight(model.featureWeight(feature), model.weight(label, feature),
+        model.totalWeightSum(), model.labelWeight(label), model.alphaI(), model.numFeatures());
+    // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+    return weight / model.thetaNormalizer(label);
+  }
+
+  // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.1, Skewed Data bias
+  public static double computeWeight(double featureWeight, double featureLabelWeight,
+      double totalWeight, double labelWeight, double alphaI, double numFeatures) {
+    double numerator = featureWeight - featureLabelWeight + alphaI;
+    double denominator = totalWeight - labelWeight + alphaI * numFeatures;
+    return -Math.log(numerator / denominator);
+  }
+}