You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2010/01/14 11:56:26 UTC
svn commit: r899156 - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/df/builder/
core/src/test/java/org/apache/mahout/df/builder/
examples/src/main/java/org/apache/mahout/df/
Author: adeneche
Date: Thu Jan 14 10:56:25 2010
New Revision: 899156
URL: http://svn.apache.org/viewvc?rev=899156&view=rev
Log:
MAHOUT-245 related modifications
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/DefaultTreeBuilderTest.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java?rev=899156&r1=899155&r2=899156&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java Thu Jan 14 10:56:25 2010
@@ -23,6 +23,7 @@
import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Instance;
import org.apache.mahout.df.data.conditions.Condition;
import org.apache.mahout.df.node.CategoricalNode;
import org.apache.mahout.df.node.Leaf;
@@ -31,6 +32,8 @@
import org.apache.mahout.df.split.IgSplit;
import org.apache.mahout.df.split.OptIgSplit;
import org.apache.mahout.df.split.Split;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* Builds a Decision Tree <br>
@@ -41,6 +44,11 @@
*/
public class DefaultTreeBuilder implements TreeBuilder {
+ private static final Logger log = LoggerFactory.getLogger(DefaultTreeBuilder.class);
+
+ /** indicates which CATEGORICAL attributes have already been selected in the parent nodes */
+ private boolean[] selected;
+
/** number of attributes to select randomly at each node */
private int m = 1;
@@ -62,14 +70,18 @@
@Override
public Node build(Random rng, Data data) {
+ if (selected == null) {
+ selected = new boolean[data.getDataset().nbAttributes()];
+ }
+
if (data.isEmpty())
return new Leaf(-1);
- if (data.isIdentical())
+ if (isIdentical(data))
return new Leaf(data.majorityLabel(rng));
if (data.identicalLabel())
return new Leaf(data.get(0).label);
- int[] attributes = randomAttributes(data.getDataset(), rng, m);
+ int[] attributes = randomAttributes(rng, selected, m);
// find the best split
Split best = null;
@@ -79,6 +91,14 @@
best = split;
}
+ boolean alreadySelected = selected[best.attr];
+ Node childNode = null;
+
+ if (alreadySelected) {
+ // attribute already selected
+ log.warn("attribute " + best.attr + " already selected in a parent node");
+ }
+
if (data.getDataset().isNumerical(best.attr)) {
Data loSubset = data.subset(Condition.lesser(best.attr, best.split));
Node loChild = build(rng, loSubset);
@@ -87,8 +107,10 @@
best.split));
Node hiChild = build(rng, hiSubset);
- return new NumericalNode(best.attr, best.split, loChild, hiChild);
+ childNode = new NumericalNode(best.attr, best.split, loChild, hiChild);
} else { // CATEGORICAL attribute
+ selected[best.attr] = true;
+
double[] values = data.values(best.attr);
Node[] childs = new Node[values.length];
@@ -97,35 +119,82 @@
childs[index] = build(rng, subset);
}
- return new CategoricalNode(best.attr, values, childs);
+ childNode = new CategoricalNode(best.attr, values, childs);
+
+ if (!alreadySelected) {
+ selected[best.attr] = false;
+ }
}
+
+ return childNode;
+ }
+
+ /**
+ * checks if all the vectors have identical attribute values. Ignore selected attributes.
+ *
+ * @return true is all the vectors are identical or the data is empty<br>
+ * false otherwise
+ */
+ private boolean isIdentical(Data data) {
+ if (data.isEmpty()) return true;
+
+ Instance instance = data.get(0);
+ for (int attr = 0; attr < selected.length; attr++) {
+ if (selected[attr]) continue;
+
+ for (int index = 1; index < data.size(); index++) {
+ if (data.get(index).get(attr) != instance.get(attr))
+ return false;
+ }
+ }
+
+ return true;
}
/**
* Randomly selects m attributes to consider for split, excludes IGNORED and
* LABEL attributes
*
- * @param dataset
- * @param rng
- * @param m number of attributes to select
+ * @param rng random-numbers generator
+ * @param selected attributes' state (selected or not)
+ * @param m number of attributes to choose
* @return
*/
- protected static int[] randomAttributes(Dataset dataset, Random rng, int m) {
- if (m > dataset.nbAttributes()) {
- throw new IllegalArgumentException("m > num attributes");
+ protected static int[] randomAttributes(Random rng, boolean[] selected, int m) {
+ int nbNonSelected = 0; // number of non selected attributes
+ for (boolean sel : selected) {
+ if (!sel) nbNonSelected++;
}
- int[] result = new int[m];
+ if (nbNonSelected == 0) {
+ log.warn("All attributes are selected !");
+ }
- Arrays.fill(result, -1);
+ int[] result;
+ if (nbNonSelected <= m) {
+ // return all non selected attributes
+ result = new int[nbNonSelected];
+ int index = 0;
+ for (int attr = 0; attr < selected.length; attr++) {
+ if (!selected[attr]) result[index++] = attr;
+ }
+ } else {
+ result = new int[m];
+ for (int index = 0; index < m; index++) {
+ // randomly choose a "non selected" attribute
+ int rind;
+ do {
+ rind = rng.nextInt(selected.length);
+ } while (selected[rind]);
- for (int index = 0; index < m; index++) {
- int rvalue;
- do {
- rvalue = rng.nextInt(dataset.nbAttributes());
- } while (ArrayUtils.contains(result, rvalue));
+ result[index] = rind;
+ selected[rind] = true; // temporarely set the choosen attribute to be selected
+ }
- result[index] = rvalue;
+ // the choosen attributes are not yet selected
+ for (int attr : result) {
+ selected[attr] = false;
+ }
}
return result;
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/DefaultTreeBuilderTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/DefaultTreeBuilderTest.java?rev=899156&r1=899155&r2=899156&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/DefaultTreeBuilderTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/DefaultTreeBuilderTest.java Thu Jan 14 10:56:25 2010
@@ -18,6 +18,7 @@
package org.apache.mahout.df.builder;
import java.util.Random;
+import java.util.Arrays;
import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.common.RandomUtils;
@@ -28,34 +29,52 @@
public class DefaultTreeBuilderTest extends TestCase {
+ @Override
+ protected void setUp() throws Exception {
+ RandomUtils.useTestSeed();
+ }
+
+ /**
+ * make sure that DefaultTreeBuilder.randomAttributes() returns the correct number of attributes, that have not been
+ * selected yet
+ *
+ * @throws Exception
+ */
public void testRandomAttributes() throws Exception {
Random rng = RandomUtils.getRandom();
- int maxNbAttributes = 100;
- int n = 100;
+ int nbAttributes = rng.nextInt(100) + 1;
+ boolean[] selected = new boolean[nbAttributes];
- for (int nloop = 0; nloop < n; nloop++) {
- int nbAttributes = rng.nextInt(maxNbAttributes) + 1;
+ for (int nloop = 0; nloop < 100; nloop++) {
+ Arrays.fill(selected, false);
- // generate a small data, only to get the dataset
- Data data = Utils.randomData(rng, nbAttributes, 1);
- if (data.getDataset().nbAttributes() == 0)
- continue;
+ // randomly select some attributes
+ int nbSelected = rng.nextInt(nbAttributes - 1);
+ for (int index = 0; index < nbSelected; index++) {
+ int attr;
+ do {
+ attr = rng.nextInt(nbAttributes);
+ } while (selected[attr]);
+
+ selected[attr] = true;
+ }
- int m = rng.nextInt(data.getDataset().nbAttributes()) + 1;
+ int m = rng.nextInt(nbAttributes);
- int[] attrs = DefaultTreeBuilder.randomAttributes(data.getDataset(), rng, m);
+ int[] attrs = DefaultTreeBuilder.randomAttributes(rng, selected, m);
- assertEquals(m, attrs.length);
+ assertEquals(Math.min(m, nbAttributes - nbSelected), attrs.length);
- for (int index = 0; index < m; index++) {
- int attr = attrs[index];
+ for (int attr : attrs) {
+ // the attribute should not be already selected
+ assertFalse("an attribute has already been selected", selected[attr]);
// each attribute should be in the range [0, nbAttributes[
assertTrue(attr >= 0);
assertTrue(attr < nbAttributes);
// each attribute should appear only once
- assertEquals(index, ArrayUtils.lastIndexOf(attrs, attr));
+ assertEquals(ArrayUtils.indexOf(attrs, attr), ArrayUtils.lastIndexOf(attrs, attr));
}
}
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java?rev=899156&r1=899155&r2=899156&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java Thu Jan 14 10:56:25 2010
@@ -71,20 +71,26 @@
/** mean time to build a forest with m=1 */
private static long sumTimeOne;
+ /** mean number of nodes for all the trees grown with m=log2(M)+1 */
+ protected static long numNodesM;
+
+ /** mean number of nodes for all the trees grown with m=1 */
+ protected static long numNodesOne;
+
/**
* runs one iteration of the procedure.
*
+ * @param rng random numbers generator
* @param data training data
* @param m number of random variables to select at each tree-node
* @param nbtrees number of trees to grow
* @throws Exception if an error occured while growing the trees
*/
- protected static void runIteration(Data data, int m, int nbtrees) {
+ protected static void runIteration(Random rng, Data data, int m, int nbtrees) {
int nblabels = data.getDataset().nblabels();
- Random rng = RandomUtils.getRandom();
-
+ log.info("Splitting the data");
Data train = data.clone();
Data test = train.rsplit(rng, (int) (data.size() * 0.1));
@@ -103,6 +109,7 @@
log.info("Growing a forest with m=" + m);
DecisionForest forestM = forestBuilder.build(nbtrees, errorM);
sumTimeM += System.currentTimeMillis() - time;
+ numNodesM += forestM.nbNodes();
double oobM = ErrorEstimate.errorRate(trainLabels, errorM.computePredictions(rng)); // oob error estimate when m = log2(M)+1
@@ -114,7 +121,8 @@
log.info("Growing a forest with m=1");
DecisionForest forestOne = forestBuilder.build(nbtrees, errorOne);
sumTimeOne += System.currentTimeMillis() - time;
-
+ numNodesOne += forestOne.nbNodes();
+
double oobOne = ErrorEstimate.errorRate(trainLabels, errorOne.computePredictions(rng)); // oob error estimate when m = 1
// compute the test set error (Selection Error), and mean tree error (One Tree Error),
@@ -208,9 +216,10 @@
// number of inputs
int m = (int) Math.floor(Maths.log(2, data.getDataset().nbAttributes()) + 1);
+ Random rng = RandomUtils.getRandom();
for (int iteration = 0; iteration < nbIterations; iteration++) {
log.info("Iteration " + iteration);
- runIteration(data, m, nbTrees);
+ runIteration(rng, data, m, nbTrees);
}
log.info("********************************************");
@@ -220,6 +229,9 @@
log.info("");
log.info("Mean Random Input Time : " + DFUtils.elapsedTime(sumTimeM / nbIterations));
log.info("Mean Single Input Time : " + DFUtils.elapsedTime(sumTimeOne / nbIterations));
+ log.info("");
+ log.info("Mean Random Input Num Nodes : " + numNodesM / nbIterations);
+ log.info("Mean Single Input Num Nodes : " + numNodesOne / nbIterations);
return 0;
}