You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2014/02/25 15:15:06 UTC
svn commit: r1571704 - in /mahout/trunk: ./
core/src/main/java/org/apache/mahout/classifier/df/split/
core/src/test/java/org/apache/mahout/classifier/df/split/
core/src/test/java/org/apache/mahout/classifier/df/tools/
examples/src/main/java/org/apache/...
Author: srowen
Date: Tue Feb 25 14:15:06 2014
New Revision: 1571704
URL: http://svn.apache.org/r1571704
Log:
MAHOUT-1419: Random decision forest is excessively slow on numeric features
Removed:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/split/OptIgSplitTest.java
Modified:
mahout/trunk/CHANGELOG
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1571704&r1=1571703&r2=1571704&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Tue Feb 25 14:15:06 2014
@@ -2,6 +2,8 @@ Mahout Change Log
Release 1.0 - unreleased
+ MAHOUT-1419: Random decision forest is excessively slow on numeric features (srowen)
+
MAHOUT-1329: Mahout for hadoop 2 (gcapan, Sergey Svinarchuk)
MAHOUT-1417: Random decision forest implementation fails in Hadoop 2 (srowen)
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java?rev=1571704&r1=1571703&r2=1571704&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java Tue Feb 25 14:15:06 2014
@@ -1,178 +1,231 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.df.split;
-
-import org.apache.commons.lang3.ArrayUtils;
-import org.apache.mahout.classifier.df.data.Data;
-import org.apache.mahout.classifier.df.data.DataUtils;
-import org.apache.mahout.classifier.df.data.Dataset;
-import org.apache.mahout.classifier.df.data.Instance;
-
-import java.util.Arrays;
-
-/**
- * Optimized implementation of IgSplit<br>
- * This class can be used when the criterion variable is the categorical attribute.
- */
-public class OptIgSplit extends IgSplit {
-
- private int[][] counts;
-
- private int[] countAll;
-
- private int[] countLess;
-
- @Override
- public Split computeSplit(Data data, int attr) {
- if (data.getDataset().isNumerical(attr)) {
- return numericalSplit(data, attr);
- } else {
- return categoricalSplit(data, attr);
- }
- }
-
- /**
- * Computes the split for a CATEGORICAL attribute
- */
- private static Split categoricalSplit(Data data, int attr) {
- double[] values = data.values(attr);
- int[][] counts = new int[values.length][data.getDataset().nblabels()];
- int[] countAll = new int[data.getDataset().nblabels()];
-
- Dataset dataset = data.getDataset();
-
- // compute frequencies
- for (int index = 0; index < data.size(); index++) {
- Instance instance = data.get(index);
- counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;
- countAll[(int) dataset.getLabel(instance)]++;
- }
-
- int size = data.size();
- double hy = entropy(countAll, size); // H(Y)
- double hyx = 0.0; // H(Y|X)
- double invDataSize = 1.0 / size;
-
- for (int index = 0; index < values.length; index++) {
- size = DataUtils.sum(counts[index]);
- hyx += size * invDataSize * entropy(counts[index], size);
- }
-
- double ig = hy - hyx;
- return new Split(attr, ig);
- }
-
- /**
- * Return the sorted list of distinct values for the given attribute
- */
- private static double[] sortedValues(Data data, int attr) {
- double[] values = data.values(attr);
- Arrays.sort(values);
-
- return values;
- }
-
- /**
- * Instantiates the counting arrays
- */
- void initCounts(Data data, double[] values) {
- counts = new int[values.length][data.getDataset().nblabels()];
- countAll = new int[data.getDataset().nblabels()];
- countLess = new int[data.getDataset().nblabels()];
- }
-
- void computeFrequencies(Data data, int attr, double[] values) {
- Dataset dataset = data.getDataset();
-
- for (int index = 0; index < data.size(); index++) {
- Instance instance = data.get(index);
- counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;
- countAll[(int) dataset.getLabel(instance)]++;
- }
- }
-
- /**
- * Computes the best split for a NUMERICAL attribute
- */
- Split numericalSplit(Data data, int attr) {
- double[] values = sortedValues(data, attr);
-
- initCounts(data, values);
-
- computeFrequencies(data, attr, values);
-
- int size = data.size();
- double hy = entropy(countAll, size);
- double invDataSize = 1.0 / size;
-
- int best = -1;
- double bestIg = -1.0;
-
- // try each possible split value
- for (int index = 0; index < values.length; index++) {
- double ig = hy;
-
- // instance with attribute value < values[index]
- size = DataUtils.sum(countLess);
- ig -= size * invDataSize * entropy(countLess, size);
-
- // instance with attribute value >= values[index]
- size = DataUtils.sum(countAll);
- ig -= size * invDataSize * entropy(countAll, size);
-
- if (ig > bestIg) {
- bestIg = ig;
- best = index;
- }
-
- DataUtils.add(countLess, counts[index]);
- DataUtils.dec(countAll, counts[index]);
- }
-
- if (best == -1) {
- throw new IllegalStateException("no best split found !");
- }
- return new Split(attr, bestIg, values[best]);
- }
-
- /**
- * Computes the Entropy
- *
- * @param counts counts[i] = numInstances with label i
- * @param dataSize numInstances
- */
- private static double entropy(int[] counts, int dataSize) {
- if (dataSize == 0) {
- return 0.0;
- }
-
- double entropy = 0.0;
- double invDataSize = 1.0 / dataSize;
-
- for (int count : counts) {
- if (count == 0) {
- continue; // otherwise we get a NaN
- }
- double p = count * invDataSize;
- entropy += -p * Math.log(p) / LOG2;
- }
-
- return entropy;
- }
-
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.commons.math3.stat.descriptive.rank.Percentile;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataUtils;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.TreeSet;
+
+/**
+ * <p>Optimized implementation of IgSplit.
+ * This class can be used when the criterion variable is the categorical attribute.</p>
+ *
+ * <p>This code was changed in MAHOUT-1419 to deal in sampled splits among numeric
+ * features to fix a performance problem. To generate some synthetic data that exercises
+ * the issue, try for example generating 4 features of Normal(0,1) values with a random
+ * boolean 0/1 categorical feature. In Scala:</p>
+ *
+ * {@code
+ * val r = new scala.util.Random()
+ * val pw = new java.io.PrintWriter("random.csv")
+ * (1 to 10000000).foreach(e =>
+ * pw.println(r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * (if (r.nextBoolean()) 1 else 0))
+ * )
+ * pw.close()
+ * }
+ */
+public class OptIgSplit extends IgSplit {
+
+ private static final int MAX_NUMERIC_SPLITS = 16;
+
+ @Override
+ public Split computeSplit(Data data, int attr) {
+ if (data.getDataset().isNumerical(attr)) {
+ return numericalSplit(data, attr);
+ } else {
+ return categoricalSplit(data, attr);
+ }
+ }
+
+ /**
+ * Computes the split for a CATEGORICAL attribute
+ */
+ private static Split categoricalSplit(Data data, int attr) {
+ double[] values = data.values(attr).clone();
+
+ double[] splitPoints = chooseCategoricalSplitPoints(values);
+
+ int numLabels = data.getDataset().nblabels();
+ int[][] counts = new int[splitPoints.length][numLabels];
+ int[] countAll = new int[numLabels];
+
+ computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+ int size = data.size();
+ double hy = entropy(countAll, size); // H(Y)
+ double hyx = 0.0; // H(Y|X)
+ double invDataSize = 1.0 / size;
+
+ for (int index = 0; index < splitPoints.length; index++) {
+ size = DataUtils.sum(counts[index]);
+ hyx += size * invDataSize * entropy(counts[index], size);
+ }
+
+ double ig = hy - hyx;
+ return new Split(attr, ig);
+ }
+
+ static void computeFrequencies(Data data,
+ int attr,
+ double[] splitPoints,
+ int[][] counts,
+ int[] countAll) {
+ Dataset dataset = data.getDataset();
+
+ for (int index = 0; index < data.size(); index++) {
+ Instance instance = data.get(index);
+ int label = (int) dataset.getLabel(instance);
+ double value = instance.get(attr);
+ int split = 0;
+ while (split < splitPoints.length && value > splitPoints[split]) {
+ split++;
+ }
+ if (split < splitPoints.length) {
+ counts[split][label]++;
+ } // Otherwise it's in the last split, which we don't need to count
+ countAll[label]++;
+ }
+ }
+
+ /**
+ * Computes the best split for a NUMERICAL attribute
+ */
+ static Split numericalSplit(Data data, int attr) {
+ double[] values = data.values(attr).clone();
+ Arrays.sort(values);
+
+ double[] splitPoints = chooseNumericSplitPoints(values);
+
+ int numLabels = data.getDataset().nblabels();
+ int[][] counts = new int[splitPoints.length][numLabels];
+ int[] countAll = new int[numLabels];
+ int[] countLess = new int[numLabels];
+
+ computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+ int size = data.size();
+ double hy = entropy(countAll, size);
+ double invDataSize = 1.0 / size;
+
+ int best = -1;
+ double bestIg = -1.0;
+
+ // try each possible split value
+ for (int index = 0; index < splitPoints.length; index++) {
+ double ig = hy;
+
+ DataUtils.add(countLess, counts[index]);
+ DataUtils.dec(countAll, counts[index]);
+
+ // instance with attribute value < values[index]
+ size = DataUtils.sum(countLess);
+ ig -= size * invDataSize * entropy(countLess, size);
+ // instance with attribute value >= values[index]
+ size = DataUtils.sum(countAll);
+ ig -= size * invDataSize * entropy(countAll, size);
+
+ if (ig > bestIg) {
+ bestIg = ig;
+ best = index;
+ }
+ }
+
+ if (best == -1) {
+ throw new IllegalStateException("no best split found !");
+ }
+ return new Split(attr, bestIg, splitPoints[best]);
+ }
+
+ /**
+ * @return an array of values to split the numeric feature's values on when
+ * building candidate splits. When input size is <= MAX_NUMERIC_SPLITS + 1, it will
+ * return the averages between success values as split points. When larger, it will
+ * return MAX_NUMERIC_SPLITS approximate percentiles through the data.
+ */
+ private static double[] chooseNumericSplitPoints(double[] values) {
+ if (values.length <= 1) {
+ return values;
+ }
+ if (values.length <= MAX_NUMERIC_SPLITS + 1) {
+ double[] splitPoints = new double[values.length - 1];
+ for (int i = 1; i < values.length; i++) {
+ splitPoints[i-1] = (values[i] + values[i-1]) / 2.0;
+ }
+ return splitPoints;
+ }
+ Percentile distribution = new Percentile();
+ distribution.setData(values);
+ double[] percentiles = new double[MAX_NUMERIC_SPLITS];
+ for (int i = 0 ; i < percentiles.length; i++) {
+ double p = 100.0 * ((i + 1.0) / (MAX_NUMERIC_SPLITS + 1.0));
+ percentiles[i] = distribution.evaluate(p);
+ }
+ return percentiles;
+ }
+
+ private static double[] chooseCategoricalSplitPoints(double[] values) {
+ // There is no great reason to believe that categorical value order matters,
+ // but the original code worked this way, and it's not terrible in the absence
+ // of more sophisticated analysis
+ Collection<Double> uniqueOrderedCategories = new TreeSet<Double>();
+ for (double v : values) {
+ uniqueOrderedCategories.add(v);
+ }
+ double[] uniqueValues = new double[uniqueOrderedCategories.size()];
+ Iterator<Double> it = uniqueOrderedCategories.iterator();
+ for (int i = 0; i < uniqueValues.length; i++) {
+ uniqueValues[i] = it.next();
+ }
+ return uniqueValues;
+ }
+
+ /**
+ * Computes the Entropy
+ *
+ * @param counts counts[i] = numInstances with label i
+ * @param dataSize numInstances
+ */
+ private static double entropy(int[] counts, int dataSize) {
+ if (dataSize == 0) {
+ return 0.0;
+ }
+
+ double entropy = 0.0;
+
+ for (int count : counts) {
+ if (count > 0) {
+ double p = count / (double) dataSize;
+ entropy -= p * Math.log(p);
+ }
+ }
+
+ return entropy / LOG2;
+ }
+
+}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java?rev=1571704&r1=1571703&r2=1571704&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java Tue Feb 25 14:15:06 2014
@@ -86,7 +86,7 @@ public final class VisualizerTest extend
Node tree = builder.build(rng, data);
assertEquals("\noutlook = rainy\n| windy = FALSE : yes\n| windy = TRUE : no\n"
- + "outlook = sunny\n| humidity < 85 : yes\n| humidity >= 85 : no\n"
+ + "outlook = sunny\n| humidity < 77.5 : yes\n| humidity >= 77.5 : no\n"
+ "outlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
}
@@ -101,7 +101,7 @@ public final class VisualizerTest extend
ATTR_NAMES);
Assert.assertArrayEquals(new String[] {
"outlook = rainy -> windy = TRUE -> no", "outlook = overcast -> yes",
- "outlook = sunny -> (humidity = 90) >= 85 -> no"}, prediction);
+ "outlook = sunny -> (humidity = 90) >= 77.5 -> no"}, prediction);
}
@Test
@@ -142,7 +142,7 @@ public final class VisualizerTest extend
builder.setComplemented(false);
Node tree = builder.build(rng, lessData);
- assertEquals("\noutlook = sunny\n| humidity < 85 : yes\n| humidity >= 85 : no\noutlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
+ assertEquals("\noutlook = sunny\n| humidity < 77.5 : yes\n| humidity >= 77.5 : no\noutlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
}
@Test
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java?rev=1571704&r1=1571703&r2=1571704&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java Tue Feb 25 14:15:06 2014
@@ -226,6 +226,9 @@ public class BuildForest extends Configu
long time = System.currentTimeMillis();
DecisionForest forest = forestBuilder.build(nbTrees);
+ if (forest == null) {
+ return;
+ }
time = System.currentTimeMillis() - time;
log.info("Build Time: {}", DFUtils.elapsedTime(time));