You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:08:09 UTC
[38/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
new file mode 100644
index 0000000..efd233f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
@@ -0,0 +1,248 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Defines the interface for classifiers that take a vector as input. This is
+ * implemented as an abstract class so that it can implement a number of handy
+ * convenience methods related to classification of vectors.
+ *
+ * <p>
+ * A classifier takes an input vector and calculates the scores (usually
+ * probabilities) that the input vector belongs to one of {@code n}
+ * categories. In {@code AbstractVectorClassifier} each category is denoted
+ * by an integer {@code c} between {@code 0} and {@code n-1}
+ * (inclusive).
+ *
+ * <p>
+ * New users should start by looking at {@link #classifyFull} (not {@link #classify}).
+ *
+ */
+public abstract class AbstractVectorClassifier {
+
+ /** Minimum allowable log likelihood value. */
+ public static final double MIN_LOG_LIKELIHOOD = -100.0;
+
+ /**
+ * Returns the number of categories that a target variable can be assigned to.
+ * A vector classifier will encode it's output as an integer from
+ * {@code 0} to {@code numCategories()-1} (inclusive).
+ *
+ * @return The number of categories.
+ */
+ public abstract int numCategories();
+
+ /**
+ * Compute and return a vector containing {@code n-1} scores, where
+ * {@code n} is equal to {@code numCategories()}, given an input
+ * vector {@code instance}. Higher scores indicate that the input vector
+ * is more likely to belong to that category. The categories are denoted by
+ * the integers {@code 0} through {@code n-1} (inclusive), and the
+ * scores in the returned vector correspond to categories 1 through
+ * {@code n-1} (leaving out category 0). It is assumed that the score for
+ * category 0 is one minus the sum of the scores in the returned vector.
+ *
+ * @param instance A feature vector to be classified.
+ * @return A vector of probabilities in 1 of {@code n-1} encoding.
+ */
+ public abstract Vector classify(Vector instance);
+
+ /**
+ * Compute and return a vector of scores before applying the inverse link
+ * function. For logistic regression and other generalized linear models, this
+ * is just the linear part of the classification.
+ *
+ * <p>
+ * The implementation of this method provided by {@code AbstractVectorClassifier} throws an
+ * {@link UnsupportedOperationException}. Your subclass must explicitly override this method to support
+ * this operation.
+ *
+ * @param features A feature vector to be classified.
+ * @return A vector of scores. If transformed by the link function, these will become probabilities.
+ */
+ public Vector classifyNoLink(Vector features) {
+ throw new UnsupportedOperationException(this.getClass().getName()
+ + " doesn't support classification without a link");
+ }
+
+ /**
+ * Classifies a vector in the special case of a binary classifier where
+ * {@link #classify(Vector)} would return a vector with only one element. As
+ * such, using this method can avoid the allocation of a vector.
+ *
+ * @param instance The feature vector to be classified.
+ * @return The score for category 1.
+ *
+ * @see #classify(Vector)
+ */
+ public abstract double classifyScalar(Vector instance);
+
+ /**
+ * Computes and returns a vector containing {@code n} scores, where
+ * {@code n} is {@code numCategories()}, given an input vector
+ * {@code instance}. Higher scores indicate that the input vector is more
+ * likely to belong to the corresponding category. The categories are denoted
+ * by the integers {@code 0} through {@code n-1} (inclusive).
+ *
+ * <p>
+ * Using this method it is possible to classify an input vector, for example,
+ * by selecting the category with the largest score. If
+ * {@code classifier} is an instance of
+ * {@code AbstractVectorClassifier} and {@code input} is a
+ * {@code Vector} of features describing an element to be classified,
+ * then the following code could be used to classify {@code input}.<br>
+ * {@code
+ * Vector scores = classifier.classifyFull(input);<br>
+ * int assignedCategory = scores.maxValueIndex();<br>
+ * } Here {@code assignedCategory} is the index of the category
+ * with the maximum score.
+ *
+ * <p>
+ * If an {@code n-1} encoding is acceptable, and allocation performance
+ * is an issue, then the {@link #classify(Vector)} method is probably better
+ * to use.
+ *
+ * @see #classify(Vector)
+ * @see #classifyFull(Vector r, Vector instance)
+ *
+ * @param instance A vector of features to be classified.
+ * @return A vector of probabilities, one for each category.
+ */
+ public Vector classifyFull(Vector instance) {
+ return classifyFull(new DenseVector(numCategories()), instance);
+ }
+
+ /**
+ * Computes and returns a vector containing {@code n} scores, where
+ * {@code n} is {@code numCategories()}, given an input vector
+ * {@code instance}. Higher scores indicate that the input vector is more
+ * likely to belong to the corresponding category. The categories are denoted
+ * by the integers {@code 0} through {@code n-1} (inclusive). The
+ * main difference between this method and {@link #classifyFull(Vector)} is
+ * that this method allows a user to provide a previously allocated
+ * {@code Vector r} to store the returned scores.
+ *
+ * <p>
+ * Using this method it is possible to classify an input vector, for example,
+ * by selecting the category with the largest score. If
+ * {@code classifier} is an instance of
+ * {@code AbstractVectorClassifier}, {@code result} is a non-null
+ * {@code Vector}, and {@code input} is a {@code Vector} of
+ * features describing an element to be classified, then the following code
+ * could be used to classify {@code input}.<br>
+ * {@code
+ * Vector scores = classifier.classifyFull(result, input); // Notice that scores == result<br>
+ * int assignedCategory = scores.maxValueIndex();<br>
+ * } Here {@code assignedCategory} is the index of the category
+ * with the maximum score.
+ *
+ * @param r Where to put the results.
+ * @param instance A vector of features to be classified.
+ * @return A vector of scores/probabilities, one for each category.
+ */
+ public Vector classifyFull(Vector r, Vector instance) {
+ r.viewPart(1, numCategories() - 1).assign(classify(instance));
+ r.setQuick(0, 1.0 - r.zSum());
+ return r;
+ }
+
+
+ /**
+ * Returns n-1 probabilities, one for each categories 1 through
+ * {@code n-1}, for each row of a matrix, where {@code n} is equal
+ * to {@code numCategories()}. The probability of the missing 0-th
+ * category is 1 - rowSum(this result).
+ *
+ * @param data The matrix whose rows are the input vectors to classify
+ * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category.
+ */
+ public Matrix classify(Matrix data) {
+ Matrix r = new DenseMatrix(data.numRows(), numCategories() - 1);
+ for (int row = 0; row < data.numRows(); row++) {
+ r.assignRow(row, classify(data.viewRow(row)));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a matrix where the rows of the matrix each contain {@code n} probabilities, one for each category.
+ *
+ * @param data The matrix whose rows are the input vectors to classify
+ * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category.
+ */
+ public Matrix classifyFull(Matrix data) {
+ Matrix r = new DenseMatrix(data.numRows(), numCategories());
+ for (int row = 0; row < data.numRows(); row++) {
+ classifyFull(r.viewRow(row), data.viewRow(row));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a vector of probabilities of category 1, one for each row
+ * of a matrix. This only makes sense if there are exactly two categories, but
+ * calling this method in that case can save a number of vector allocations.
+ *
+ * @param data The matrix whose rows are vectors to classify
+ * @return A vector of scores, with one value per row of the input matrix.
+ */
+ public Vector classifyScalar(Matrix data) {
+ Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories");
+
+ Vector r = new DenseVector(data.numRows());
+ for (int row = 0; row < data.numRows(); row++) {
+ r.set(row, classifyScalar(data.viewRow(row)));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a measure of how good the classification for a particular example
+ * actually is.
+ *
+ * @param actual The correct category for the example.
+ * @param data The vector to be classified.
+ * @return The log likelihood of the correct answer as estimated by the current model. This will always be <= 0
+ * and larger (closer to 0) indicates better accuracy. In order to simplify code that maintains eunning averages,
+ * we bound this value at -100.
+ */
+ public double logLikelihood(int actual, Vector data) {
+ if (numCategories() == 2) {
+ double p = classifyScalar(data);
+ if (actual > 0) {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log(p));
+ } else {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log1p(-p));
+ }
+ } else {
+ Vector p = classify(data);
+ if (actual > 0) {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log(p.get(actual - 1)));
+ } else {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log1p(-p.zSum()));
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java b/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
new file mode 100644
index 0000000..29eaa0d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
@@ -0,0 +1,74 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+/**
+ * Result of a document classification. The label and the associated score (usually probabilty)
+ */
+public class ClassifierResult {
+
+ private String label;
+ private double score;
+ private double logLikelihood = Double.MAX_VALUE;
+
+ public ClassifierResult() { }
+
+ public ClassifierResult(String label, double score) {
+ this.label = label;
+ this.score = score;
+ }
+
+ public ClassifierResult(String label) {
+ this.label = label;
+ }
+
+ public ClassifierResult(String label, double score, double logLikelihood) {
+ this.label = label;
+ this.score = score;
+ this.logLikelihood = logLikelihood;
+ }
+
+ public double getLogLikelihood() {
+ return logLikelihood;
+ }
+
+ public void setLogLikelihood(double logLikelihood) {
+ this.logLikelihood = logLikelihood;
+ }
+
+ public String getLabel() {
+ return label;
+ }
+
+ public double getScore() {
+ return score;
+ }
+
+ public void setLabel(String label) {
+ this.label = label;
+ }
+
+ public void setScore(double score) {
+ this.score = score;
+ }
+
+ @Override
+ public String toString() {
+ return "ClassifierResult{" + "category='" + label + '\'' + ", score=" + score + '}';
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java b/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
new file mode 100644
index 0000000..0baa4bf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
@@ -0,0 +1,444 @@
+/**
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Map;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.math3.stat.descriptive.moment.Mean;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The ConfusionMatrix Class stores the result of Classification of a Test Dataset.
+ *
+ * The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default.
+ *
+ * See http://en.wikipedia.org/wiki/Confusion_matrix for background
+ */
+public class ConfusionMatrix {
+ private static final Logger LOG = LoggerFactory.getLogger(ConfusionMatrix.class);
+ private final Map<String,Integer> labelMap = Maps.newLinkedHashMap();
+ private final int[][] confusionMatrix;
+ private int samples = 0;
+ private String defaultLabel = "unknown";
+
+ public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
+ confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
+ this.defaultLabel = defaultLabel;
+ int i = 0;
+ for (String label : labels) {
+ labelMap.put(label, i++);
+ }
+ labelMap.put(defaultLabel, i);
+ }
+
+ public ConfusionMatrix(Matrix m) {
+ confusionMatrix = new int[m.numRows()][m.numRows()];
+ setMatrix(m);
+ }
+
+ public int[][] getConfusionMatrix() {
+ return confusionMatrix;
+ }
+
+ public Collection<String> getLabels() {
+ return Collections.unmodifiableCollection(labelMap.keySet());
+ }
+
+ private int numLabels() {
+ return labelMap.size();
+ }
+
+ public double getAccuracy(String label) {
+ int labelId = labelMap.get(label);
+ int labelTotal = 0;
+ int correct = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ labelTotal += confusionMatrix[labelId][i];
+ if (i == labelId) {
+ correct += confusionMatrix[labelId][i];
+ }
+ }
+ return 100.0 * correct / labelTotal;
+ }
+
+ // Producer accuracy
+ public double getAccuracy() {
+ int total = 0;
+ int correct = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ for (int j = 0; j < numLabels(); j++) {
+ total += confusionMatrix[i][j];
+ if (i == j) {
+ correct += confusionMatrix[i][j];
+ }
+ }
+ }
+ return 100.0 * correct / total;
+ }
+
+ /** Sum of true positives and false negatives */
+ private int getActualNumberOfTestExamplesForClass(String label) {
+ int labelId = labelMap.get(label);
+ int sum = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ sum += confusionMatrix[labelId][i];
+ }
+ return sum;
+ }
+
+ public double getPrecision(String label) {
+ int labelId = labelMap.get(label);
+ int truePositives = confusionMatrix[labelId][labelId];
+ int falsePositives = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ if (i == labelId) {
+ continue;
+ }
+ falsePositives += confusionMatrix[i][labelId];
+ }
+
+ if (truePositives + falsePositives == 0) {
+ return 0;
+ }
+
+ return ((double) truePositives) / (truePositives + falsePositives);
+ }
+
+ public double getWeightedPrecision() {
+ double[] precisions = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ precisions[index] = getPrecision(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(precisions, weights);
+ }
+
+ public double getRecall(String label) {
+ int labelId = labelMap.get(label);
+ int truePositives = confusionMatrix[labelId][labelId];
+ int falseNegatives = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ if (i == labelId) {
+ continue;
+ }
+ falseNegatives += confusionMatrix[labelId][i];
+ }
+ if (truePositives + falseNegatives == 0) {
+ return 0;
+ }
+ return ((double) truePositives) / (truePositives + falseNegatives);
+ }
+
+ public double getWeightedRecall() {
+ double[] recalls = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ recalls[index] = getRecall(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(recalls, weights);
+ }
+
+ public double getF1score(String label) {
+ double precision = getPrecision(label);
+ double recall = getRecall(label);
+ if (precision + recall == 0) {
+ return 0;
+ }
+ return 2 * precision * recall / (precision + recall);
+ }
+
+ public double getWeightedF1score() {
+ double[] f1Scores = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ f1Scores[index] = getF1score(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(f1Scores, weights);
+ }
+
+ // User accuracy
+ public double getReliability() {
+ int count = 0;
+ double accuracy = 0;
+ for (String label: labelMap.keySet()) {
+ if (!label.equals(defaultLabel)) {
+ accuracy += getAccuracy(label);
+ }
+ count++;
+ }
+ return accuracy / count;
+ }
+
+ /**
+ * Accuracy v.s. randomly classifying all samples.
+ * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy())
+ * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales.
+ * Educational And Psychological Measurement 20:37-46.
+ *
+ * Formula and variable names from:
+ * http://www.yale.edu/ceo/OEFS/Accuracy.pdf
+ *
+ * @return double
+ */
+ public double getKappa() {
+ double a = 0.0;
+ double b = 0.0;
+ for (int i = 0; i < confusionMatrix.length; i++) {
+ a += confusionMatrix[i][i];
+ double br = 0;
+ for (int j = 0; j < confusionMatrix.length; j++) {
+ br += confusionMatrix[i][j];
+ }
+ double bc = 0;
+ for (int[] vec : confusionMatrix) {
+ bc += vec[i];
+ }
+ b += br * bc;
+ }
+ return (samples * a - b) / (samples * samples - b);
+ }
+
+ /**
+ * Standard deviation of normalized producer accuracy
+ * Not a standard score
+ * @return double
+ */
+ public RunningAverageAndStdDev getNormalizedStats() {
+ RunningAverageAndStdDev summer = new FullRunningAverageAndStdDev();
+ for (int d = 0; d < confusionMatrix.length; d++) {
+ double total = 0;
+ for (int j = 0; j < confusionMatrix.length; j++) {
+ total += confusionMatrix[d][j];
+ }
+ summer.addDatum(confusionMatrix[d][d] / (total + 0.000001));
+ }
+
+ return summer;
+ }
+
+ public int getCorrect(String label) {
+ int labelId = labelMap.get(label);
+ return confusionMatrix[labelId][labelId];
+ }
+
+ public int getTotal(String label) {
+ int labelId = labelMap.get(label);
+ int labelTotal = 0;
+ for (int i = 0; i < labelMap.size(); i++) {
+ labelTotal += confusionMatrix[labelId][i];
+ }
+ return labelTotal;
+ }
+
+ public void addInstance(String correctLabel, ClassifierResult classifiedResult) {
+ samples++;
+ incrementCount(correctLabel, classifiedResult.getLabel());
+ }
+
+ public void addInstance(String correctLabel, String classifiedLabel) {
+ samples++;
+ incrementCount(correctLabel, classifiedLabel);
+ }
+
+ public int getCount(String correctLabel, String classifiedLabel) {
+ if(!labelMap.containsKey(correctLabel)) {
+ LOG.warn("Label {} did not appear in the training examples", correctLabel);
+ return 0;
+ }
+ Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
+ int correctId = labelMap.get(correctLabel);
+ int classifiedId = labelMap.get(classifiedLabel);
+ return confusionMatrix[correctId][classifiedId];
+ }
+
+ public void putCount(String correctLabel, String classifiedLabel, int count) {
+ if(!labelMap.containsKey(correctLabel)) {
+ LOG.warn("Label {} did not appear in the training examples", correctLabel);
+ return;
+ }
+ Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
+ int correctId = labelMap.get(correctLabel);
+ int classifiedId = labelMap.get(classifiedLabel);
+ if (confusionMatrix[correctId][classifiedId] == 0.0 && count != 0) {
+ samples++;
+ }
+ confusionMatrix[correctId][classifiedId] = count;
+ }
+
+ public String getDefaultLabel() {
+ return defaultLabel;
+ }
+
+ public void incrementCount(String correctLabel, String classifiedLabel, int count) {
+ putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel));
+ }
+
+ public void incrementCount(String correctLabel, String classifiedLabel) {
+ incrementCount(correctLabel, classifiedLabel, 1);
+ }
+
+ public ConfusionMatrix merge(ConfusionMatrix b) {
+ Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match");
+ for (String correctLabel : this.labelMap.keySet()) {
+ for (String classifiedLabel : this.labelMap.keySet()) {
+ incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel));
+ }
+ }
+ return this;
+ }
+
+ public Matrix getMatrix() {
+ int length = confusionMatrix.length;
+ Matrix m = new DenseMatrix(length, length);
+ for (int r = 0; r < length; r++) {
+ for (int c = 0; c < length; c++) {
+ m.set(r, c, confusionMatrix[r][c]);
+ }
+ }
+ Map<String,Integer> labels = Maps.newHashMap();
+ for (Map.Entry<String, Integer> entry : labelMap.entrySet()) {
+ labels.put(entry.getKey(), entry.getValue());
+ }
+ m.setRowLabelBindings(labels);
+ m.setColumnLabelBindings(labels);
+ return m;
+ }
+
+ public void setMatrix(Matrix m) {
+ int length = confusionMatrix.length;
+ if (m.numRows() != m.numCols()) {
+ throw new IllegalArgumentException(
+ "ConfusionMatrix: matrix(" + m.numRows() + ',' + m.numCols() + ") must be square");
+ }
+ for (int r = 0; r < length; r++) {
+ for (int c = 0; c < length; c++) {
+ confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
+ }
+ }
+ Map<String,Integer> labels = m.getRowLabelBindings();
+ if (labels == null) {
+ labels = m.getColumnLabelBindings();
+ }
+ if (labels != null) {
+ String[] sorted = sortLabels(labels);
+ verifyLabels(length, sorted);
+ labelMap.clear();
+ for (int i = 0; i < length; i++) {
+ labelMap.put(sorted[i], i);
+ }
+ }
+ }
+
+ private static String[] sortLabels(Map<String,Integer> labels) {
+ String[] sorted = new String[labels.size()];
+ for (Map.Entry<String,Integer> entry : labels.entrySet()) {
+ sorted[entry.getValue()] = entry.getKey();
+ }
+ return sorted;
+ }
+
+ private static void verifyLabels(int length, String[] sorted) {
+ Preconditions.checkArgument(sorted.length == length, "One label, one row");
+ for (int i = 0; i < length; i++) {
+ if (sorted[i] == null) {
+ Preconditions.checkArgument(false, "One label, one row");
+ }
+ }
+ }
+
+ /**
+ * This is overloaded. toString() is not a formatted report you print for a manager :)
+ * Assume that if there are no default assignments, the default feature was not used
+ */
+ @Override
+ public String toString() {
+ StringBuilder returnString = new StringBuilder(200);
+ returnString.append("=======================================================").append('\n');
+ returnString.append("Confusion Matrix\n");
+ returnString.append("-------------------------------------------------------").append('\n');
+
+ int unclassified = getTotal(defaultLabel);
+ for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
+ if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
+
+ returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t');
+ }
+
+ returnString.append("<--Classified as").append('\n');
+ for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
+ if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
+ String correctLabel = entry.getKey();
+ int labelTotal = 0;
+ for (String classifiedLabel : this.labelMap.keySet()) {
+ if (classifiedLabel.equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
+ returnString.append(
+ StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t');
+ labelTotal += getCount(correctLabel, classifiedLabel);
+ }
+ returnString.append(" | ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t')
+ .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
+ .append(" = ").append(correctLabel).append('\n');
+ }
+ if (unclassified > 0) {
+ returnString.append("Default Category: ").append(defaultLabel).append(": ").append(unclassified).append('\n');
+ }
+ returnString.append('\n');
+ return returnString.toString();
+ }
+
+ static String getSmallLabel(int i) {
+ int val = i;
+ StringBuilder returnString = new StringBuilder();
+ do {
+ int n = val % 26;
+ returnString.insert(0, (char) ('a' + n));
+ val /= 26;
+ } while (val > 0);
+ return returnString.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java b/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
new file mode 100644
index 0000000..af1d5e7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import org.apache.mahout.math.Vector;
+
+import java.io.Closeable;
+
+/**
+ * The simplest interface for online learning algorithms.
+ */
+public interface OnlineLearner extends Closeable {
+ /**
+ * Updates the model using a particular target variable value and a feature vector.
+ * <p/>
+ * There may an assumption that if multiple passes through the training data are necessary, then
+ * the training examples will be presented in the same order. This is because the order of
+ * training examples may be used to assign records to different data splits for evaluation by
+ * cross-validation. Without the order invariance, records might be assigned to training and test
+ * splits and error estimates could be seriously affected.
+ * <p/>
+ * If re-ordering is necessary, then using the alternative API which allows a tracking key to be
+ * added to the training example can be used.
+ *
+ * @param actual The value of the target variable. This value should be in the half-open
+ * interval [0..n) where n is the number of target categories.
+ * @param instance The feature vector for this example.
+ */
+ void train(int actual, Vector instance);
+
+ /**
+ * Updates the model using a particular target variable value and a feature vector.
+ * <p/>
+ * There may an assumption that if multiple passes through the training data are necessary that
+ * the tracking key for a record will be the same for each pass and that there will be a
+ * relatively large number of distinct tracking keys and that the low-order bits of the tracking
+ * keys will not correlate with any of the input variables. This tracking key is used to assign
+ * training examples to different test/training splits.
+ * <p/>
+ * Examples of useful tracking keys include id-numbers for the training records derived from
+ * a database id for the base table from the which the record is derived, or the offset of
+ * the original data record in a data file.
+ *
+ * @param trackingKey The tracking key for this training example.
+ * @param groupKey An optional value that allows examples to be grouped in the computation of
+ * the update to the model.
+ * @param actual The value of the target variable. This value should be in the half-open
+ * interval [0..n) where n is the number of target categories.
+ * @param instance The feature vector for this example.
+ */
+ void train(long trackingKey, String groupKey, int actual, Vector instance);
+
+ /**
+ * Updates the model using a particular target variable value and a feature vector.
+ * <p/>
+ * There may an assumption that if multiple passes through the training data are necessary that
+ * the tracking key for a record will be the same for each pass and that there will be a
+ * relatively large number of distinct tracking keys and that the low-order bits of the tracking
+ * keys will not correlate with any of the input variables. This tracking key is used to assign
+ * training examples to different test/training splits.
+ * <p/>
+ * Examples of useful tracking keys include id-numbers for the training records derived from
+ * a database id for the base table from the which the record is derived, or the offset of
+ * the original data record in a data file.
+ *
+ * @param trackingKey The tracking key for this training example.
+ * @param actual The value of the target variable. This value should be in the half-open
+ * interval [0..n) where n is the number of target categories.
+ * @param instance The feature vector for this example.
+ */
+ void train(long trackingKey, int actual, Vector instance);
+
+ /**
+ * Prepares the classifier for classification and deallocates any temporary data structures.
+ *
+ * An online classifier should be able to accept more training after being closed, but
+ * closing the classifier may make classification more efficient.
+ */
+ @Override
+ void close();
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java b/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
new file mode 100644
index 0000000..5d8b9ed
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
@@ -0,0 +1,144 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.text.DecimalFormat;
+import java.text.NumberFormat;
+import java.util.List;
+import java.util.Locale;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.lang3.StringUtils;
+
+/**
+ * ResultAnalyzer captures the classification statistics and displays in a tabular manner
+ */
+public class RegressionResultAnalyzer {
+
+ private static class Result {
+ private final double actual;
+ private final double result;
+ Result(double actual, double result) {
+ this.actual = actual;
+ this.result = result;
+ }
+ double getActual() {
+ return actual;
+ }
+ double getResult() {
+ return result;
+ }
+ }
+
+ private List<Result> results;
+
+ /**
+ *
+ * @param actual
+ * The actual answer
+ * @param result
+ * The regression result
+ */
+ public void addInstance(double actual, double result) {
+ if (results == null) {
+ results = Lists.newArrayList();
+ }
+ results.add(new Result(actual, result));
+ }
+
+ /**
+ *
+ * @param results
+ * The results table
+ */
+ public void setInstances(double[][] results) {
+ for (double[] res : results) {
+ addInstance(res[0], res[1]);
+ }
+ }
+
+ @Override
+ public String toString() {
+ double sumActual = 0.0;
+ double sumActualSquared = 0.0;
+ double sumResult = 0.0;
+ double sumResultSquared = 0.0;
+ double sumActualResult = 0.0;
+ double sumAbsolute = 0.0;
+ double sumAbsoluteSquared = 0.0;
+ int predictable = 0;
+ int unpredictable = 0;
+
+ for (Result res : results) {
+ double actual = res.getActual();
+ double result = res.getResult();
+ if (Double.isNaN(result)) {
+ unpredictable++;
+ } else {
+ sumActual += actual;
+ sumActualSquared += actual * actual;
+ sumResult += result;
+ sumResultSquared += result * result;
+ sumActualResult += actual * result;
+ double absolute = Math.abs(actual - result);
+ sumAbsolute += absolute;
+ sumAbsoluteSquared += absolute * absolute;
+ predictable++;
+ }
+ }
+
+ StringBuilder returnString = new StringBuilder();
+
+ returnString.append("=======================================================\n");
+ returnString.append("Summary\n");
+ returnString.append("-------------------------------------------------------\n");
+
+ if (predictable > 0) {
+ double varActual = sumActualSquared - sumActual * sumActual / predictable;
+ double varResult = sumResultSquared - sumResult * sumResult / predictable;
+ double varCo = sumActualResult - sumActual * sumResult / predictable;
+
+ double correlation;
+ if (varActual * varResult <= 0) {
+ correlation = 0.0;
+ } else {
+ correlation = varCo / Math.sqrt(varActual * varResult);
+ }
+
+ Locale.setDefault(Locale.US);
+ NumberFormat decimalFormatter = new DecimalFormat("0.####");
+
+ returnString.append(StringUtils.rightPad("Correlation coefficient", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(correlation), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Mean absolute error", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(sumAbsolute / predictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Root mean squared error", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(Math.sqrt(sumAbsoluteSquared / predictable)),
+ 10)).append('\n');
+ }
+ returnString.append(StringUtils.rightPad("Predictable Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(predictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Unpredictable Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(unpredictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Total Regressed Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(results.size()), 10)).append('\n');
+ returnString.append('\n');
+
+ return returnString.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java b/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
new file mode 100644
index 0000000..1711f19
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.text.DecimalFormat;
+import java.text.NumberFormat;
+import java.util.Collection;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+/** ResultAnalyzer captures the classification statistics and displays in a tabular manner */
+public class ResultAnalyzer {
+
+ private final ConfusionMatrix confusionMatrix;
+ private final OnlineSummarizer summarizer;
+ private boolean hasLL;
+
+ /*
+ * === Summary ===
+ *
+ * Correctly Classified Instances 635 92.9722 % Incorrectly Classified Instances 48 7.0278 % Kappa statistic
+ * 0.923 Mean absolute error 0.0096 Root mean squared error 0.0817 Relative absolute error 9.9344 % Root
+ * relative squared error 37.2742 % Total Number of Instances 683
+ */
+ private int correctlyClassified;
+ private int incorrectlyClassified;
+
+ public ResultAnalyzer(Collection<String> labelSet, String defaultLabel) {
+ confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel);
+ summarizer = new OnlineSummarizer();
+ }
+
+ public ConfusionMatrix getConfusionMatrix() {
+ return this.confusionMatrix;
+ }
+
+ /**
+ *
+ * @param correctLabel
+ * The correct label
+ * @param classifiedResult
+ * The classified result
+ * @return whether the instance was correct or not
+ */
+ public boolean addInstance(String correctLabel, ClassifierResult classifiedResult) {
+ boolean result = correctLabel.equals(classifiedResult.getLabel());
+ if (result) {
+ correctlyClassified++;
+ } else {
+ incorrectlyClassified++;
+ }
+ confusionMatrix.addInstance(correctLabel, classifiedResult);
+ if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE) {
+ summarizer.add(classifiedResult.getLogLikelihood());
+ hasLL = true;
+ }
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder returnString = new StringBuilder();
+
+ returnString.append('\n');
+ returnString.append("=======================================================\n");
+ returnString.append("Summary\n");
+ returnString.append("-------------------------------------------------------\n");
+ int totalClassified = correctlyClassified + incorrectlyClassified;
+ double percentageCorrect = (double) 100 * correctlyClassified / totalClassified;
+ double percentageIncorrect = (double) 100 * incorrectlyClassified / totalClassified;
+ NumberFormat decimalFormatter = new DecimalFormat("0.####");
+
+ returnString.append(StringUtils.rightPad("Correctly Classified Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(correctlyClassified), 10)).append('\t').append(
+ StringUtils.leftPad(decimalFormatter.format(percentageCorrect), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Incorrectly Classified Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(incorrectlyClassified), 10)).append('\t').append(
+ StringUtils.leftPad(decimalFormatter.format(percentageIncorrect), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Total Classified Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(totalClassified), 10)).append('\n');
+ returnString.append('\n');
+
+ returnString.append(confusionMatrix);
+ returnString.append("=======================================================\n");
+ returnString.append("Statistics\n");
+ returnString.append("-------------------------------------------------------\n");
+
+ RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats();
+ returnString.append(StringUtils.rightPad("Kappa", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Accuracy", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Reliability", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Reliability (standard deviation)", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted precision", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted recall", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted F1 score", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10)).append('\n');
+
+ if (hasLL) {
+ returnString.append(StringUtils.rightPad("Log-likelihood", 30)).append("mean : ").append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getMean()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("25%-ile : ", 10)).append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(1)), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("75%-ile : ", 10)).append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(3)), 10)).append('\n');
+ }
+
+ return returnString.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java b/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java
new file mode 100644
index 0000000..0ec5b55
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.node.Node;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
+import java.util.Random;
+
+/**
+ * Builds a tree using bagging
+ */
+public class Bagging {
+
+ private static final Logger log = LoggerFactory.getLogger(Bagging.class);
+
+ private final TreeBuilder treeBuilder;
+
+ private final Data data;
+
+ private final boolean[] sampled;
+
+ public Bagging(TreeBuilder treeBuilder, Data data) {
+ this.treeBuilder = treeBuilder;
+ this.data = data;
+ sampled = new boolean[data.size()];
+ }
+
+ /**
+ * Builds one tree
+ */
+ public Node build(Random rng) {
+ log.debug("Bagging...");
+ Arrays.fill(sampled, false);
+ Data bag = data.bagging(rng, sampled);
+
+ log.debug("Building...");
+ return treeBuilder.build(rng, bag);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java b/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java
new file mode 100644
index 0000000..137b174
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java
@@ -0,0 +1,181 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+
+/**
+ * Utility class that contains various helper methods
+ */
+public final class DFUtils {
+
+ private DFUtils() {}
+
+ /**
+ * Writes an Node[] into a DataOutput
+ * @throws java.io.IOException
+ */
+ public static void writeArray(DataOutput out, Node[] array) throws IOException {
+ out.writeInt(array.length);
+ for (Node w : array) {
+ w.write(out);
+ }
+ }
+
+ /**
+ * Reads a Node[] from a DataInput
+ * @throws java.io.IOException
+ */
+ public static Node[] readNodeArray(DataInput in) throws IOException {
+ int length = in.readInt();
+ Node[] nodes = new Node[length];
+ for (int index = 0; index < length; index++) {
+ nodes[index] = Node.read(in);
+ }
+
+ return nodes;
+ }
+
+ /**
+ * Writes a double[] into a DataOutput
+ * @throws java.io.IOException
+ */
+ public static void writeArray(DataOutput out, double[] array) throws IOException {
+ out.writeInt(array.length);
+ for (double value : array) {
+ out.writeDouble(value);
+ }
+ }
+
+ /**
+ * Reads a double[] from a DataInput
+ * @throws java.io.IOException
+ */
+ public static double[] readDoubleArray(DataInput in) throws IOException {
+ int length = in.readInt();
+ double[] array = new double[length];
+ for (int index = 0; index < length; index++) {
+ array[index] = in.readDouble();
+ }
+
+ return array;
+ }
+
+ /**
+ * Writes an int[] into a DataOutput
+ * @throws java.io.IOException
+ */
+ public static void writeArray(DataOutput out, int[] array) throws IOException {
+ out.writeInt(array.length);
+ for (int value : array) {
+ out.writeInt(value);
+ }
+ }
+
+ /**
+ * Reads an int[] from a DataInput
+ * @throws java.io.IOException
+ */
+ public static int[] readIntArray(DataInput in) throws IOException {
+ int length = in.readInt();
+ int[] array = new int[length];
+ for (int index = 0; index < length; index++) {
+ array[index] = in.readInt();
+ }
+
+ return array;
+ }
+
+ /**
+ * Return a list of all files in the output directory
+ * @throws IOException if no file is found
+ */
+ public static Path[] listOutputFiles(FileSystem fs, Path outputPath) throws IOException {
+ List<Path> outputFiles = Lists.newArrayList();
+ for (FileStatus s : fs.listStatus(outputPath, PathFilters.logsCRCFilter())) {
+ if (!s.isDir() && !s.getPath().getName().startsWith("_")) {
+ outputFiles.add(s.getPath());
+ }
+ }
+ if (outputFiles.isEmpty()) {
+ throw new IOException("No output found !");
+ }
+ return outputFiles.toArray(new Path[outputFiles.size()]);
+ }
+
+ /**
+ * Formats a time interval in milliseconds to a String in the form "hours:minutes:seconds:millis"
+ */
+ public static String elapsedTime(long milli) {
+ long seconds = milli / 1000;
+ milli %= 1000;
+
+ long minutes = seconds / 60;
+ seconds %= 60;
+
+ long hours = minutes / 60;
+ minutes %= 60;
+
+ return hours + "h " + minutes + "m " + seconds + "s " + milli;
+ }
+
+ public static void storeWritable(Configuration conf, Path path, Writable writable) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+
+ FSDataOutputStream out = fs.create(path);
+ try {
+ writable.write(out);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ /**
+ * Write a string to a path.
+ * @param conf From which the file system will be picked
+ * @param path Where the string will be written
+ * @param string The string to write
+ * @throws IOException if things go poorly
+ */
+ public static void storeString(Configuration conf, Path path, String string) throws IOException {
+ DataOutputStream out = null;
+ try {
+ out = path.getFileSystem(conf).create(path);
+ out.write(string.getBytes(Charset.defaultCharset()));
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java b/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
new file mode 100644
index 0000000..1b47ec7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
@@ -0,0 +1,244 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataUtils;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.Node;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Represents a forest of decision trees.
+ */
+public class DecisionForest implements Writable {
+
+ private final List<Node> trees;
+
+ private DecisionForest() {
+ trees = Lists.newArrayList();
+ }
+
+ public DecisionForest(List<Node> trees) {
+ Preconditions.checkArgument(trees != null && !trees.isEmpty(), "trees argument must not be null or empty");
+
+ this.trees = trees;
+ }
+
+ List<Node> getTrees() {
+ return trees;
+ }
+
+ /**
+ * Classifies the data and calls callback for each classification
+ */
+ public void classify(Data data, double[][] predictions) {
+ Preconditions.checkArgument(data.size() == predictions.length, "predictions.length must be equal to data.size()");
+
+ if (data.isEmpty()) {
+ return; // nothing to classify
+ }
+
+ int treeId = 0;
+ for (Node tree : trees) {
+ for (int index = 0; index < data.size(); index++) {
+ if (predictions[index] == null) {
+ predictions[index] = new double[trees.size()];
+ }
+ predictions[index][treeId] = tree.classify(data.get(index));
+ }
+ treeId++;
+ }
+ }
+
+ /**
+ * predicts the label for the instance
+ *
+ * @param rng
+ * Random number generator, used to break ties randomly
+ * @return NaN if the label cannot be predicted
+ */
+ public double classify(Dataset dataset, Random rng, Instance instance) {
+ if (dataset.isNumerical(dataset.getLabelId())) {
+ double sum = 0;
+ int cnt = 0;
+ for (Node tree : trees) {
+ double prediction = tree.classify(instance);
+ if (!Double.isNaN(prediction)) {
+ sum += prediction;
+ cnt++;
+ }
+ }
+
+ if (cnt > 0) {
+ return sum / cnt;
+ } else {
+ return Double.NaN;
+ }
+ } else {
+ int[] predictions = new int[dataset.nblabels()];
+ for (Node tree : trees) {
+ double prediction = tree.classify(instance);
+ if (!Double.isNaN(prediction)) {
+ predictions[(int) prediction]++;
+ }
+ }
+
+ if (DataUtils.sum(predictions) == 0) {
+ return Double.NaN; // no prediction available
+ }
+
+ return DataUtils.maxindex(rng, predictions);
+ }
+ }
+
+ /**
+ * @return Mean number of nodes per tree
+ */
+ public long meanNbNodes() {
+ long sum = 0;
+
+ for (Node tree : trees) {
+ sum += tree.nbNodes();
+ }
+
+ return sum / trees.size();
+ }
+
+ /**
+ * @return Total number of nodes in all the trees
+ */
+ public long nbNodes() {
+ long sum = 0;
+
+ for (Node tree : trees) {
+ sum += tree.nbNodes();
+ }
+
+ return sum;
+ }
+
+ /**
+ * @return Mean maximum depth per tree
+ */
+ public long meanMaxDepth() {
+ long sum = 0;
+
+ for (Node tree : trees) {
+ sum += tree.maxDepth();
+ }
+
+ return sum / trees.size();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof DecisionForest)) {
+ return false;
+ }
+
+ DecisionForest rf = (DecisionForest) obj;
+
+ return trees.size() == rf.getTrees().size() && trees.containsAll(rf.getTrees());
+ }
+
+ @Override
+ public int hashCode() {
+ return trees.hashCode();
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ dataOutput.writeInt(trees.size());
+ for (Node tree : trees) {
+ tree.write(dataOutput);
+ }
+ }
+
+ /**
+ * Reads the trees from the input and adds them to the existing trees
+ */
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ int size = dataInput.readInt();
+ for (int i = 0; i < size; i++) {
+ trees.add(Node.read(dataInput));
+ }
+ }
+
+ /**
+ * Read the forest from inputStream
+ * @param dataInput - input forest
+ * @return {@link org.apache.mahout.classifier.df.DecisionForest}
+ * @throws IOException
+ */
+ public static DecisionForest read(DataInput dataInput) throws IOException {
+ DecisionForest forest = new DecisionForest();
+ forest.readFields(dataInput);
+ return forest;
+ }
+
+ /**
+ * Load the forest from a single file or a directory of files
+ * @throws java.io.IOException
+ */
+ public static DecisionForest load(Configuration conf, Path forestPath) throws IOException {
+ FileSystem fs = forestPath.getFileSystem(conf);
+ Path[] files;
+ if (fs.getFileStatus(forestPath).isDir()) {
+ files = DFUtils.listOutputFiles(fs, forestPath);
+ } else {
+ files = new Path[]{forestPath};
+ }
+
+ DecisionForest forest = null;
+ for (Path path : files) {
+ FSDataInputStream dataInput = new FSDataInputStream(fs.open(path));
+ try {
+ if (forest == null) {
+ forest = read(dataInput);
+ } else {
+ forest.readFields(dataInput);
+ }
+ } finally {
+ Closeables.close(dataInput, true);
+ }
+ }
+
+ return forest;
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java b/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java
new file mode 100644
index 0000000..2a7facc
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Various methods to compute from the output of a random forest
+ */
+public final class ErrorEstimate {
+
+ private ErrorEstimate() {
+ }
+
+ public static double errorRate(double[] labels, double[] predictions) {
+ Preconditions.checkArgument(labels.length == predictions.length, "labels.length != predictions.length");
+ double nberrors = 0; // number of instance that got bad predictions
+ double datasize = 0; // number of classified instances
+
+ for (int index = 0; index < labels.length; index++) {
+ if (predictions[index] == -1) {
+ continue; // instance not classified
+ }
+
+ if (predictions[index] != labels[index]) {
+ nberrors++;
+ }
+
+ datasize++;
+ }
+
+ return nberrors / datasize;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java b/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java
new file mode 100644
index 0000000..895188b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java
@@ -0,0 +1,421 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import com.google.common.collect.Sets;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+import org.apache.mahout.classifier.df.node.CategoricalNode;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.classifier.df.node.NumericalNode;
+import org.apache.mahout.classifier.df.split.IgSplit;
+import org.apache.mahout.classifier.df.split.OptIgSplit;
+import org.apache.mahout.classifier.df.split.RegressionSplit;
+import org.apache.mahout.classifier.df.split.Split;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collection;
+import java.util.Random;
+
+/**
+ * Builds a classification tree or regression tree<br>
+ * A classification tree is built when the criterion variable is the categorical attribute.<br>
+ * A regression tree is built when the criterion variable is the numerical attribute.
+ */
+public class DecisionTreeBuilder implements TreeBuilder {
+
+ private static final Logger log = LoggerFactory.getLogger(DecisionTreeBuilder.class);
+
+ private static final int[] NO_ATTRIBUTES = new int[0];
+ private static final double EPSILON = 1.0e-6;
+
+ /**
+ * indicates which CATEGORICAL attributes have already been selected in the parent nodes
+ */
+ private boolean[] selected;
+ /**
+ * number of attributes to select randomly at each node
+ */
+ private int m;
+ /**
+ * IgSplit implementation
+ */
+ private IgSplit igSplit;
+ /**
+ * tree is complemented
+ */
+ private boolean complemented = true;
+ /**
+ * minimum number for split
+ */
+ private double minSplitNum = 2.0;
+ /**
+ * minimum proportion of the total variance for split
+ */
+ private double minVarianceProportion = 1.0e-3;
+ /**
+ * full set data
+ */
+ private Data fullSet;
+ /**
+ * minimum variance for split
+ */
+ private double minVariance = Double.NaN;
+
+ public void setM(int m) {
+ this.m = m;
+ }
+
+ public void setIgSplit(IgSplit igSplit) {
+ this.igSplit = igSplit;
+ }
+
+ public void setComplemented(boolean complemented) {
+ this.complemented = complemented;
+ }
+
+ public void setMinSplitNum(int minSplitNum) {
+ this.minSplitNum = minSplitNum;
+ }
+
+ public void setMinVarianceProportion(double minVarianceProportion) {
+ this.minVarianceProportion = minVarianceProportion;
+ }
+
+ @Override
+ public Node build(Random rng, Data data) {
+ if (selected == null) {
+ selected = new boolean[data.getDataset().nbAttributes()];
+ selected[data.getDataset().getLabelId()] = true; // never select the label
+ }
+ if (m == 0) {
+ // set default m
+ double e = data.getDataset().nbAttributes() - 1;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ // regression
+ m = (int) Math.ceil(e / 3.0);
+ } else {
+ // classification
+ m = (int) Math.ceil(Math.sqrt(e));
+ }
+ }
+
+ if (data.isEmpty()) {
+ return new Leaf(Double.NaN);
+ }
+
+ double sum = 0.0;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ // regression
+ // sum and sum squared of a label is computed
+ double sumSquared = 0.0;
+ for (int i = 0; i < data.size(); i++) {
+ double label = data.getDataset().getLabel(data.get(i));
+ sum += label;
+ sumSquared += label * label;
+ }
+
+ // computes the variance
+ double var = sumSquared - (sum * sum) / data.size();
+
+ // computes the minimum variance
+ if (Double.compare(minVariance, Double.NaN) == 0) {
+ minVariance = var / data.size() * minVarianceProportion;
+ log.debug("minVariance:{}", minVariance);
+ }
+
+ // variance is compared with minimum variance
+ if ((var / data.size()) < minVariance) {
+ log.debug("variance({}) < minVariance({}) Leaf({})", var / data.size(), minVariance, sum / data.size());
+ return new Leaf(sum / data.size());
+ }
+ } else {
+ // classification
+ if (isIdentical(data)) {
+ return new Leaf(data.majorityLabel(rng));
+ }
+ if (data.identicalLabel()) {
+ return new Leaf(data.getDataset().getLabel(data.get(0)));
+ }
+ }
+
+ // store full set data
+ if (fullSet == null) {
+ fullSet = data;
+ }
+
+ int[] attributes = randomAttributes(rng, selected, m);
+ if (attributes == null || attributes.length == 0) {
+ // we tried all the attributes and could not split the data anymore
+ double label;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ // regression
+ label = sum / data.size();
+ } else {
+ // classification
+ label = data.majorityLabel(rng);
+ }
+ log.warn("attribute which can be selected is not found Leaf({})", label);
+ return new Leaf(label);
+ }
+
+ if (igSplit == null) {
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ // regression
+ igSplit = new RegressionSplit();
+ } else {
+ // classification
+ igSplit = new OptIgSplit();
+ }
+ }
+
+ // find the best split
+ Split best = null;
+ for (int attr : attributes) {
+ Split split = igSplit.computeSplit(data, attr);
+ if (best == null || best.getIg() < split.getIg()) {
+ best = split;
+ }
+ }
+
+ // information gain is near to zero.
+ if (best.getIg() < EPSILON) {
+ double label;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ label = sum / data.size();
+ } else {
+ label = data.majorityLabel(rng);
+ }
+ log.debug("ig is near to zero Leaf({})", label);
+ return new Leaf(label);
+ }
+
+ log.debug("best split attr:{}, split:{}, ig:{}", best.getAttr(), best.getSplit(), best.getIg());
+
+ boolean alreadySelected = selected[best.getAttr()];
+ if (alreadySelected) {
+ // attribute already selected
+ log.warn("attribute {} already selected in a parent node", best.getAttr());
+ }
+
+ Node childNode;
+ if (data.getDataset().isNumerical(best.getAttr())) {
+ boolean[] temp = null;
+
+ Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
+ Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));
+
+ if (loSubset.isEmpty() || hiSubset.isEmpty()) {
+ // the selected attribute did not change the data, avoid using it in the child notes
+ selected[best.getAttr()] = true;
+ } else {
+ // the data changed, so we can unselect all previousely selected NUMERICAL attributes
+ temp = selected;
+ selected = cloneCategoricalAttributes(data.getDataset(), selected);
+ }
+
+ // size of the subset is less than the minSpitNum
+ if (loSubset.size() < minSplitNum || hiSubset.size() < minSplitNum) {
+ // branch is not split
+ double label;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ label = sum / data.size();
+ } else {
+ label = data.majorityLabel(rng);
+ }
+ log.debug("branch is not split Leaf({})", label);
+ return new Leaf(label);
+ }
+
+ Node loChild = build(rng, loSubset);
+ Node hiChild = build(rng, hiSubset);
+
+ // restore the selection state of the attributes
+ if (temp != null) {
+ selected = temp;
+ } else {
+ selected[best.getAttr()] = alreadySelected;
+ }
+
+ childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
+ } else { // CATEGORICAL attribute
+ double[] values = data.values(best.getAttr());
+
+ // tree is complemented
+ Collection<Double> subsetValues = null;
+ if (complemented) {
+ subsetValues = Sets.newHashSet();
+ for (double value : values) {
+ subsetValues.add(value);
+ }
+ values = fullSet.values(best.getAttr());
+ }
+
+ int cnt = 0;
+ Data[] subsets = new Data[values.length];
+ for (int index = 0; index < values.length; index++) {
+ if (complemented && !subsetValues.contains(values[index])) {
+ continue;
+ }
+ subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index]));
+ if (subsets[index].size() >= minSplitNum) {
+ cnt++;
+ }
+ }
+
+ // size of the subset is less than the minSpitNum
+ if (cnt < 2) {
+ // branch is not split
+ double label;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ label = sum / data.size();
+ } else {
+ label = data.majorityLabel(rng);
+ }
+ log.debug("branch is not split Leaf({})", label);
+ return new Leaf(label);
+ }
+
+ selected[best.getAttr()] = true;
+
+ Node[] children = new Node[values.length];
+ for (int index = 0; index < values.length; index++) {
+ if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) {
+ // tree is complemented
+ double label;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ label = sum / data.size();
+ } else {
+ label = data.majorityLabel(rng);
+ }
+ log.debug("complemented Leaf({})", label);
+ children[index] = new Leaf(label);
+ continue;
+ }
+ children[index] = build(rng, subsets[index]);
+ }
+
+ selected[best.getAttr()] = alreadySelected;
+
+ childNode = new CategoricalNode(best.getAttr(), values, children);
+ }
+
+ return childNode;
+ }
+
+ /**
+ * checks if all the vectors have identical attribute values. Ignore selected attributes.
+ *
+ * @return true is all the vectors are identical or the data is empty<br>
+ * false otherwise
+ */
+ private boolean isIdentical(Data data) {
+ if (data.isEmpty()) {
+ return true;
+ }
+
+ Instance instance = data.get(0);
+ for (int attr = 0; attr < selected.length; attr++) {
+ if (selected[attr]) {
+ continue;
+ }
+
+ for (int index = 1; index < data.size(); index++) {
+ if (data.get(index).get(attr) != instance.get(attr)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+ /**
+ * Make a copy of the selection state of the attributes, unselect all numerical attributes
+ *
+ * @param selected selection state to clone
+ * @return cloned selection state
+ */
+ private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) {
+ boolean[] cloned = new boolean[selected.length];
+
+ for (int i = 0; i < selected.length; i++) {
+ cloned[i] = !dataset.isNumerical(i) && selected[i];
+ }
+ cloned[dataset.getLabelId()] = true;
+
+ return cloned;
+ }
+
+ /**
+ * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes
+ *
+ * @param rng random-numbers generator
+ * @param selected attributes' state (selected or not)
+ * @param m number of attributes to choose
+ * @return list of selected attributes' indices, or null if all attributes have already been selected
+ */
+ private static int[] randomAttributes(Random rng, boolean[] selected, int m) {
+ int nbNonSelected = 0; // number of non selected attributes
+ for (boolean sel : selected) {
+ if (!sel) {
+ nbNonSelected++;
+ }
+ }
+
+ if (nbNonSelected == 0) {
+ log.warn("All attributes are selected !");
+ return NO_ATTRIBUTES;
+ }
+
+ int[] result;
+ if (nbNonSelected <= m) {
+ // return all non selected attributes
+ result = new int[nbNonSelected];
+ int index = 0;
+ for (int attr = 0; attr < selected.length; attr++) {
+ if (!selected[attr]) {
+ result[index++] = attr;
+ }
+ }
+ } else {
+ result = new int[m];
+ for (int index = 0; index < m; index++) {
+ // randomly choose a "non selected" attribute
+ int rind;
+ do {
+ rind = rng.nextInt(selected.length);
+ } while (selected[rind]);
+
+ result[index] = rind;
+ selected[rind] = true; // temporarily set the chosen attribute to be selected
+ }
+
+ // the chosen attributes are not yet selected
+ for (int attr : result) {
+ selected[attr] = false;
+ }
+ }
+
+ return result;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java b/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java
new file mode 100644
index 0000000..f03698d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java
@@ -0,0 +1,252 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+import org.apache.mahout.classifier.df.node.CategoricalNode;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.classifier.df.node.NumericalNode;
+import org.apache.mahout.classifier.df.split.IgSplit;
+import org.apache.mahout.classifier.df.split.OptIgSplit;
+import org.apache.mahout.classifier.df.split.Split;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Random;
+
+/**
+ * Builds a Decision Tree <br>
+ * Based on the algorithm described in the "Decision Trees" tutorials by Andrew W. Moore, available at:<br>
+ * <br>
+ * http://www.cs.cmu.edu/~awm/tutorials
+ * <br><br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+public class DefaultTreeBuilder implements TreeBuilder {
+
+ private static final Logger log = LoggerFactory.getLogger(DefaultTreeBuilder.class);
+
+ private static final int[] NO_ATTRIBUTES = new int[0];
+
+ /**
+ * indicates which CATEGORICAL attributes have already been selected in the parent nodes
+ */
+ private boolean[] selected;
+ /**
+ * number of attributes to select randomly at each node
+ */
+ private int m = 1;
+ /**
+ * IgSplit implementation
+ */
+ private final IgSplit igSplit;
+
+ public DefaultTreeBuilder() {
+ igSplit = new OptIgSplit();
+ }
+
+ public void setM(int m) {
+ this.m = m;
+ }
+
+ @Override
+ public Node build(Random rng, Data data) {
+
+ if (selected == null) {
+ selected = new boolean[data.getDataset().nbAttributes()];
+ selected[data.getDataset().getLabelId()] = true; // never select the label
+ }
+
+ if (data.isEmpty()) {
+ return new Leaf(-1);
+ }
+ if (isIdentical(data)) {
+ return new Leaf(data.majorityLabel(rng));
+ }
+ if (data.identicalLabel()) {
+ return new Leaf(data.getDataset().getLabel(data.get(0)));
+ }
+
+ int[] attributes = randomAttributes(rng, selected, m);
+ if (attributes == null || attributes.length == 0) {
+ // we tried all the attributes and could not split the data anymore
+ return new Leaf(data.majorityLabel(rng));
+ }
+
+ // find the best split
+ Split best = null;
+ for (int attr : attributes) {
+ Split split = igSplit.computeSplit(data, attr);
+ if (best == null || best.getIg() < split.getIg()) {
+ best = split;
+ }
+ }
+
+ boolean alreadySelected = selected[best.getAttr()];
+ if (alreadySelected) {
+ // attribute already selected
+ log.warn("attribute {} already selected in a parent node", best.getAttr());
+ }
+
+ Node childNode;
+ if (data.getDataset().isNumerical(best.getAttr())) {
+ boolean[] temp = null;
+
+ Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
+ Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));
+
+ if (loSubset.isEmpty() || hiSubset.isEmpty()) {
+ // the selected attribute did not change the data, avoid using it in the child notes
+ selected[best.getAttr()] = true;
+ } else {
+ // the data changed, so we can unselect all previousely selected NUMERICAL attributes
+ temp = selected;
+ selected = cloneCategoricalAttributes(data.getDataset(), selected);
+ }
+
+ Node loChild = build(rng, loSubset);
+ Node hiChild = build(rng, hiSubset);
+
+ // restore the selection state of the attributes
+ if (temp != null) {
+ selected = temp;
+ } else {
+ selected[best.getAttr()] = alreadySelected;
+ }
+
+ childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
+ } else { // CATEGORICAL attribute
+ selected[best.getAttr()] = true;
+
+ double[] values = data.values(best.getAttr());
+ Node[] children = new Node[values.length];
+
+ for (int index = 0; index < values.length; index++) {
+ Data subset = data.subset(Condition.equals(best.getAttr(), values[index]));
+ children[index] = build(rng, subset);
+ }
+
+ selected[best.getAttr()] = alreadySelected;
+
+ childNode = new CategoricalNode(best.getAttr(), values, children);
+ }
+
+ return childNode;
+ }
+
+ /**
+ * checks if all the vectors have identical attribute values. Ignore selected attributes.
+ *
+ * @return true is all the vectors are identical or the data is empty<br>
+ * false otherwise
+ */
+ private boolean isIdentical(Data data) {
+ if (data.isEmpty()) {
+ return true;
+ }
+
+ Instance instance = data.get(0);
+ for (int attr = 0; attr < selected.length; attr++) {
+ if (selected[attr]) {
+ continue;
+ }
+
+ for (int index = 1; index < data.size(); index++) {
+ if (data.get(index).get(attr) != instance.get(attr)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+
+ /**
+ * Make a copy of the selection state of the attributes, unselect all numerical attributes
+ *
+ * @param selected selection state to clone
+ * @return cloned selection state
+ */
+ private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) {
+ boolean[] cloned = new boolean[selected.length];
+
+ for (int i = 0; i < selected.length; i++) {
+ cloned[i] = !dataset.isNumerical(i) && selected[i];
+ }
+
+ return cloned;
+ }
+
+ /**
+ * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes
+ *
+ * @param rng random-numbers generator
+ * @param selected attributes' state (selected or not)
+ * @param m number of attributes to choose
+ * @return list of selected attributes' indices, or null if all attributes have already been selected
+ */
+ protected static int[] randomAttributes(Random rng, boolean[] selected, int m) {
+ int nbNonSelected = 0; // number of non selected attributes
+ for (boolean sel : selected) {
+ if (!sel) {
+ nbNonSelected++;
+ }
+ }
+
+ if (nbNonSelected == 0) {
+ log.warn("All attributes are selected !");
+ return NO_ATTRIBUTES;
+ }
+
+ int[] result;
+ if (nbNonSelected <= m) {
+ // return all non selected attributes
+ result = new int[nbNonSelected];
+ int index = 0;
+ for (int attr = 0; attr < selected.length; attr++) {
+ if (!selected[attr]) {
+ result[index++] = attr;
+ }
+ }
+ } else {
+ result = new int[m];
+ for (int index = 0; index < m; index++) {
+ // randomly choose a "non selected" attribute
+ int rind;
+ do {
+ rind = rng.nextInt(selected.length);
+ } while (selected[rind]);
+
+ result[index] = rind;
+ selected[rind] = true; // temporarily set the chosen attribute to be selected
+ }
+
+ // the chosen attributes are not yet selected
+ for (int attr : result) {
+ selected[attr] = false;
+ }
+ }
+
+ return result;
+ }
+}