You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2010/11/01 17:43:36 UTC
svn commit: r1029738 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/df/builder/
core/src/test/java/org/apache/mahout/df/builder/
examples/src/main/java/org/apache/mahout/df/mapreduce/
Author: adeneche
Date: Mon Nov 1 16:43:36 2010
New Revision: 1029738
URL: http://svn.apache.org/viewvc?rev=1029738&view=rev
Log:
MAHOUT-526 Fixed the Infinite Recursion in Decision Forests
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java?rev=1029738&r1=1029737&r2=1029738&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/builder/DefaultTreeBuilder.java Mon Nov 1 16:43:36 2010
@@ -20,6 +20,7 @@ package org.apache.mahout.df.builder;
import java.util.Random;
import org.apache.mahout.df.data.Data;
+import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.data.Instance;
import org.apache.mahout.df.data.conditions.Condition;
import org.apache.mahout.df.node.CategoricalNode;
@@ -81,7 +82,10 @@ public class DefaultTreeBuilder implemen
}
int[] attributes = randomAttributes(rng, selected, m);
-
+ if (attributes == null) { // we tried all the attributes and could not split the data anymore
+ return new Leaf(data.majorityLabel(rng));
+ }
+
// find the best split
Split best = null;
for (int attr : attributes) {
@@ -92,7 +96,6 @@ public class DefaultTreeBuilder implemen
}
boolean alreadySelected = selected[best.getAttr()];
-
if (alreadySelected) {
// attribute already selected
log.warn("attribute {} already selected in a parent node", best.getAttr());
@@ -100,12 +103,30 @@ public class DefaultTreeBuilder implemen
Node childNode;
if (data.getDataset().isNumerical(best.getAttr())) {
+ boolean[] temp = null;
+
Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
- Node loChild = build(rng, loSubset);
-
Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));
+
+ if (loSubset.isEmpty() || hiSubset.isEmpty()) {
+ // the selected attribute did not change the data, avoid using it in the child notes
+ selected[best.getAttr()] = true;
+ } else {
+ // the data changed, so we can unselect all previousely selected NUMERICAL attributes
+ temp = selected;
+ selected = cloneCategoricalAttributes(data.getDataset(), selected);
+ }
+
+ Node loChild = build(rng, loSubset);
Node hiChild = build(rng, hiSubset);
-
+
+ // restore the selection state of the attributes
+ if (temp != null) {
+ selected = temp;
+ } else {
+ selected[best.getAttr()] = alreadySelected;
+ }
+
childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
} else { // CATEGORICAL attribute
selected[best.getAttr()] = true;
@@ -117,12 +138,10 @@ public class DefaultTreeBuilder implemen
Data subset = data.subset(Condition.equals(best.getAttr(), values[index]));
children[index] = build(rng, subset);
}
+
+ selected[best.getAttr()] = alreadySelected;
childNode = new CategoricalNode(best.getAttr(), values, children);
-
- if (!alreadySelected) {
- selected[best.getAttr()] = false;
- }
}
return childNode;
@@ -154,7 +173,25 @@ public class DefaultTreeBuilder implemen
return true;
}
-
+
+
+ /**
+ * Make a copy of the selection state of the attributes, unselect all numerical attributes
+ * @param dataset
+ * @param selected selection state to clone
+ * @return cloned selection state
+ */
+ protected static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) {
+ boolean[] cloned = new boolean[selected.length];
+
+ for (int i = 0; i < selected.length; i++) {
+ if (dataset.isNumerical(i)) cloned[i] = false;
+ else cloned[i] = selected[i];
+ }
+
+ return cloned;
+ }
+
/**
* Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes
*
@@ -164,6 +201,7 @@ public class DefaultTreeBuilder implemen
* attributes' state (selected or not)
* @param m
* number of attributes to choose
+ * @return list of selected attributes' indices, or null if all attributes have already been selected
*/
protected static int[] randomAttributes(Random rng, boolean[] selected, int m) {
int nbNonSelected = 0; // number of non selected attributes
@@ -175,6 +213,7 @@ public class DefaultTreeBuilder implemen
if (nbNonSelected == 0) {
log.warn("All attributes are selected !");
+ return null;
}
int[] result;
Added: mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java?rev=1029738&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/df/builder/InfiniteRecursionTest.java Mon Nov 1 16:43:36 2010
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.builder;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.df.data.Data;
+import org.apache.mahout.df.data.DataLoader;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Utils;
+import org.junit.Test;
+
+import java.util.Random;
+
+public final class InfiniteRecursionTest extends MahoutTestCase {
+
+ static private double[][] dData = {
+ {0.25, 0.0, 0.0, 5.143998668220409E-4, 0.019847102289905324, 3.5216524641879855E-4, 0.0, 0.6225857142857143, 4},
+ {0.25, 0.0, 0.0, 0.0010504411519893459, 0.005462138323171171, 0.0026130744829756746, 0.0, 0.4964857142857143, 3},
+ {0.25, 0.0, 0.0, 0.0010504411519893459, 0.005462138323171171, 0.0026130744829756746, 0.0, 0.4964857142857143, 4},
+ {0.25, 0.0, 0.0, 5.143998668220409E-4, 0.019847102289905324, 3.5216524641879855E-4, 0.0, 0.6225857142857143, 3}
+ };
+
+ /**
+ * make sure DefaultTreeBuilder.build() does not throw a StackOverflowException
+ */
+ @Test
+ public void testBuild() throws Exception {
+ Random rng = RandomUtils.getRandom();
+
+ DefaultTreeBuilder builder = new DefaultTreeBuilder();
+
+ String[] source = Utils.double2String(dData);
+ String descriptor = "N N N N N N N N L";
+ Dataset dataset = DataLoader.generateDataset(descriptor, source);
+ Data data = DataLoader.loadData(dataset, source);
+
+ builder.build(rng, data);
+ }
+}
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java?rev=1029738&r1=1029737&r2=1029738&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java Mon Nov 1 16:43:36 2010
@@ -196,7 +196,10 @@ public class BuildForest extends Configu
time = System.currentTimeMillis() - time;
log.info("Build Time: {}", DFUtils.elapsedTime(time));
-
+ log.info("Forest num Nodes: {}", forest.nbNodes());
+ log.info("Forest mean num Nodes: {}", forest.meanNbNodes());
+ log.info("Forest mean max Depth: {}", forest.meanMaxDepth());
+
if (isOob) {
Random rng;
if (seed != null) {