You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2014/02/25 15:15:06 UTC

svn commit: r1571704 - in /mahout/trunk: ./ core/src/main/java/org/apache/mahout/classifier/df/split/ core/src/test/java/org/apache/mahout/classifier/df/split/ core/src/test/java/org/apache/mahout/classifier/df/tools/ examples/src/main/java/org/apache/...

Author: srowen
Date: Tue Feb 25 14:15:06 2014
New Revision: 1571704

URL: http://svn.apache.org/r1571704
Log:
MAHOUT-1419: Random decision forest is excessively slow on numeric features

Removed:
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/split/OptIgSplitTest.java
Modified:
    mahout/trunk/CHANGELOG
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java

Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1571704&r1=1571703&r2=1571704&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Tue Feb 25 14:15:06 2014
@@ -2,6 +2,8 @@ Mahout Change Log
 
 Release 1.0 - unreleased
 
+  MAHOUT-1419: Random decision forest is excessively slow on numeric features (srowen)
+
   MAHOUT-1329: Mahout for hadoop 2  (gcapan, Sergey Svinarchuk)
 
   MAHOUT-1417: Random decision forest implementation fails in Hadoop 2 (srowen)

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java?rev=1571704&r1=1571703&r2=1571704&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java Tue Feb 25 14:15:06 2014
@@ -1,178 +1,231 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.df.split;
-
-import org.apache.commons.lang3.ArrayUtils;
-import org.apache.mahout.classifier.df.data.Data;
-import org.apache.mahout.classifier.df.data.DataUtils;
-import org.apache.mahout.classifier.df.data.Dataset;
-import org.apache.mahout.classifier.df.data.Instance;
-
-import java.util.Arrays;
-
-/**
- * Optimized implementation of IgSplit<br>
- * This class can be used when the criterion variable is the categorical attribute.
- */
-public class OptIgSplit extends IgSplit {
-
-  private int[][] counts;
-
-  private int[] countAll;
-
-  private int[] countLess;
-
-  @Override
-  public Split computeSplit(Data data, int attr) {
-    if (data.getDataset().isNumerical(attr)) {
-      return numericalSplit(data, attr);
-    } else {
-      return categoricalSplit(data, attr);
-    }
-  }
-
-  /**
-   * Computes the split for a CATEGORICAL attribute
-   */
-  private static Split categoricalSplit(Data data, int attr) {
-    double[] values = data.values(attr);
-    int[][] counts = new int[values.length][data.getDataset().nblabels()];
-    int[] countAll = new int[data.getDataset().nblabels()];
-
-    Dataset dataset = data.getDataset();
-
-    // compute frequencies
-    for (int index = 0; index < data.size(); index++) {
-      Instance instance = data.get(index);
-      counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;
-      countAll[(int) dataset.getLabel(instance)]++;
-    }
-
-    int size = data.size();
-    double hy = entropy(countAll, size); // H(Y)
-    double hyx = 0.0; // H(Y|X)
-    double invDataSize = 1.0 / size;
-
-    for (int index = 0; index < values.length; index++) {
-      size = DataUtils.sum(counts[index]);
-      hyx += size * invDataSize * entropy(counts[index], size);
-    }
-
-    double ig = hy - hyx;
-    return new Split(attr, ig);
-  }
-
-  /**
-   * Return the sorted list of distinct values for the given attribute
-   */
-  private static double[] sortedValues(Data data, int attr) {
-    double[] values = data.values(attr);
-    Arrays.sort(values);
-
-    return values;
-  }
-
-  /**
-   * Instantiates the counting arrays
-   */
-  void initCounts(Data data, double[] values) {
-    counts = new int[values.length][data.getDataset().nblabels()];
-    countAll = new int[data.getDataset().nblabels()];
-    countLess = new int[data.getDataset().nblabels()];
-  }
-
-  void computeFrequencies(Data data, int attr, double[] values) {
-    Dataset dataset = data.getDataset();
-
-    for (int index = 0; index < data.size(); index++) {
-      Instance instance = data.get(index);
-      counts[ArrayUtils.indexOf(values, instance.get(attr))][(int) dataset.getLabel(instance)]++;
-      countAll[(int) dataset.getLabel(instance)]++;
-    }
-  }
-
-  /**
-   * Computes the best split for a NUMERICAL attribute
-   */
-  Split numericalSplit(Data data, int attr) {
-    double[] values = sortedValues(data, attr);
-
-    initCounts(data, values);
-
-    computeFrequencies(data, attr, values);
-
-    int size = data.size();
-    double hy = entropy(countAll, size);
-    double invDataSize = 1.0 / size;
-
-    int best = -1;
-    double bestIg = -1.0;
-
-    // try each possible split value
-    for (int index = 0; index < values.length; index++) {
-      double ig = hy;
-
-      // instance with attribute value < values[index]
-      size = DataUtils.sum(countLess);
-      ig -= size * invDataSize * entropy(countLess, size);
-
-      // instance with attribute value >= values[index]
-      size = DataUtils.sum(countAll);
-      ig -= size * invDataSize * entropy(countAll, size);
-
-      if (ig > bestIg) {
-        bestIg = ig;
-        best = index;
-      }
-
-      DataUtils.add(countLess, counts[index]);
-      DataUtils.dec(countAll, counts[index]);
-    }
-
-    if (best == -1) {
-      throw new IllegalStateException("no best split found !");
-    }
-    return new Split(attr, bestIg, values[best]);
-  }
-
-  /**
-   * Computes the Entropy
-   *
-   * @param counts   counts[i] = numInstances with label i
-   * @param dataSize numInstances
-   */
-  private static double entropy(int[] counts, int dataSize) {
-    if (dataSize == 0) {
-      return 0.0;
-    }
-
-    double entropy = 0.0;
-    double invDataSize = 1.0 / dataSize;
-
-    for (int count : counts) {
-      if (count == 0) {
-        continue; // otherwise we get a NaN
-      }
-      double p = count * invDataSize;
-      entropy += -p * Math.log(p) / LOG2;
-    }
-
-    return entropy;
-  }
-
-}
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.commons.math3.stat.descriptive.rank.Percentile;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataUtils;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.TreeSet;
+
+/**
+ * <p>Optimized implementation of IgSplit.
+ * This class can be used when the criterion variable is the categorical attribute.</p>
+ *
+ * <p>This code was changed in MAHOUT-1419 to deal in sampled splits among numeric
+ * features to fix a performance problem. To generate some synthetic data that exercises
+ * the issue, try for example generating 4 features of Normal(0,1) values with a random
+ * boolean 0/1 categorical feature. In Scala:</p>
+ *
+ * {@code
+ *  val r = new scala.util.Random()
+ *  val pw = new java.io.PrintWriter("random.csv")
+ *  (1 to 10000000).foreach(e =>
+ *    pw.println(r.nextDouble() + "," +
+ *               r.nextDouble() + "," +
+ *               r.nextDouble() + "," +
+ *               r.nextDouble() + "," +
+ *               (if (r.nextBoolean()) 1 else 0))
+ *   )
+ *   pw.close()
+ * }
+ */
+public class OptIgSplit extends IgSplit {
+
+  private static final int MAX_NUMERIC_SPLITS = 16;
+
+  @Override
+  public Split computeSplit(Data data, int attr) {
+    if (data.getDataset().isNumerical(attr)) {
+      return numericalSplit(data, attr);
+    } else {
+      return categoricalSplit(data, attr);
+    }
+  }
+
+  /**
+   * Computes the split for a CATEGORICAL attribute
+   */
+  private static Split categoricalSplit(Data data, int attr) {
+    double[] values = data.values(attr).clone();
+
+    double[] splitPoints = chooseCategoricalSplitPoints(values);
+
+    int numLabels = data.getDataset().nblabels();
+    int[][] counts = new int[splitPoints.length][numLabels];
+    int[] countAll = new int[numLabels];
+
+    computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+    int size = data.size();
+    double hy = entropy(countAll, size); // H(Y)
+    double hyx = 0.0; // H(Y|X)
+    double invDataSize = 1.0 / size;
+
+    for (int index = 0; index < splitPoints.length; index++) {
+      size = DataUtils.sum(counts[index]);
+      hyx += size * invDataSize * entropy(counts[index], size);
+    }
+
+    double ig = hy - hyx;
+    return new Split(attr, ig);
+  }
+
+  static void computeFrequencies(Data data,
+                                 int attr,
+                                 double[] splitPoints,
+                                 int[][] counts,
+                                 int[] countAll) {
+    Dataset dataset = data.getDataset();
+
+    for (int index = 0; index < data.size(); index++) {
+      Instance instance = data.get(index);
+      int label = (int) dataset.getLabel(instance);
+      double value = instance.get(attr);
+      int split = 0;
+      while (split < splitPoints.length && value > splitPoints[split]) {
+        split++;
+      }
+      if (split < splitPoints.length) {
+        counts[split][label]++;
+      } // Otherwise it's in the last split, which we don't need to count
+      countAll[label]++;
+    }
+  }
+
+  /**
+   * Computes the best split for a NUMERICAL attribute
+   */
+  static Split numericalSplit(Data data, int attr) {
+    double[] values = data.values(attr).clone();
+    Arrays.sort(values);
+
+    double[] splitPoints = chooseNumericSplitPoints(values);
+
+    int numLabels = data.getDataset().nblabels();
+    int[][] counts = new int[splitPoints.length][numLabels];
+    int[] countAll = new int[numLabels];
+    int[] countLess = new int[numLabels];
+
+    computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+    int size = data.size();
+    double hy = entropy(countAll, size);
+    double invDataSize = 1.0 / size;
+
+    int best = -1;
+    double bestIg = -1.0;
+
+    // try each possible split value
+    for (int index = 0; index < splitPoints.length; index++) {
+      double ig = hy;
+
+      DataUtils.add(countLess, counts[index]);
+      DataUtils.dec(countAll, counts[index]);
+
+      // instance with attribute value < values[index]
+      size = DataUtils.sum(countLess);
+      ig -= size * invDataSize * entropy(countLess, size);
+      // instance with attribute value >= values[index]
+      size = DataUtils.sum(countAll);
+      ig -= size * invDataSize * entropy(countAll, size);
+
+      if (ig > bestIg) {
+        bestIg = ig;
+        best = index;
+      }
+    }
+
+    if (best == -1) {
+      throw new IllegalStateException("no best split found !");
+    }
+    return new Split(attr, bestIg, splitPoints[best]);
+  }
+
+  /**
+   * @return an array of values to split the numeric feature's values on when
+   *  building candidate splits. When input size is <= MAX_NUMERIC_SPLITS + 1, it will
+   *  return the averages between success values as split points. When larger, it will
+   *  return MAX_NUMERIC_SPLITS approximate percentiles through the data.
+   */
+  private static double[] chooseNumericSplitPoints(double[] values) {
+    if (values.length <= 1) {
+      return values;
+    }
+    if (values.length <= MAX_NUMERIC_SPLITS + 1) {
+      double[] splitPoints = new double[values.length - 1];
+      for (int i = 1; i < values.length; i++) {
+        splitPoints[i-1] = (values[i] + values[i-1]) / 2.0;
+      }
+      return splitPoints;
+    }
+    Percentile distribution = new Percentile();
+    distribution.setData(values);
+    double[] percentiles = new double[MAX_NUMERIC_SPLITS];
+    for (int i = 0 ; i < percentiles.length; i++) {
+      double p = 100.0 * ((i + 1.0) / (MAX_NUMERIC_SPLITS + 1.0));
+      percentiles[i] = distribution.evaluate(p);
+    }
+    return percentiles;
+  }
+
+  private static double[] chooseCategoricalSplitPoints(double[] values) {
+    // There is no great reason to believe that categorical value order matters,
+    // but the original code worked this way, and it's not terrible in the absence
+    // of more sophisticated analysis
+    Collection<Double> uniqueOrderedCategories = new TreeSet<Double>();
+    for (double v : values) {
+      uniqueOrderedCategories.add(v);
+    }
+    double[] uniqueValues = new double[uniqueOrderedCategories.size()];
+    Iterator<Double> it = uniqueOrderedCategories.iterator();
+    for (int i = 0; i < uniqueValues.length; i++) {
+      uniqueValues[i] = it.next();
+    }
+    return uniqueValues;
+  }
+
+  /**
+   * Computes the Entropy
+   *
+   * @param counts   counts[i] = numInstances with label i
+   * @param dataSize numInstances
+   */
+  private static double entropy(int[] counts, int dataSize) {
+    if (dataSize == 0) {
+      return 0.0;
+    }
+
+    double entropy = 0.0;
+
+    for (int count : counts) {
+      if (count > 0) {
+        double p = count / (double) dataSize;
+        entropy -= p * Math.log(p);
+      }
+    }
+
+    return entropy / LOG2;
+  }
+
+}

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java?rev=1571704&r1=1571703&r2=1571704&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java Tue Feb 25 14:15:06 2014
@@ -86,7 +86,7 @@ public final class VisualizerTest extend
     Node tree = builder.build(rng, data);
     
     assertEquals("\noutlook = rainy\n|   windy = FALSE : yes\n|   windy = TRUE : no\n"
-            + "outlook = sunny\n|   humidity < 85 : yes\n|   humidity >= 85 : no\n"
+            + "outlook = sunny\n|   humidity < 77.5 : yes\n|   humidity >= 77.5 : no\n"
             + "outlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
   }
   
