You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/28 14:54:41 UTC
[13/51] [partial] mahout git commit: NO-JIRA Clean up MR refactor
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
new file mode 100644
index 0000000..56b1a04
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.commons.math3.stat.descriptive.rank.Percentile;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataUtils;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.TreeSet;
+
+/**
+ * <p>Optimized implementation of IgSplit.
+ * This class can be used when the criterion variable is the categorical attribute.</p>
+ *
+ * <p>This code was changed in MAHOUT-1419 to deal in sampled splits among numeric
+ * features to fix a performance problem. To generate some synthetic data that exercises
+ * the issue, try for example generating 4 features of Normal(0,1) values with a random
+ * boolean 0/1 categorical feature. In Scala:</p>
+ *
+ * {@code
+ * val r = new scala.util.Random()
+ * val pw = new java.io.PrintWriter("random.csv")
+ * (1 to 10000000).foreach(e =>
+ * pw.println(r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * (if (r.nextBoolean()) 1 else 0))
+ * )
+ * pw.close()
+ * }
+ */
+@Deprecated
+public class OptIgSplit extends IgSplit {
+
+ private static final int MAX_NUMERIC_SPLITS = 16;
+
+ @Override
+ public Split computeSplit(Data data, int attr) {
+ if (data.getDataset().isNumerical(attr)) {
+ return numericalSplit(data, attr);
+ } else {
+ return categoricalSplit(data, attr);
+ }
+ }
+
+ /**
+ * Computes the split for a CATEGORICAL attribute
+ */
+ private static Split categoricalSplit(Data data, int attr) {
+ double[] values = data.values(attr).clone();
+
+ double[] splitPoints = chooseCategoricalSplitPoints(values);
+
+ int numLabels = data.getDataset().nblabels();
+ int[][] counts = new int[splitPoints.length][numLabels];
+ int[] countAll = new int[numLabels];
+
+ computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+ int size = data.size();
+ double hy = entropy(countAll, size); // H(Y)
+ double hyx = 0.0; // H(Y|X)
+ double invDataSize = 1.0 / size;
+
+ for (int index = 0; index < splitPoints.length; index++) {
+ size = DataUtils.sum(counts[index]);
+ hyx += size * invDataSize * entropy(counts[index], size);
+ }
+
+ double ig = hy - hyx;
+ return new Split(attr, ig);
+ }
+
+ static void computeFrequencies(Data data,
+ int attr,
+ double[] splitPoints,
+ int[][] counts,
+ int[] countAll) {
+ Dataset dataset = data.getDataset();
+
+ for (int index = 0; index < data.size(); index++) {
+ Instance instance = data.get(index);
+ int label = (int) dataset.getLabel(instance);
+ double value = instance.get(attr);
+ int split = 0;
+ while (split < splitPoints.length && value > splitPoints[split]) {
+ split++;
+ }
+ if (split < splitPoints.length) {
+ counts[split][label]++;
+ } // Otherwise it's in the last split, which we don't need to count
+ countAll[label]++;
+ }
+ }
+
+ /**
+ * Computes the best split for a NUMERICAL attribute
+ */
+ static Split numericalSplit(Data data, int attr) {
+ double[] values = data.values(attr).clone();
+ Arrays.sort(values);
+
+ double[] splitPoints = chooseNumericSplitPoints(values);
+
+ int numLabels = data.getDataset().nblabels();
+ int[][] counts = new int[splitPoints.length][numLabels];
+ int[] countAll = new int[numLabels];
+ int[] countLess = new int[numLabels];
+
+ computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+ int size = data.size();
+ double hy = entropy(countAll, size);
+ double invDataSize = 1.0 / size;
+
+ int best = -1;
+ double bestIg = -1.0;
+
+ // try each possible split value
+ for (int index = 0; index < splitPoints.length; index++) {
+ double ig = hy;
+
+ DataUtils.add(countLess, counts[index]);
+ DataUtils.dec(countAll, counts[index]);
+
+ // instance with attribute value < values[index]
+ size = DataUtils.sum(countLess);
+ ig -= size * invDataSize * entropy(countLess, size);
+ // instance with attribute value >= values[index]
+ size = DataUtils.sum(countAll);
+ ig -= size * invDataSize * entropy(countAll, size);
+
+ if (ig > bestIg) {
+ bestIg = ig;
+ best = index;
+ }
+ }
+
+ if (best == -1) {
+ throw new IllegalStateException("no best split found !");
+ }
+ return new Split(attr, bestIg, splitPoints[best]);
+ }
+
+ /**
+ * @return an array of values to split the numeric feature's values on when
+ * building candidate splits. When input size is <= MAX_NUMERIC_SPLITS + 1, it will
+ * return the averages between success values as split points. When larger, it will
+ * return MAX_NUMERIC_SPLITS approximate percentiles through the data.
+ */
+ private static double[] chooseNumericSplitPoints(double[] values) {
+ if (values.length <= 1) {
+ return values;
+ }
+ if (values.length <= MAX_NUMERIC_SPLITS + 1) {
+ double[] splitPoints = new double[values.length - 1];
+ for (int i = 1; i < values.length; i++) {
+ splitPoints[i-1] = (values[i] + values[i-1]) / 2.0;
+ }
+ return splitPoints;
+ }
+ Percentile distribution = new Percentile();
+ distribution.setData(values);
+ double[] percentiles = new double[MAX_NUMERIC_SPLITS];
+ for (int i = 0 ; i < percentiles.length; i++) {
+ double p = 100.0 * ((i + 1.0) / (MAX_NUMERIC_SPLITS + 1.0));
+ percentiles[i] = distribution.evaluate(p);
+ }
+ return percentiles;
+ }
+
+ private static double[] chooseCategoricalSplitPoints(double[] values) {
+ // There is no great reason to believe that categorical value order matters,
+ // but the original code worked this way, and it's not terrible in the absence
+ // of more sophisticated analysis
+ Collection<Double> uniqueOrderedCategories = new TreeSet<>();
+ for (double v : values) {
+ uniqueOrderedCategories.add(v);
+ }
+ double[] uniqueValues = new double[uniqueOrderedCategories.size()];
+ Iterator<Double> it = uniqueOrderedCategories.iterator();
+ for (int i = 0; i < uniqueValues.length; i++) {
+ uniqueValues[i] = it.next();
+ }
+ return uniqueValues;
+ }
+
+ /**
+ * Computes the Entropy
+ *
+ * @param counts counts[i] = numInstances with label i
+ * @param dataSize numInstances
+ */
+ private static double entropy(int[] counts, int dataSize) {
+ if (dataSize == 0) {
+ return 0.0;
+ }
+
+ double entropy = 0.0;
+
+ for (int count : counts) {
+ if (count > 0) {
+ double p = count / (double) dataSize;
+ entropy -= p * Math.log(p);
+ }
+ }
+
+ return entropy / LOG2;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
new file mode 100644
index 0000000..38695a3
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
@@ -0,0 +1,177 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Comparator;
+
+/**
+ * Regression problem implementation of IgSplit. This class can be used when the criterion variable is the numerical
+ * attribute.
+ */
+@Deprecated
+public class RegressionSplit extends IgSplit {
+
+ /**
+ * Comparator for Instance sort
+ */
+ private static class InstanceComparator implements Comparator<Instance>, Serializable {
+ private final int attr;
+
+ InstanceComparator(int attr) {
+ this.attr = attr;
+ }
+
+ @Override
+ public int compare(Instance arg0, Instance arg1) {
+ return Double.compare(arg0.get(attr), arg1.get(attr));
+ }
+ }
+
+ @Override
+ public Split computeSplit(Data data, int attr) {
+ if (data.getDataset().isNumerical(attr)) {
+ return numericalSplit(data, attr);
+ } else {
+ return categoricalSplit(data, attr);
+ }
+ }
+
+ /**
+ * Computes the split for a CATEGORICAL attribute
+ */
+ private static Split categoricalSplit(Data data, int attr) {
+ FullRunningAverage[] ra = new FullRunningAverage[data.getDataset().nbValues(attr)];
+ double[] sk = new double[data.getDataset().nbValues(attr)];
+ for (int i = 0; i < ra.length; i++) {
+ ra[i] = new FullRunningAverage();
+ }
+ FullRunningAverage totalRa = new FullRunningAverage();
+ double totalSk = 0.0;
+
+ for (int i = 0; i < data.size(); i++) {
+ // computes the variance
+ Instance instance = data.get(i);
+ int value = (int) instance.get(attr);
+ double xk = data.getDataset().getLabel(instance);
+ if (ra[value].getCount() == 0) {
+ ra[value].addDatum(xk);
+ sk[value] = 0.0;
+ } else {
+ double mk = ra[value].getAverage();
+ ra[value].addDatum(xk);
+ sk[value] += (xk - mk) * (xk - ra[value].getAverage());
+ }
+
+ // total variance
+ if (i == 0) {
+ totalRa.addDatum(xk);
+ totalSk = 0.0;
+ } else {
+ double mk = totalRa.getAverage();
+ totalRa.addDatum(xk);
+ totalSk += (xk - mk) * (xk - totalRa.getAverage());
+ }
+ }
+
+ // computes the variance gain
+ double ig = totalSk;
+ for (double aSk : sk) {
+ ig -= aSk;
+ }
+
+ return new Split(attr, ig);
+ }
+
+ /**
+ * Computes the best split for a NUMERICAL attribute
+ */
+ private static Split numericalSplit(Data data, int attr) {
+ FullRunningAverage[] ra = new FullRunningAverage[2];
+ for (int i = 0; i < ra.length; i++) {
+ ra[i] = new FullRunningAverage();
+ }
+
+ // Instance sort
+ Instance[] instances = new Instance[data.size()];
+ for (int i = 0; i < data.size(); i++) {
+ instances[i] = data.get(i);
+ }
+ Arrays.sort(instances, new InstanceComparator(attr));
+
+ double[] sk = new double[2];
+ for (Instance instance : instances) {
+ double xk = data.getDataset().getLabel(instance);
+ if (ra[1].getCount() == 0) {
+ ra[1].addDatum(xk);
+ sk[1] = 0.0;
+ } else {
+ double mk = ra[1].getAverage();
+ ra[1].addDatum(xk);
+ sk[1] += (xk - mk) * (xk - ra[1].getAverage());
+ }
+ }
+ double totalSk = sk[1];
+
+ // find the best split point
+ double split = Double.NaN;
+ double preSplit = Double.NaN;
+ double bestVal = Double.MAX_VALUE;
+ double bestSk = 0.0;
+
+ // computes total variance
+ for (Instance instance : instances) {
+ double xk = data.getDataset().getLabel(instance);
+
+ if (instance.get(attr) > preSplit) {
+ double curVal = sk[0] / ra[0].getCount() + sk[1] / ra[1].getCount();
+ if (curVal < bestVal) {
+ bestVal = curVal;
+ bestSk = sk[0] + sk[1];
+ split = (instance.get(attr) + preSplit) / 2.0;
+ }
+ }
+
+ // computes the variance
+ if (ra[0].getCount() == 0) {
+ ra[0].addDatum(xk);
+ sk[0] = 0.0;
+ } else {
+ double mk = ra[0].getAverage();
+ ra[0].addDatum(xk);
+ sk[0] += (xk - mk) * (xk - ra[0].getAverage());
+ }
+
+ double mk = ra[1].getAverage();
+ ra[1].removeDatum(xk);
+ sk[1] -= (xk - mk) * (xk - ra[1].getAverage());
+
+ preSplit = instance.get(attr);
+ }
+
+ // computes the variance gain
+ double ig = totalSk - bestSk;
+
+ return new Split(attr, ig, split);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
new file mode 100644
index 0000000..2a6a322
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import java.util.Locale;
+
+/**
+ * Contains enough information to identify each split
+ */
+@Deprecated
+public final class Split {
+
+ private final int attr;
+ private final double ig;
+ private final double split;
+
+ public Split(int attr, double ig, double split) {
+ this.attr = attr;
+ this.ig = ig;
+ this.split = split;
+ }
+
+ public Split(int attr, double ig) {
+ this(attr, ig, Double.NaN);
+ }
+
+ /**
+ * @return attribute to split for
+ */
+ public int getAttr() {
+ return attr;
+ }
+
+ /**
+ * @return Information Gain of the split
+ */
+ public double getIg() {
+ return ig;
+ }
+
+ /**
+ * @return split value for NUMERICAL attributes
+ */
+ public double getSplit() {
+ return split;
+ }
+
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "attr: %d, ig: %f, split: %f", attr, ig, split);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
new file mode 100644
index 0000000..f29faed
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
@@ -0,0 +1,166 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.DescriptorUtils;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Generates a file descriptor for a given dataset
+ */
+public final class Describe implements Tool {
+
+ private static final Logger log = LoggerFactory.getLogger(Describe.class);
+
+ private Describe() {}
+
+ public static int main(String[] args) throws Exception {
+ return ToolRunner.run(new Describe(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option pathOpt = obuilder.withLongName("path").withShortName("p").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+ Option descriptorOpt = obuilder.withLongName("descriptor").withShortName("d").withRequired(true)
+ .withArgument(abuilder.withName("descriptor").withMinimum(1).create()).withDescription(
+ "data descriptor").create();
+
+ Option descPathOpt = obuilder.withLongName("file").withShortName("f").withRequired(true).withArgument(
+ abuilder.withName("file").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Path to generated descriptor file").create();
+
+ Option regOpt = obuilder.withLongName("regression").withDescription("Regression Problem").withShortName("r")
+ .create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(pathOpt).withOption(descPathOpt).withOption(
+ descriptorOpt).withOption(regOpt).withOption(helpOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return -1;
+ }
+
+ String dataPath = cmdLine.getValue(pathOpt).toString();
+ String descPath = cmdLine.getValue(descPathOpt).toString();
+ List<String> descriptor = convert(cmdLine.getValues(descriptorOpt));
+ boolean regression = cmdLine.hasOption(regOpt);
+
+ log.debug("Data path : {}", dataPath);
+ log.debug("Descriptor path : {}", descPath);
+ log.debug("Descriptor : {}", descriptor);
+ log.debug("Regression : {}", regression);
+
+ runTool(dataPath, descriptor, descPath, regression);
+ } catch (OptionException e) {
+ log.warn(e.toString());
+ CommandLineUtil.printHelp(group);
+ }
+ return 0;
+ }
+
+ private void runTool(String dataPath, Iterable<String> description, String filePath, boolean regression)
+ throws DescriptorException, IOException {
+ log.info("Generating the descriptor...");
+ String descriptor = DescriptorUtils.generateDescriptor(description);
+
+ Path fPath = validateOutput(filePath);
+
+ log.info("generating the dataset...");
+ Dataset dataset = generateDataset(descriptor, dataPath, regression);
+
+ log.info("storing the dataset description");
+ String json = dataset.toJSON();
+ DFUtils.storeString(conf, fPath, json);
+ }
+
+ private Dataset generateDataset(String descriptor, String dataPath, boolean regression) throws IOException,
+ DescriptorException {
+ Path path = new Path(dataPath);
+ FileSystem fs = path.getFileSystem(conf);
+
+ return DataLoader.generateDataset(descriptor, regression, fs, path);
+ }
+
+ private Path validateOutput(String filePath) throws IOException {
+ Path path = new Path(filePath);
+ FileSystem fs = path.getFileSystem(conf);
+ if (fs.exists(path)) {
+ throw new IllegalStateException("Descriptor's file already exists");
+ }
+
+ return path;
+ }
+
+ private static List<String> convert(Collection<?> values) {
+ List<String> list = new ArrayList<>(values.size());
+ for (Object value : values) {
+ list.add(value.toString());
+ }
+ return list;
+ }
+
+ private Configuration conf;
+
+ @Override
+ public void setConf(Configuration entries) {
+ this.conf = entries;
+ }
+
+ @Override
+ public Configuration getConf() {
+ return conf;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
new file mode 100644
index 0000000..b421c4e
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
@@ -0,0 +1,158 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This tool is to visualize the Decision Forest
+ */
+@Deprecated
+public final class ForestVisualizer {
+
+ private static final Logger log = LoggerFactory.getLogger(ForestVisualizer.class);
+
+ private ForestVisualizer() {
+ }
+
+ public static String toString(DecisionForest forest, Dataset dataset, String[] attrNames) {
+
+ List<Node> trees;
+ try {
+ Method getTrees = forest.getClass().getDeclaredMethod("getTrees");
+ getTrees.setAccessible(true);
+ trees = (List<Node>) getTrees.invoke(forest);
+ } catch (IllegalAccessException e) {
+ throw new IllegalStateException(e);
+ } catch (InvocationTargetException e) {
+ throw new IllegalStateException(e);
+ } catch (NoSuchMethodException e) {
+ throw new IllegalStateException(e);
+ }
+
+ int cnt = 1;
+ StringBuilder buff = new StringBuilder();
+ for (Node tree : trees) {
+ buff.append("Tree[").append(cnt).append("]:");
+ buff.append(TreeVisualizer.toString(tree, dataset, attrNames));
+ buff.append('\n');
+ cnt++;
+ }
+ return buff.toString();
+ }
+
+ /**
+ * Decision Forest to String
+ * @param forestPath
+ * path to the Decision Forest
+ * @param datasetPath
+ * dataset path
+ * @param attrNames
+ * attribute names
+ */
+ public static String toString(String forestPath, String datasetPath, String[] attrNames) throws IOException {
+ Configuration conf = new Configuration();
+ DecisionForest forest = DecisionForest.load(conf, new Path(forestPath));
+ Dataset dataset = Dataset.load(conf, new Path(datasetPath));
+ return toString(forest, dataset, attrNames);
+ }
+
+ /**
+ * Print Decision Forest
+ * @param forestPath
+ * path to the Decision Forest
+ * @param datasetPath
+ * dataset path
+ * @param attrNames
+ * attribute names
+ */
+ public static void print(String forestPath, String datasetPath, String[] attrNames) throws IOException {
+ System.out.println(toString(forestPath, datasetPath, attrNames));
+ }
+
+ public static void main(String[] args) {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true)
+ .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
+ .withDescription("Dataset path").create();
+
+ Option modelOpt = obuilder.withLongName("model").withShortName("m").withRequired(true)
+ .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create())
+ .withDescription("Path to the Decision Forest").create();
+
+ Option attrNamesOpt = obuilder.withLongName("names").withShortName("n").withRequired(false)
+ .withArgument(abuilder.withName("names").withMinimum(1).create())
+ .withDescription("Optional, Attribute names").create();
+
+ Option helpOpt = obuilder.withLongName("help").withShortName("h")
+ .withDescription("Print out help").create();
+
+ Group group = gbuilder.withName("Options").withOption(datasetOpt).withOption(modelOpt)
+ .withOption(attrNamesOpt).withOption(helpOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption("help")) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+
+ String datasetName = cmdLine.getValue(datasetOpt).toString();
+ String modelName = cmdLine.getValue(modelOpt).toString();
+ String[] attrNames = null;
+ if (cmdLine.hasOption(attrNamesOpt)) {
+ Collection<String> names = (Collection<String>) cmdLine.getValues(attrNamesOpt);
+ if (!names.isEmpty()) {
+ attrNames = new String[names.size()];
+ names.toArray(attrNames);
+ }
+ }
+
+ print(modelName, datasetName, attrNames);
+ } catch (Exception e) {
+ log.error("Exception", e);
+ CommandLineUtil.printHelp(group);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
new file mode 100644
index 0000000..c37af4e
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
@@ -0,0 +1,122 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.conf.Configured;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * Compute the frequency distribution of the "class label"<br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+@Deprecated
+public final class Frequencies extends Configured implements Tool {
+
+ private static final Logger log = LoggerFactory.getLogger(Frequencies.class);
+
+ private Frequencies() { }
+
+ @Override
+ public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).create()).withDescription("dataset path").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt).withOption(helpOpt)
+ .create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return 0;
+ }
+
+ String dataPath = cmdLine.getValue(dataOpt).toString();
+ String datasetPath = cmdLine.getValue(datasetOpt).toString();
+
+ log.debug("Data path : {}", dataPath);
+ log.debug("Dataset path : {}", datasetPath);
+
+ runTool(dataPath, datasetPath);
+ } catch (OptionException e) {
+ log.warn(e.toString(), e);
+ CommandLineUtil.printHelp(group);
+ }
+
+ return 0;
+ }
+
+ private void runTool(String data, String dataset) throws IOException,
+ ClassNotFoundException,
+ InterruptedException {
+
+ FileSystem fs = FileSystem.get(getConf());
+ Path workingDir = fs.getWorkingDirectory();
+
+ Path dataPath = new Path(data);
+ Path datasetPath = new Path(dataset);
+
+ log.info("Computing the frequencies...");
+ FrequenciesJob job = new FrequenciesJob(new Path(workingDir, "output"), dataPath, datasetPath);
+
+ int[][] counts = job.run(getConf());
+
+ // outputing the frequencies
+ log.info("counts[partition][class]");
+ for (int[] count : counts) {
+ log.info(Arrays.toString(count));
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new Frequencies(), args);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
new file mode 100644
index 0000000..9d7e2ff
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
@@ -0,0 +1,297 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.net.URI;
+import java.util.Arrays;
+
+/**
+ * Temporary class used to compute the frequency distribution of the "class attribute".<br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+@Deprecated
+public class FrequenciesJob {
+
+ private static final Logger log = LoggerFactory.getLogger(FrequenciesJob.class);
+
+ /** directory that will hold this job's output */
+ private final Path outputPath;
+
+ /** file that contains the serialized dataset */
+ private final Path datasetPath;
+
+ /** directory that contains the data used in the first step */
+ private final Path dataPath;
+
+ /**
+ * @param base
+ * base directory
+ * @param dataPath
+ * data used in the first step
+ */
+ public FrequenciesJob(Path base, Path dataPath, Path datasetPath) {
+ this.outputPath = new Path(base, "frequencies.output");
+ this.dataPath = dataPath;
+ this.datasetPath = datasetPath;
+ }
+
+ /**
+ * @return counts[partition][label] = num tuples from 'partition' with class == label
+ */
+ public int[][] run(Configuration conf) throws IOException, ClassNotFoundException, InterruptedException {
+
+ // check the output
+ FileSystem fs = outputPath.getFileSystem(conf);
+ if (fs.exists(outputPath)) {
+ throw new IOException("Output path already exists : " + outputPath);
+ }
+
+ // put the dataset into the DistributedCache
+ URI[] files = {datasetPath.toUri()};
+ DistributedCache.setCacheFiles(files, conf);
+
+ Job job = new Job(conf);
+ job.setJarByClass(FrequenciesJob.class);
+
+ FileInputFormat.setInputPaths(job, dataPath);
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ job.setMapOutputKeyClass(LongWritable.class);
+ job.setMapOutputValueClass(IntWritable.class);
+ job.setOutputKeyClass(LongWritable.class);
+ job.setOutputValueClass(Frequencies.class);
+
+ job.setMapperClass(FrequenciesMapper.class);
+ job.setReducerClass(FrequenciesReducer.class);
+
+ job.setInputFormatClass(TextInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ // run the job
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+
+ int[][] counts = parseOutput(job);
+
+ HadoopUtil.delete(conf, outputPath);
+
+ return counts;
+ }
+
+ /**
+ * Extracts the output and processes it
+ *
+ * @return counts[partition][label] = num tuples from 'partition' with class == label
+ */
+ int[][] parseOutput(JobContext job) throws IOException {
+ Configuration conf = job.getConfiguration();
+
+ int numMaps = conf.getInt("mapred.map.tasks", -1);
+ log.info("mapred.map.tasks = {}", numMaps);
+
+ FileSystem fs = outputPath.getFileSystem(conf);
+
+ Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);
+
+ Frequencies[] values = new Frequencies[numMaps];
+
+ // read all the outputs
+ int index = 0;
+ for (Path path : outfiles) {
+ for (Frequencies value : new SequenceFileValueIterable<Frequencies>(path, conf)) {
+ values[index++] = value;
+ }
+ }
+
+ if (index < numMaps) {
+ throw new IllegalStateException("number of output Frequencies (" + index
+ + ") is lesser than the number of mappers!");
+ }
+
+ // sort the frequencies using the firstIds
+ Arrays.sort(values);
+ return Frequencies.extractCounts(values);
+ }
+
+ /**
+ * Outputs the first key and the label of each tuple
+ *
+ */
+ private static class FrequenciesMapper extends Mapper<LongWritable,Text,LongWritable,IntWritable> {
+
+ private LongWritable firstId;
+
+ private DataConverter converter;
+ private Dataset dataset;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+
+ dataset = Builder.loadDataset(conf);
+ setup(dataset);
+ }
+
+ /**
+ * Useful when testing
+ */
+ void setup(Dataset dataset) {
+ converter = new DataConverter(dataset);
+ }
+
+ @Override
+ protected void map(LongWritable key, Text value, Context context) throws IOException,
+ InterruptedException {
+ if (firstId == null) {
+ firstId = new LongWritable(key.get());
+ }
+
+ Instance instance = converter.convert(value.toString());
+
+ context.write(firstId, new IntWritable((int) dataset.getLabel(instance)));
+ }
+
+ }
+
+ private static class FrequenciesReducer extends Reducer<LongWritable,IntWritable,LongWritable,Frequencies> {
+
+ private int nblabels;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+ Dataset dataset = Builder.loadDataset(conf);
+ setup(dataset.nblabels());
+ }
+
+ /**
+ * Useful when testing
+ */
+ void setup(int nblabels) {
+ this.nblabels = nblabels;
+ }
+
+ @Override
+ protected void reduce(LongWritable key, Iterable<IntWritable> values, Context context)
+ throws IOException, InterruptedException {
+ int[] counts = new int[nblabels];
+ for (IntWritable value : values) {
+ counts[value.get()]++;
+ }
+
+ context.write(key, new Frequencies(key.get(), counts));
+ }
+ }
+
+ /**
+ * Output of the job
+ *
+ */
+ private static class Frequencies implements Writable, Comparable<Frequencies>, Cloneable {
+
+ /** first key of the partition used to sort the partitions */
+ private long firstId;
+
+ /** counts[c] = num tuples from the partition with label == c */
+ private int[] counts;
+
+ Frequencies() { }
+
+ Frequencies(long firstId, int[] counts) {
+ this.firstId = firstId;
+ this.counts = Arrays.copyOf(counts, counts.length);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ firstId = in.readLong();
+ counts = DFUtils.readIntArray(in);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeLong(firstId);
+ DFUtils.writeArray(out, counts);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ return other instanceof Frequencies && firstId == ((Frequencies) other).firstId;
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) firstId;
+ }
+
+ @Override
+ protected Frequencies clone() {
+ return new Frequencies(firstId, counts);
+ }
+
+ @Override
+ public int compareTo(Frequencies obj) {
+ if (firstId < obj.firstId) {
+ return -1;
+ } else if (firstId > obj.firstId) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+ public static int[][] extractCounts(Frequencies[] partitions) {
+ int[][] counts = new int[partitions.length][];
+ for (int p = 0; p < partitions.length; p++) {
+ counts[p] = partitions[p].counts;
+ }
+ return counts;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
new file mode 100644
index 0000000..a2a3458
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
@@ -0,0 +1,264 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.lang.reflect.Field;
+import java.text.DecimalFormat;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.CategoricalNode;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.classifier.df.node.NumericalNode;
+
+/**
+ * This tool is to visualize the Decision tree
+ */
+@Deprecated
+public final class TreeVisualizer {
+
+ private TreeVisualizer() {}
+
+ private static String doubleToString(double value) {
+ DecimalFormat df = new DecimalFormat("0.##");
+ return df.format(value);
+ }
+
+ private static String toStringNode(Node node, Dataset dataset,
+ String[] attrNames, Map<String,Field> fields, int layer) {
+
+ StringBuilder buff = new StringBuilder();
+
+ try {
+ if (node instanceof CategoricalNode) {
+ CategoricalNode cnode = (CategoricalNode) node;
+ int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
+ double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode);
+ Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode);
+ String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset);
+ for (int i = 0; i < attrValues[attr].length; i++) {
+ int index = ArrayUtils.indexOf(values, i);
+ if (index < 0) {
+ continue;
+ }
+ buff.append('\n');
+ for (int j = 0; j < layer; j++) {
+ buff.append("| ");
+ }
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+ .append(attrValues[attr][i]);
+ buff.append(toStringNode(childs[index], dataset, attrNames, fields, layer + 1));
+ }
+ } else if (node instanceof NumericalNode) {
+ NumericalNode nnode = (NumericalNode) node;
+ int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
+ double split = (Double) fields.get("NumericalNode.split").get(nnode);
+ Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
+ Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
+ buff.append('\n');
+ for (int j = 0; j < layer; j++) {
+ buff.append("| ");
+ }
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ")
+ .append(doubleToString(split));
+ buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 1));
+ buff.append('\n');
+ for (int j = 0; j < layer; j++) {
+ buff.append("| ");
+ }
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ")
+ .append(doubleToString(split));
+ buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1));
+ } else if (node instanceof Leaf) {
+ Leaf leaf = (Leaf) node;
+ double label = (Double) fields.get("Leaf.label").get(leaf);
+ if (dataset.isNumerical(dataset.getLabelId())) {
+ buff.append(" : ").append(doubleToString(label));
+ } else {
+ buff.append(" : ").append(dataset.getLabelString(label));
+ }
+ }
+ } catch (IllegalAccessException iae) {
+ throw new IllegalStateException(iae);
+ }
+
+ return buff.toString();
+ }
+
+ private static Map<String,Field> getReflectMap() {
+ Map<String,Field> fields = new HashMap<>();
+
+ try {
+ Field m = CategoricalNode.class.getDeclaredField("attr");
+ m.setAccessible(true);
+ fields.put("CategoricalNode.attr", m);
+ m = CategoricalNode.class.getDeclaredField("values");
+ m.setAccessible(true);
+ fields.put("CategoricalNode.values", m);
+ m = CategoricalNode.class.getDeclaredField("childs");
+ m.setAccessible(true);
+ fields.put("CategoricalNode.childs", m);
+ m = NumericalNode.class.getDeclaredField("attr");
+ m.setAccessible(true);
+ fields.put("NumericalNode.attr", m);
+ m = NumericalNode.class.getDeclaredField("split");
+ m.setAccessible(true);
+ fields.put("NumericalNode.split", m);
+ m = NumericalNode.class.getDeclaredField("loChild");
+ m.setAccessible(true);
+ fields.put("NumericalNode.loChild", m);
+ m = NumericalNode.class.getDeclaredField("hiChild");
+ m.setAccessible(true);
+ fields.put("NumericalNode.hiChild", m);
+ m = Leaf.class.getDeclaredField("label");
+ m.setAccessible(true);
+ fields.put("Leaf.label", m);
+ m = Dataset.class.getDeclaredField("values");
+ m.setAccessible(true);
+ fields.put("Dataset.values", m);
+ } catch (NoSuchFieldException nsfe) {
+ throw new IllegalStateException(nsfe);
+ }
+
+ return fields;
+ }
+
+ /**
+ * Decision tree to String
+ *
+ * @param tree
+ * Node of tree
+ * @param attrNames
+ * attribute names
+ */
+ public static String toString(Node tree, Dataset dataset, String[] attrNames) {
+ return toStringNode(tree, dataset, attrNames, getReflectMap(), 0);
+ }
+
+ /**
+ * Print Decision tree
+ *
+ * @param tree
+ * Node of tree
+ * @param attrNames
+ * attribute names
+ */
+ public static void print(Node tree, Dataset dataset, String[] attrNames) {
+ System.out.println(toString(tree, dataset, attrNames));
+ }
+
+ private static String toStringPredict(Node node, Instance instance,
+ Dataset dataset, String[] attrNames, Map<String,Field> fields) {
+ StringBuilder buff = new StringBuilder();
+
+ try {
+ if (node instanceof CategoricalNode) {
+ CategoricalNode cnode = (CategoricalNode) node;
+ int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
+ double[] values = (double[]) fields.get("CategoricalNode.values").get(
+ cnode);
+ Node[] childs = (Node[]) fields.get("CategoricalNode.childs")
+ .get(cnode);
+ String[][] attrValues = (String[][]) fields.get("Dataset.values").get(
+ dataset);
+
+ int index = ArrayUtils.indexOf(values, instance.get(attr));
+ if (index >= 0) {
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+ .append(attrValues[attr][(int) instance.get(attr)]);
+ buff.append(" -> ");
+ buff.append(toStringPredict(childs[index], instance, dataset,
+ attrNames, fields));
+ }
+ } else if (node instanceof NumericalNode) {
+ NumericalNode nnode = (NumericalNode) node;
+ int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
+ double split = (Double) fields.get("NumericalNode.split").get(nnode);
+ Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
+ Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
+
+ if (instance.get(attr) < split) {
+ buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+ .append(" = ").append(doubleToString(instance.get(attr)))
+ .append(") < ").append(doubleToString(split));
+ buff.append(" -> ");
+ buff.append(toStringPredict(loChild, instance, dataset, attrNames,
+ fields));
+ } else {
+ buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+ .append(" = ").append(doubleToString(instance.get(attr)))
+ .append(") >= ").append(doubleToString(split));
+ buff.append(" -> ");
+ buff.append(toStringPredict(hiChild, instance, dataset, attrNames,
+ fields));
+ }
+ } else if (node instanceof Leaf) {
+ Leaf leaf = (Leaf) node;
+ double label = (Double) fields.get("Leaf.label").get(leaf);
+ if (dataset.isNumerical(dataset.getLabelId())) {
+ buff.append(doubleToString(label));
+ } else {
+ buff.append(dataset.getLabelString(label));
+ }
+ }
+ } catch (IllegalAccessException iae) {
+ throw new IllegalStateException(iae);
+ }
+
+ return buff.toString();
+ }
+
+ /**
+ * Predict trace to String
+ *
+ * @param tree
+ * Node of tree
+ * @param attrNames
+ * attribute names
+ */
+ public static String[] predictTrace(Node tree, Data data, String[] attrNames) {
+ Map<String,Field> reflectMap = getReflectMap();
+ String[] prediction = new String[data.size()];
+ for (int i = 0; i < data.size(); i++) {
+ prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(),
+ attrNames, reflectMap);
+ }
+ return prediction;
+ }
+
+ /**
+ * Print predict trace
+ *
+ * @param tree
+ * Node of tree
+ * @param attrNames
+ * attribute names
+ */
+ public static void predictTracePrint(Node tree, Data data, String[] attrNames) {
+ Map<String,Field> reflectMap = getReflectMap();
+ for (int i = 0; i < data.size(); i++) {
+ System.out.println(toStringPredict(tree, data.get(i), data.getDataset(),
+ attrNames, reflectMap));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
new file mode 100644
index 0000000..e1b55ab
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
@@ -0,0 +1,212 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Locale;
+import java.util.Random;
+import java.util.Scanner;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This tool is used to uniformly distribute the class of all the tuples of the dataset over a given number of
+ * partitions.<br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+@Deprecated
+public final class UDistrib {
+
+ private static final Logger log = LoggerFactory.getLogger(UDistrib.class);
+
+ private UDistrib() {}
+
+ /**
+ * Launch the uniform distribution tool. Requires the following command line arguments:<br>
+ *
+ * data : data path dataset : dataset path numpartitions : num partitions output : output path
+ *
+ * @throws java.io.IOException
+ */
+ public static void main(String[] args) throws IOException {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+ abuilder.withName("data").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+ abuilder.withName("dataset").withMinimum(1).create()).withDescription("Dataset path").create();
+
+ Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument(
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Path to generated files").create();
+
+ Option partitionsOpt = obuilder.withLongName("numpartitions").withShortName("p").withRequired(true)
+ .withArgument(abuilder.withName("numparts").withMinimum(1).withMinimum(1).create()).withDescription(
+ "Number of partitions to create").create();
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(outputOpt).withOption(
+ datasetOpt).withOption(partitionsOpt).withOption(helpOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+
+ String data = cmdLine.getValue(dataOpt).toString();
+ String dataset = cmdLine.getValue(datasetOpt).toString();
+ int numPartitions = Integer.parseInt(cmdLine.getValue(partitionsOpt).toString());
+ String output = cmdLine.getValue(outputOpt).toString();
+
+ runTool(data, dataset, output, numPartitions);
+ } catch (OptionException e) {
+ log.warn(e.toString(), e);
+ CommandLineUtil.printHelp(group);
+ }
+
+ }
+
+ private static void runTool(String dataStr, String datasetStr, String output, int numPartitions) throws IOException {
+
+ Preconditions.checkArgument(numPartitions > 0, "numPartitions <= 0");
+
+ // make sure the output file does not exist
+ Path outputPath = new Path(output);
+ Configuration conf = new Configuration();
+ FileSystem fs = outputPath.getFileSystem(conf);
+
+ Preconditions.checkArgument(!fs.exists(outputPath), "Output path already exists");
+
+ // create a new file corresponding to each partition
+ // Path workingDir = fs.getWorkingDirectory();
+ // FileSystem wfs = workingDir.getFileSystem(conf);
+ // File parentFile = new File(workingDir.toString());
+ // File tempFile = FileUtil.createLocalTempFile(parentFile, "Parts", true);
+ // File tempFile = File.createTempFile("df.tools.UDistrib","");
+ // tempFile.deleteOnExit();
+ File tempFile = FileUtil.createLocalTempFile(new File(""), "df.tools.UDistrib", true);
+ Path partsPath = new Path(tempFile.toString());
+ FileSystem pfs = partsPath.getFileSystem(conf);
+
+ Path[] partPaths = new Path[numPartitions];
+ FSDataOutputStream[] files = new FSDataOutputStream[numPartitions];
+ for (int p = 0; p < numPartitions; p++) {
+ partPaths[p] = new Path(partsPath, String.format(Locale.ENGLISH, "part.%03d", p));
+ files[p] = pfs.create(partPaths[p]);
+ }
+
+ Path datasetPath = new Path(datasetStr);
+ Dataset dataset = Dataset.load(conf, datasetPath);
+
+ // currents[label] = next partition file where to place the tuple
+ int[] currents = new int[dataset.nblabels()];
+
+ // currents is initialized randomly in the range [0, numpartitions[
+ Random random = RandomUtils.getRandom();
+ for (int c = 0; c < currents.length; c++) {
+ currents[c] = random.nextInt(numPartitions);
+ }
+
+ // foreach tuple of the data
+ Path dataPath = new Path(dataStr);
+ FileSystem ifs = dataPath.getFileSystem(conf);
+ FSDataInputStream input = ifs.open(dataPath);
+ Scanner scanner = new Scanner(input, "UTF-8");
+ DataConverter converter = new DataConverter(dataset);
+
+ int id = 0;
+ while (scanner.hasNextLine()) {
+ if (id % 1000 == 0) {
+ log.info("progress : {}", id);
+ }
+
+ String line = scanner.nextLine();
+ if (line.isEmpty()) {
+ continue; // skip empty lines
+ }
+
+ // write the tuple in files[tuple.label]
+ Instance instance = converter.convert(line);
+ int label = (int) dataset.getLabel(instance);
+ files[currents[label]].writeBytes(line);
+ files[currents[label]].writeChar('\n');
+
+ // update currents
+ currents[label]++;
+ if (currents[label] == numPartitions) {
+ currents[label] = 0;
+ }
+ }
+
+ // close all the files.
+ scanner.close();
+ for (FSDataOutputStream file : files) {
+ Closeables.close(file, false);
+ }
+
+ // merge all output files
+ FileUtil.copyMerge(pfs, partsPath, fs, outputPath, true, conf, null);
+ /*
+ * FSDataOutputStream joined = fs.create(new Path(outputPath, "uniform.data")); for (int p = 0; p <
+ * numPartitions; p++) {log.info("Joining part : {}", p); FSDataInputStream partStream =
+ * fs.open(partPaths[p]);
+ *
+ * IOUtils.copyBytes(partStream, joined, conf, false);
+ *
+ * partStream.close(); }
+ *
+ * joined.close();
+ *
+ * fs.delete(partsPath, true);
+ */
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
new file mode 100644
index 0000000..049f9bf
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.evaluation;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.list.DoubleArrayList;
+
+import com.google.common.base.Preconditions;
+
+import java.util.Random;
+
+/**
+ * Computes AUC and a few other accuracy statistics without storing huge amounts of data. This is
+ * done by keeping uniform samples of the positive and negative scores. Then, when AUC is to be
+ * computed, the remaining scores are sorted and a rank-sum statistic is used to compute the AUC.
+ * Since AUC is invariant with respect to down-sampling of either positives or negatives, this is
+ * close to correct and is exactly correct if maxBufferSize or fewer positive and negative scores
+ * are examined.
+ */
+public class Auc {
+
+ private int maxBufferSize = 10000;
+ private final DoubleArrayList[] scores = {new DoubleArrayList(), new DoubleArrayList()};
+ private final Random rand;
+ private int samples;
+ private final double threshold;
+ private final Matrix confusion;
+ private final DenseMatrix entropy;
+
+ private boolean probabilityScore = true;
+
+ private boolean hasScore;
+
+ /**
+ * Allocates a new data-structure for accumulating information about AUC and a few other accuracy
+ * measures.
+ * @param threshold The threshold to use in computing the confusion matrix.
+ */
+ public Auc(double threshold) {
+ confusion = new DenseMatrix(2, 2);
+ entropy = new DenseMatrix(2, 2);
+ this.rand = RandomUtils.getRandom();
+ this.threshold = threshold;
+ }
+
+ public Auc() {
+ this(0.5);
+ }
+
+ /**
+ * Adds a score to the AUC buffers.
+ *
+ * @param trueValue Whether this score is for a true-positive or a true-negative example.
+ * @param score The score for this example.
+ */
+ public void add(int trueValue, double score) {
+ Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1");
+ hasScore = true;
+
+ int predictedClass = score > threshold ? 1 : 0;
+ confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1);
+
+ samples++;
+ if (isProbabilityScore()) {
+ double limited = Math.max(1.0e-20, Math.min(score, 1 - 1.0e-20));
+ double v0 = entropy.get(trueValue, 0);
+ entropy.set(trueValue, 0, (Math.log1p(-limited) - v0) / samples + v0);
+
+ double v1 = entropy.get(trueValue, 1);
+ entropy.set(trueValue, 1, (Math.log(limited) - v1) / samples + v1);
+ }
+
+ // add to buffers
+ DoubleArrayList buf = scores[trueValue];
+ if (buf.size() >= maxBufferSize) {
+ // but if too many points are seen, we insert into a random
+ // place and discard the predecessor. The random place could
+ // be anywhere, possibly not even in the buffer.
+ // this is a special case of Knuth's permutation algorithm
+ // but since we don't ever shuffle the first maxBufferSize
+ // samples, the result isn't just a fair sample of the prefixes
+ // of all permutations. The CONTENTs of the result, however,
+ // will be a fair and uniform sample of maxBufferSize elements
+ // chosen from all elements without replacement
+ int index = rand.nextInt(samples);
+ if (index < buf.size()) {
+ buf.set(index, score);
+ }
+ } else {
+ // for small buffers, we collect all points without permuting
+ // since we sort the data later, permuting now would just be
+ // pedantic
+ buf.add(score);
+ }
+ }
+
+ public void add(int trueValue, int predictedClass) {
+ hasScore = false;
+ Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1");
+ confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1);
+ }
+
+ /**
+ * Computes the AUC of points seen so far. This can be moderately expensive since it requires
+ * that all points that have been retained be sorted.
+ *
+ * @return The value of the Area Under the receiver operating Curve.
+ */
+ public double auc() {
+ Preconditions.checkArgument(hasScore, "Can't compute AUC for classifier without a score");
+ scores[0].sort();
+ scores[1].sort();
+
+ double n0 = scores[0].size();
+ double n1 = scores[1].size();
+
+ if (n0 == 0 || n1 == 0) {
+ return 0.5;
+ }
+
+ // scan the data
+ int i0 = 0;
+ int i1 = 0;
+ int rank = 1;
+ double rankSum = 0;
+ while (i0 < n0 && i1 < n1) {
+
+ double v0 = scores[0].get(i0);
+ double v1 = scores[1].get(i1);
+
+ if (v0 < v1) {
+ i0++;
+ rank++;
+ } else if (v1 < v0) {
+ i1++;
+ rankSum += rank;
+ rank++;
+ } else {
+ // ties have to be handled delicately
+ double tieScore = v0;
+
+ // how many negatives are tied?
+ int k0 = 0;
+ while (i0 < n0 && scores[0].get(i0) == tieScore) {
+ k0++;
+ i0++;
+ }
+
+ // and how many positives
+ int k1 = 0;
+ while (i1 < n1 && scores[1].get(i1) == tieScore) {
+ k1++;
+ i1++;
+ }
+
+ // we found k0 + k1 tied values which have
+ // ranks in the half open interval [rank, rank + k0 + k1)
+ // the average rank is assigned to all
+ rankSum += (rank + (k0 + k1 - 1) / 2.0) * k1;
+ rank += k0 + k1;
+ }
+ }
+
+ if (i1 < n1) {
+ rankSum += (rank + (n1 - i1 - 1) / 2.0) * (n1 - i1);
+ rank += (int) (n1 - i1);
+ }
+
+ return (rankSum / n1 - (n1 + 1) / 2) / n0;
+ }
+
+ /**
+ * Returns the confusion matrix for the classifier supposing that we were to use a particular
+ * threshold.
+ * @return The confusion matrix.
+ */
+ public Matrix confusion() {
+ return confusion;
+ }
+
+ /**
+ * Returns a matrix related to the confusion matrix and to the log-likelihood. For a
+ * pretty accurate classifier, N + entropy is nearly the same as the confusion matrix
+ * because log(1-eps) \approx -eps if eps is small.
+ *
+ * For lower accuracy classifiers, this measure will give us a better picture of how
+ * things work our.
+ *
+ * Also, by definition, log-likelihood = sum(diag(entropy))
+ * @return Returns a cell by cell break-down of the log-likelihood
+ */
+ public Matrix entropy() {
+ if (!hasScore) {
+ // find a constant score that would optimize log-likelihood, but use a dash of Bayesian
+ // conservatism to avoid dividing by zero or taking log(0)
+ double p = (0.5 + confusion.get(1, 1)) / (1 + confusion.get(0, 0) + confusion.get(1, 1));
+ entropy.set(0, 0, confusion.get(0, 0) * Math.log1p(-p));
+ entropy.set(0, 1, confusion.get(0, 1) * Math.log(p));
+ entropy.set(1, 0, confusion.get(1, 0) * Math.log1p(-p));
+ entropy.set(1, 1, confusion.get(1, 1) * Math.log(p));
+ }
+ return entropy;
+ }
+
+ public void setMaxBufferSize(int maxBufferSize) {
+ this.maxBufferSize = maxBufferSize;
+ }
+
+ public boolean isProbabilityScore() {
+ return probabilityScore;
+ }
+
+ public void setProbabilityScore(boolean probabilityScore) {
+ this.probabilityScore = probabilityScore;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
new file mode 100644
index 0000000..f0794b3
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm. Note that this class
+ * supports {@link #classifyFull}, but not {@code classify} or
+ * {@code classifyScalar}. The reason that these two methods are not
+ * supported is because the scores computed by a NaiveBayesClassifier do not
+ * represent probabilities.
+ */
+public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier {
+
+ private final NaiveBayesModel model;
+
+ protected AbstractNaiveBayesClassifier(NaiveBayesModel model) {
+ this.model = model;
+ }
+
+ protected NaiveBayesModel getModel() {
+ return model;
+ }
+
+ protected abstract double getScoreForLabelFeature(int label, int feature);
+
+ protected double getScoreForLabelInstance(int label, Vector instance) {
+ double result = 0.0;
+ for (Element e : instance.nonZeroes()) {
+ result += e.get() * getScoreForLabelFeature(label, e.index());
+ }
+ return result;
+ }
+
+ @Override
+ public int numCategories() {
+ return model.numLabels();
+ }
+
+ @Override
+ public Vector classifyFull(Vector instance) {
+ return classifyFull(model.createScoringVector(), instance);
+ }
+
+ @Override
+ public Vector classifyFull(Vector r, Vector instance) {
+ for (int label = 0; label < model.numLabels(); label++) {
+ r.setQuick(label, getScoreForLabelInstance(label, instance));
+ }
+ return r;
+ }
+
+ /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */
+ @Override
+ public double classifyScalar(Vector instance) {
+ throw new UnsupportedOperationException("Not supported in Naive Bayes");
+ }
+
+ /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */
+ @Override
+ public Vector classify(Vector instance) {
+ throw new UnsupportedOperationException("probabilites not supported in Naive Bayes");
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
new file mode 100644
index 0000000..4db8b17
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.regex.Pattern;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.naivebayes.training.ThetaMapper;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public final class BayesUtils {
+
+ private static final Pattern SLASH = Pattern.compile("/");
+
+ private BayesUtils() {}
+
+ public static NaiveBayesModel readModelFromDir(Path base, Configuration conf) {
+
+ float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);
+ boolean isComplementary = conf.getBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, true);
+
+ // read feature sums and label sums
+ Vector scoresPerLabel = null;
+ Vector scoresPerFeature = null;
+ for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ String key = record.getFirst().toString();
+ VectorWritable value = record.getSecond();
+ if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {
+ scoresPerFeature = value.get();
+ } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {
+ scoresPerLabel = value.get();
+ }
+ }
+
+ Preconditions.checkNotNull(scoresPerFeature);
+ Preconditions.checkNotNull(scoresPerLabel);
+
+ Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size());
+ for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());
+ }
+
+ // perLabelThetaNormalizer is only used by the complementary model, we do not instantiate it for the standard model
+ Vector perLabelThetaNormalizer = null;
+ if (isComplementary) {
+ perLabelThetaNormalizer=scoresPerLabel.like();
+ for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) {
+ perLabelThetaNormalizer = entry.getSecond().get();
+ }
+ }
+ Preconditions.checkNotNull(perLabelThetaNormalizer);
+ }
+
+ return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perLabelThetaNormalizer,
+ alphaI, isComplementary);
+ }
+
+ /** Write the list of labels into a map file */
+ public static int writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath)
+ throws IOException {
+ FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+ int i = 0;
+ try (SequenceFile.Writer writer =
+ SequenceFile.createWriter(fs.getConf(), SequenceFile.Writer.file(indexPath),
+ SequenceFile.Writer.keyClass(Text.class), SequenceFile.Writer.valueClass(IntWritable.class))) {
+ for (String label : labels) {
+ writer.append(new Text(label), new IntWritable(i++));
+ }
+ }
+ return i;
+ }
+
+ public static int writeLabelIndex(Configuration conf, Path indexPath,
+ Iterable<Pair<Text,IntWritable>> labels) throws IOException {
+ FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+ Collection<String> seen = new HashSet<>();
+ int i = 0;
+ try (SequenceFile.Writer writer =
+ SequenceFile.createWriter(fs.getConf(), SequenceFile.Writer.file(indexPath),
+ SequenceFile.Writer.keyClass(Text.class), SequenceFile.Writer.valueClass(IntWritable.class))){
+ for (Object label : labels) {
+ String theLabel = SLASH.split(((Pair<?, ?>) label).getFirst().toString())[1];
+ if (!seen.contains(theLabel)) {
+ writer.append(new Text(theLabel), new IntWritable(i++));
+ seen.add(theLabel);
+ }
+ }
+ }
+ return i;
+ }
+
+ public static Map<Integer, String> readLabelIndex(Configuration conf, Path indexPath) {
+ Map<Integer, String> labelMap = new HashMap<>();
+ for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(indexPath, true, conf)) {
+ labelMap.put(pair.getSecond().get(), pair.getFirst().toString());
+ }
+ return labelMap;
+ }
+
+ public static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException {
+ OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<>();
+ for (Pair<Writable,IntWritable> entry
+ : new SequenceFileIterable<Writable,IntWritable>(HadoopUtil.getSingleCachedFile(conf), conf)) {
+ index.put(entry.getFirst().toString(), entry.getSecond().get());
+ }
+ return index;
+ }
+
+ public static Map<String,Vector> readScoresFromCache(Configuration conf) throws IOException {
+ Map<String,Vector> sumVectors = new HashMap<>();
+ for (Pair<Text,VectorWritable> entry
+ : new SequenceFileDirIterable<Text,VectorWritable>(HadoopUtil.getSingleCachedFile(conf),
+ PathType.LIST, PathFilters.partFilter(), conf)) {
+ sumVectors.put(entry.getFirst().toString(), entry.getSecond().get());
+ }
+ return sumVectors;
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
new file mode 100644
index 0000000..18bd3d6
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/** Implementation of the Naive Bayes Classifier Algorithm */
+public class ComplementaryNaiveBayesClassifier extends AbstractNaiveBayesClassifier {
+ public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) {
+ super(model);
+ }
+
+ @Override
+ public double getScoreForLabelFeature(int label, int feature) {
+ NaiveBayesModel model = getModel();
+ double weight = computeWeight(model.featureWeight(feature), model.weight(label, feature),
+ model.totalWeightSum(), model.labelWeight(label), model.alphaI(), model.numFeatures());
+ // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+ return weight / model.thetaNormalizer(label);
+ }
+
+ // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.1, Skewed Data bias
+ public static double computeWeight(double featureWeight, double featureLabelWeight,
+ double totalWeight, double labelWeight, double alphaI, double numFeatures) {
+ double numerator = featureWeight - featureLabelWeight + alphaI;
+ double denominator = totalWeight - labelWeight + alphaI * numFeatures;
+ return -Math.log(numerator / denominator);
+ }
+}