You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2010/01/14 11:56:26 UTC

svn commit: r899156 - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/df/builder/ core/src/test/java/org/apache/mahout/df/builder/ examples/src/main/java/org/apache/mahout/df/

Author: adeneche
Date: Thu Jan 14 10:56:25 2010
New Revision: 899156

URL: http://svn.apache.org/viewvc?rev=899156&view=rev
Log:
MAHOUT-245 related modifications

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/DefaultTreeBuilderTest.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java?rev=899156&r1=899155&r2=899156&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java Thu Jan 14 10:56:25 2010
@@ -23,6 +23,7 @@
 import org.apache.commons.lang.ArrayUtils;
 import org.apache.mahout.df.data.Data;
 import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Instance;
 import org.apache.mahout.df.data.conditions.Condition;
 import org.apache.mahout.df.node.CategoricalNode;
 import org.apache.mahout.df.node.Leaf;
@@ -31,6 +32,8 @@
 import org.apache.mahout.df.split.IgSplit;
 import org.apache.mahout.df.split.OptIgSplit;
 import org.apache.mahout.df.split.Split;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 /**
  * Builds a Decision Tree <br>
@@ -41,6 +44,11 @@
  */
 public class DefaultTreeBuilder implements TreeBuilder {
 
+  private static final Logger log = LoggerFactory.getLogger(DefaultTreeBuilder.class);
+
+  /** indicates which CATEGORICAL attributes have already been selected in the parent nodes */
+  private boolean[] selected;
+
   /** number of attributes to select randomly at each node */
   private int m = 1;
 
@@ -62,14 +70,18 @@
   @Override
   public Node build(Random rng, Data data) {
 
+    if (selected == null) {
+      selected = new boolean[data.getDataset().nbAttributes()];
+    }
+
     if (data.isEmpty())
       return new Leaf(-1);
-    if (data.isIdentical())
+    if (isIdentical(data))
       return new Leaf(data.majorityLabel(rng));
     if (data.identicalLabel())
       return new Leaf(data.get(0).label);
 
-    int[] attributes = randomAttributes(data.getDataset(), rng, m);
+    int[] attributes = randomAttributes(rng, selected, m);
 
     // find the best split
     Split best = null;
@@ -79,6 +91,14 @@
         best = split;
     }
 
+    boolean alreadySelected = selected[best.attr];
+    Node childNode = null;
+
+    if (alreadySelected) {
+      // attribute already selected
+      log.warn("attribute " + best.attr + " already selected in a parent node");
+    }
+    
     if (data.getDataset().isNumerical(best.attr)) {
       Data loSubset = data.subset(Condition.lesser(best.attr, best.split));
       Node loChild = build(rng, loSubset);
@@ -87,8 +107,10 @@
           best.split));
       Node hiChild = build(rng, hiSubset);
 
-      return new NumericalNode(best.attr, best.split, loChild, hiChild);
+      childNode = new NumericalNode(best.attr, best.split, loChild, hiChild);
     } else { // CATEGORICAL attribute
+      selected[best.attr] = true;
+      
       double[] values = data.values(best.attr);
       Node[] childs = new Node[values.length];
 
@@ -97,35 +119,82 @@
         childs[index] = build(rng, subset);
       }
 
-      return new CategoricalNode(best.attr, values, childs);
+      childNode = new CategoricalNode(best.attr, values, childs);
+
+      if (!alreadySelected) {
+        selected[best.attr] = false;
+      }
     }
+
+    return childNode;
+  }
+
+  /**
+   * checks if all the vectors have identical attribute values. Ignore selected attributes.
+   *
+   * @return true is all the vectors are identical or the data is empty<br>
+   *         false otherwise
+   */
+  private boolean isIdentical(Data data) {
+    if (data.isEmpty()) return true;
+
+    Instance instance = data.get(0);
+    for (int attr = 0; attr < selected.length; attr++) {
+    if (selected[attr]) continue;
+
+    for (int index = 1; index < data.size(); index++) {
+      if (data.get(index).get(attr) != instance.get(attr))
+        return false;
+      }
+    }
+
+    return true;
   }
 
   /**
    * Randomly selects m attributes to consider for split, excludes IGNORED and
    * LABEL attributes
    * 
-   * @param dataset
-   * @param rng
-   * @param m number of attributes to select
+   * @param rng random-numbers generator
+   * @param selected attributes' state (selected or not)
+   * @param m number of attributes to choose
    * @return
    */