@@ -101,7 +101,7 @@ public final class VisualizerTest extend
         ATTR_NAMES);
     Assert.assertArrayEquals(new String[] {
         "outlook = rainy -> windy = TRUE -> no", "outlook = overcast -> yes",
-        "outlook = sunny -> (humidity = 90) >= 85 -> no"}, prediction);
+        "outlook = sunny -> (humidity = 90) >= 77.5 -> no"}, prediction);
   }
   
   @Test
@@ -142,7 +142,7 @@ public final class VisualizerTest extend
     builder.setComplemented(false);
     Node tree = builder.build(rng, lessData);
 
-    assertEquals("\noutlook = sunny\n|   humidity < 85 : yes\n|   humidity >= 85 : no\noutlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
+    assertEquals("\noutlook = sunny\n|   humidity < 77.5 : yes\n|   humidity >= 77.5 : no\noutlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
   }
   
   @Test

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java?rev=1571704&r1=1571703&r2=1571704&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java Tue Feb 25 14:15:06 2014
@@ -226,6 +226,9 @@ public class BuildForest extends Configu
     long time = System.currentTimeMillis();
     
     DecisionForest forest = forestBuilder.build(nbTrees);
+    if (forest == null) {
+      return;
+    }
     
     time = System.currentTimeMillis() - time;
     log.info("Build Time: {}", DFUtils.elapsedTime(time));