-  protected static int[] randomAttributes(Dataset dataset, Random rng, int m) {
-    if (m > dataset.nbAttributes()) {
-      throw new IllegalArgumentException("m > num attributes");
+  protected static int[] randomAttributes(Random rng, boolean[] selected, int m) {
+    int nbNonSelected = 0; // number of non selected attributes
+    for (boolean sel : selected) {
+      if (!sel) nbNonSelected++;
     }
 
-    int[] result = new int[m];
+    if (nbNonSelected == 0) {
+      log.warn("All attributes are selected !");
+    }
 
-    Arrays.fill(result, -1);
+    int[] result;
+    if (nbNonSelected <= m) {
+      // return all non selected attributes
+      result = new int[nbNonSelected];
+      int index = 0;
+      for (int attr = 0; attr < selected.length; attr++) {
+        if (!selected[attr]) result[index++] = attr;
+      }
+    } else {
+      result = new int[m];
+      for (int index = 0; index < m; index++) {
+        // randomly choose a "non selected" attribute
+        int rind;
+        do {
+          rind = rng.nextInt(selected.length);
+        } while (selected[rind]);
 
-    for (int index = 0; index < m; index++) {
-      int rvalue;
-      do {
-        rvalue = rng.nextInt(dataset.nbAttributes());
-      } while (ArrayUtils.contains(result, rvalue));
+        result[index] = rind;
+        selected[rind] = true; // temporarely set the choosen attribute to be selected
+      }
 
-      result[index] = rvalue;
+      // the choosen attributes are not yet selected
+      for (int attr : result) {
+        selected[attr] = false;
+      }
     }
 
     return result;

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/DefaultTreeBuilderTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/DefaultTreeBuilderTest.java?rev=899156&r1=899155&r2=899156&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/DefaultTreeBuilderTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/DefaultTreeBuilderTest.java Thu Jan 14 10:56:25 2010
@@ -18,6 +18,7 @@
 package org.apache.mahout.df.builder;
 
 import java.util.Random;
+import java.util.Arrays;
 
 import org.apache.commons.lang.ArrayUtils;
 import org.apache.mahout.common.RandomUtils;
@@ -28,34 +29,52 @@
 
 public class DefaultTreeBuilderTest extends TestCase {
 
+  @Override
+  protected void setUp() throws Exception {
+    RandomUtils.useTestSeed();
+  }
+
+  /**
+   * make sure that DefaultTreeBuilder.randomAttributes() returns the correct number of attributes, that have not been
+   * selected yet
+   *
+   * @throws Exception
+   */
   public void testRandomAttributes() throws Exception {
     Random rng = RandomUtils.getRandom();
-    int maxNbAttributes = 100;
-    int n = 100;
+    int nbAttributes = rng.nextInt(100) + 1;
+    boolean[] selected = new boolean[nbAttributes];
 
-    for (int nloop = 0; nloop < n; nloop++) {
-      int nbAttributes = rng.nextInt(maxNbAttributes) + 1;
+    for (int nloop = 0; nloop < 100; nloop++) {
+      Arrays.fill(selected, false);
 
-      // generate a small data, only to get the dataset
-      Data data = Utils.randomData(rng, nbAttributes, 1);
-      if (data.getDataset().nbAttributes() == 0)
-        continue;
+      // randomly select some attributes
+      int nbSelected = rng.nextInt(nbAttributes - 1);
+      for (int index = 0; index < nbSelected; index++) {
+        int attr;
+        do {
+          attr = rng.nextInt(nbAttributes);
+        } while (selected[attr]);
+
+        selected[attr] = true;
+      }
 
-      int m = rng.nextInt(data.getDataset().nbAttributes()) + 1;
+      int m = rng.nextInt(nbAttributes);
 
-      int[] attrs = DefaultTreeBuilder.randomAttributes(data.getDataset(), rng, m);
+      int[] attrs = DefaultTreeBuilder.randomAttributes(rng, selected, m);
 
-      assertEquals(m, attrs.length);
+      assertEquals(Math.min(m, nbAttributes - nbSelected), attrs.length);
 
-      for (int index = 0; index < m; index++) {
-        int attr = attrs[index];
+      for (int attr : attrs) {
+        // the attribute should not be already selected
+        assertFalse("an attribute has already been selected", selected[attr]);
 
         // each attribute should be in the range [0, nbAttributes[
         assertTrue(attr >= 0);
         assertTrue(attr < nbAttributes);
 
         // each attribute should appear only once
-        assertEquals(index, ArrayUtils.lastIndexOf(attrs, attr));
+        assertEquals(ArrayUtils.indexOf(attrs, attr), ArrayUtils.lastIndexOf(attrs, attr));
       }
     }
   }

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java?rev=899156&r1=899155&r2=899156&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java Thu Jan 14 10:56:25 2010
@@ -71,20 +71,26 @@
   /** mean time to build a forest with m=1 */
   private static long sumTimeOne;
 
+  /** mean number of nodes for all the trees grown with m=log2(M)+1 */
+  protected static long numNodesM;
+
+  /** mean number of nodes for all the trees grown with m=1 */
+  protected static long numNodesOne;
+
   /**
    * runs one iteration of the procedure.
    *
+   * @param rng random numbers generator
    * @param data training data
    * @param m number of random variables to select at each tree-node
    * @param nbtrees number of trees to grow
    * @throws Exception if an error occured while growing the trees
    */
-  protected static void runIteration(Data data, int m, int nbtrees) {
+  protected static void runIteration(Random rng, Data data, int m, int nbtrees) {
 
     int nblabels = data.getDataset().nblabels();
 
-    Random rng = RandomUtils.getRandom();
-
+    log.info("Splitting the data");
     Data train = data.clone();
     Data test = train.rsplit(rng, (int) (data.size() * 0.1));
     
@@ -103,6 +109,7 @@
     log.info("Growing a forest with m=" + m);
     DecisionForest forestM = forestBuilder.build(nbtrees, errorM);
     sumTimeM += System.currentTimeMillis() - time;
+    numNodesM += forestM.nbNodes();
 
     double oobM = ErrorEstimate.errorRate(trainLabels, errorM.computePredictions(rng)); // oob error estimate when m = log2(M)+1
 
@@ -114,7 +121,8 @@
     log.info("Growing a forest with m=1");
     DecisionForest forestOne = forestBuilder.build(nbtrees, errorOne);
     sumTimeOne += System.currentTimeMillis() - time;
-
+    numNodesOne += forestOne.nbNodes();
+    
     double oobOne = ErrorEstimate.errorRate(trainLabels, errorOne.computePredictions(rng)); // oob error estimate when m = 1
 
     // compute the test set error (Selection Error), and mean tree error (One Tree Error),
@@ -208,9 +216,10 @@
     // number of inputs
     int m = (int) Math.floor(Maths.log(2, data.getDataset().nbAttributes()) + 1);
 
+    Random rng = RandomUtils.getRandom();
     for (int iteration = 0; iteration < nbIterations; iteration++) {
       log.info("Iteration " + iteration);
-      runIteration(data, m, nbTrees);
+      runIteration(rng, data, m, nbTrees);
     }
 
     log.info("********************************************");
@@ -220,6 +229,9 @@
     log.info("");
     log.info("Mean Random Input Time : " + DFUtils.elapsedTime(sumTimeM / nbIterations));
     log.info("Mean Single Input Time : " + DFUtils.elapsedTime(sumTimeOne / nbIterations));
+    log.info("");
+    log.info("Mean Random Input Num Nodes : " + numNodesM / nbIterations);
+    log.info("Mean Single Input Num Nodes : " + numNodesOne / nbIterations);
 
     return 0;
   }