You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2013/06/05 12:39:16 UTC
svn commit: r1489802 - in /mahout/trunk: CHANGELOG
core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
Author: gsingers
Date: Wed Jun 5 10:39:15 2013
New Revision: 1489802
URL: http://svn.apache.org/r1489802
Log:
MAHOUT-961: fix issues with visualizing DF trees
Modified:
mahout/trunk/CHANGELOG
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1489802&r1=1489801&r2=1489802&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Wed Jun 5 10:39:15 2013
@@ -84,3 +84,5 @@ __MAHOUT-1181: Adding StreamingKMeans Ma
MAHOUT-1176: Introduce a changelog file to raise contributors attribution (ssc)
MAHOUT-1108: Allows cluster-reuters.sh example to be executed on a cluster (elmer.garduno via gsingers)
+
+ MAHOUT-961: Fix issue in decision forest tree visualizer to properly show stems of tree (Ikumasa Mukai via gsingers)
\ No newline at end of file
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java?rev=1489802&r1=1489801&r2=1489802&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java Wed Jun 5 10:39:15 2013
@@ -35,47 +35,42 @@ import org.apache.mahout.classifier.df.n
* This tool is to visualize the Decision tree
*/
public final class TreeVisualizer {
-
- private TreeVisualizer() {
- }
-
+
+ private TreeVisualizer() {}
+
private static String doubleToString(double value) {
DecimalFormat df = new DecimalFormat("0.##");
return df.format(value);
}
-
- private static String toStringNode(Node node,
- Dataset dataset,
- String[] attrNames,
- Map<String, Field> fields,
- int layer) {
-
+
+ private static String toStringNode(Node node, Dataset dataset,
+ String[] attrNames, Map<String,Field> fields, int layer) {
+
StringBuilder buff = new StringBuilder();
-
+
try {
-
if (node instanceof CategoricalNode) {
CategoricalNode cnode = (CategoricalNode) node;
int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
- double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode);
- Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode);
- String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset);
- for (int i = 0; i < childs.length; i++) {
+ double[] values = (double[]) fields.get("CategoricalNode.values").get(
+ cnode);
+ Node[] childs = (Node[]) fields.get("CategoricalNode.childs")
+ .get(cnode);
+ String[][] attrValues = (String[][]) fields.get("Dataset.values").get(
+ dataset);
+ for (int i = 0; i < attrValues[attr].length; i++) {
+ int index = ArrayUtils.indexOf(values, i);
+ if (index < 0) {
+ continue;
+ }
buff.append('\n');
for (int j = 0; j < layer; j++) {
buff.append("| ");
}
- if (values[i] < attrValues[attr].length) {
- if (attrNames == null) {
- buff.append(attr);
- } else {
- buff.append(attrNames[attr]);
- }
- buff.append(" = ");
- buff.append(attrValues[attr][(int) values[i]]);
-
- buff.append(toStringNode(childs[i], dataset, attrNames, fields, layer + 1));
- }
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+ .append(attrValues[attr][i]);
+ buff.append(toStringNode(childs[index], dataset, attrNames, fields,
+ layer + 1));
}
} else if (node instanceof NumericalNode) {
NumericalNode nnode = (NumericalNode) node;
@@ -87,13 +82,15 @@ public final class TreeVisualizer {
for (int j = 0; j < layer; j++) {
buff.append("| ");
}
- buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ").append(doubleToString(split));
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ")
+ .append(doubleToString(split));
buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 1));
buff.append('\n');
for (int j = 0; j < layer; j++) {
buff.append("| ");
}
- buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ").append(doubleToString(split));
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ")
+ .append(doubleToString(split));
buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1));
} else if (node instanceof Leaf) {
Leaf leaf = (Leaf) node;
@@ -101,20 +98,19 @@ public final class TreeVisualizer {
if (dataset.isNumerical(dataset.getLabelId())) {
buff.append(" : ").append(doubleToString(label));
} else {
- buff.append(" : ").append(dataset.getLabelString((int) label));
+ buff.append(" : ").append(dataset.getLabelString(label));
}
}
-
} catch (IllegalAccessException iae) {
throw new IllegalStateException(iae);
}
-
+
return buff.toString();
}
-
- private static Map<String, Field> getReflectMap() {
- Map<String, Field> fields = new HashMap<String, Field>();
-
+
+ private static Map<String,Field> getReflectMap() {
+ Map<String,Field> fields = new HashMap<String,Field>();
+
try {
Field m = CategoricalNode.class.getDeclaredField("attr");
m.setAccessible(true);
@@ -152,6 +148,7 @@ public final class TreeVisualizer {
/**
* Decision tree to String
+ *
* @param tree
* Node of tree
* @param attrNames
@@ -160,37 +157,41 @@ public final class TreeVisualizer {
public static String toString(Node tree, Dataset dataset, String[] attrNames) {
return toStringNode(tree, dataset, attrNames, getReflectMap(), 0);
}
-
+
/**
* Print Decision tree
- * @param tree Node of tree
- * @param attrNames attribute names
+ *
+ * @param tree
+ * Node of tree
+ * @param attrNames
+ * attribute names
*/
public static void print(Node tree, Dataset dataset, String[] attrNames) {
System.out.println(toString(tree, dataset, attrNames));
}
-
- private static String toStringPredict(Node node,
- Instance instance,
- Dataset dataset,
- String[] attrNames,
- Map<String, Field> fields) {
+
+ private static String toStringPredict(Node node, Instance instance,
+ Dataset dataset, String[] attrNames, Map<String,Field> fields) {
StringBuilder buff = new StringBuilder();
-
+
try {
if (node instanceof CategoricalNode) {
CategoricalNode cnode = (CategoricalNode) node;
int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
- double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode);
- Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode);
- String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset);
-
+ double[] values = (double[]) fields.get("CategoricalNode.values").get(
+ cnode);
+ Node[] childs = (Node[]) fields.get("CategoricalNode.childs")
+ .get(cnode);
+ String[][] attrValues = (String[][]) fields.get("Dataset.values").get(
+ dataset);
+
int index = ArrayUtils.indexOf(values, instance.get(attr));
if (index >= 0) {
buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
.append(attrValues[attr][(int) instance.get(attr)]);
buff.append(" -> ");
- buff.append(toStringPredict(childs[index], instance, dataset, attrNames, fields));
+ buff.append(toStringPredict(childs[index], instance, dataset,
+ attrNames, fields));
}
} else if (node instanceof NumericalNode) {
NumericalNode nnode = (NumericalNode) node;
@@ -198,17 +199,21 @@ public final class TreeVisualizer {
double split = (Double) fields.get("NumericalNode.split").get(nnode);
Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
-
+
if (instance.get(attr) < split) {
- buff.append('(').append(attrNames == null ? attr : attrNames[attr]).append(" = ")
- .append(doubleToString(instance.get(attr))).append(") < ").append(doubleToString(split));
+ buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+ .append(" = ").append(doubleToString(instance.get(attr)))
+ .append(") < ").append(doubleToString(split));
buff.append(" -> ");
- buff.append(toStringPredict(loChild, instance, dataset, attrNames, fields));
+ buff.append(toStringPredict(loChild, instance, dataset, attrNames,
+ fields));
} else {
- buff.append('(').append(attrNames == null ? attr : attrNames[attr]).append(" = ")
- .append(doubleToString(instance.get(attr))).append(") >= ").append(doubleToString(split));
+ buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+ .append(" = ").append(doubleToString(instance.get(attr)))
+ .append(") >= ").append(doubleToString(split));
buff.append(" -> ");
- buff.append(toStringPredict(hiChild, instance, dataset, attrNames, fields));
+ buff.append(toStringPredict(hiChild, instance, dataset, attrNames,
+ fields));
}
} else if (node instanceof Leaf) {
Leaf leaf = (Leaf) node;
@@ -216,43 +221,47 @@ public final class TreeVisualizer {
if (dataset.isNumerical(dataset.getLabelId())) {
buff.append(doubleToString(label));
} else {
- buff.append(dataset.getLabelString((int) label));
+ buff.append(dataset.getLabelString(label));
}
}
} catch (IllegalAccessException iae) {
throw new IllegalStateException(iae);
}
-
+
return buff.toString();
}
-
+
/**
* Predict trace to String
+ *
* @param tree
* Node of tree
* @param attrNames
* attribute names
*/
public static String[] predictTrace(Node tree, Data data, String[] attrNames) {
- Map<String, Field> reflectMap = getReflectMap();
+ Map<String,Field> reflectMap = getReflectMap();
String[] prediction = new String[data.size()];
for (int i = 0; i < data.size(); i++) {
- prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(), attrNames, reflectMap);
+ prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(),
+ attrNames, reflectMap);
}
return prediction;
}
-
+
/**
* Print predict trace
+ *
* @param tree
* Node of tree
* @param attrNames
* attribute names
*/
public static void predictTracePrint(Node tree, Data data, String[] attrNames) {
- Map<String, Field> reflectMap = getReflectMap();
+ Map<String,Field> reflectMap = getReflectMap();
for (int i = 0; i < data.size(); i++) {
- System.out.println(toStringPredict(tree, data.get(i), data.getDataset(), attrNames, reflectMap));
+ System.out.println(toStringPredict(tree, data.get(i), data.getDataset(),
+ attrNames, reflectMap));
}
}
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java?rev=1489802&r1=1489801&r2=1489802&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java Wed Jun 5 10:39:15 2013
@@ -20,12 +20,12 @@ package org.apache.mahout.classifier.df.
import java.util.List;
import java.util.Random;
-import com.google.common.collect.Lists;
import org.apache.mahout.classifier.df.DecisionForest;
import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataLoader;
import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
import org.apache.mahout.classifier.df.node.CategoricalNode;
import org.apache.mahout.classifier.df.node.Leaf;
import org.apache.mahout.classifier.df.node.Node;
@@ -36,24 +36,28 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import com.google.common.collect.Lists;
+
public final class VisualizerTest extends MahoutTestCase {
private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no",
- "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes",
- "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no",
- "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no",
- "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes",
- "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes",
- "rainy,71,91,TRUE,no"};
+ "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes",
+ "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no",
+ "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no",
+ "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes",
+ "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes",
+ "rainy,71,91,TRUE,no"};
private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-",
- "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",};
+ "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",};
private static final String[] ATTR_NAMES = {"outlook", "temperature",
- "humidity", "windy", "play"};
+ "humidity", "windy", "play"};
private Random rng;
+
private Data data;
+
private Data testData;
@Override
@@ -61,10 +65,11 @@ public final class VisualizerTest extend
public void setUp() throws Exception {
super.setUp();
- rng = RandomUtils.getRandom();
+ rng = RandomUtils.getRandom(1);
// Dataset
- Dataset dataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA);
+ Dataset dataset = DataLoader
+ .generateDataset("C N N C L", false, TRAIN_DATA);
// Training data
data = DataLoader.loadData(dataset, TRAIN_DATA);
@@ -80,10 +85,9 @@ public final class VisualizerTest extend
builder.setM(data.getDataset().nbAttributes() - 1);
Node tree = builder.build(rng, data);
- assertEquals(TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES),
- "\noutlook = rainy\n| windy = FALSE : yes\n| windy = TRUE : no\n"
- + "outlook = overcast : yes\n"
- + "outlook = sunny\n| humidity < 85 : yes\n| humidity >= 85 : no");
+ assertEquals("\noutlook = rainy\n| windy = FALSE : yes\n| windy = TRUE : no\n"
+ + "outlook = sunny\n| humidity < 85 : yes\n| humidity >= 85 : no\n"
+ + "outlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
}
@Test
@@ -93,8 +97,9 @@ public final class VisualizerTest extend
builder.setM(data.getDataset().nbAttributes() - 1);
Node tree = builder.build(rng, data);
- String[] prediction = TreeVisualizer.predictTrace(tree, testData, ATTR_NAMES);
- Assert.assertArrayEquals(new String[]{
+ String[] prediction = TreeVisualizer.predictTrace(tree, testData,
+ ATTR_NAMES);
+ Assert.assertArrayEquals(new String[] {
"outlook = rainy -> windy = TRUE -> no", "outlook = overcast -> yes",
"outlook = sunny -> (humidity = 90) >= 85 -> no"}, prediction);
}
@@ -103,22 +108,51 @@ public final class VisualizerTest extend
public void testForestVisualize() throws Exception {
// Tree
NumericalNode root = new NumericalNode(2, 90, new Leaf(0),
- new CategoricalNode(0, new double[] {0, 1, 2}, new Node[] {
- new NumericalNode(1, 71, new Leaf(0), new Leaf(1)), new Leaf(1),
- new Leaf(0)}));
+ new CategoricalNode(0, new double[] {0, 1, 2}, new Node[] {
+ new NumericalNode(1, 71, new Leaf(0), new Leaf(1)), new Leaf(1),
+ new Leaf(0)}));
List<Node> trees = Lists.newArrayList();
trees.add(root);
// Forest
DecisionForest forest = new DecisionForest(trees);
- assertEquals(ForestVisualizer.toString(forest, data.getDataset(), null),
- "Tree[1]:\n2 < 90 : yes\n2 >= 90\n"
- + "| 0 = rainy\n| | 1 < 71 : yes\n| | 1 >= 71 : no\n"
- + "| 0 = sunny : no\n" + "| 0 = overcast : yes\n");
-
- assertEquals(ForestVisualizer.toString(forest, data.getDataset(), ATTR_NAMES),
- "Tree[1]:\nhumidity < 90 : yes\nhumidity >= 90\n"
- + "| outlook = rainy\n| | temperature < 71 : yes\n| | temperature >= 71 : no\n"
- + "| outlook = sunny : no\n" + "| outlook = overcast : yes\n");
+ assertEquals("Tree[1]:\n2 < 90 : yes\n2 >= 90\n"
+ + "| 0 = rainy\n| | 1 < 71 : yes\n| | 1 >= 71 : no\n"
+ + "| 0 = sunny : no\n" + "| 0 = overcast : yes\n", ForestVisualizer.toString(forest, data.getDataset(), null));
+
+ assertEquals("Tree[1]:\nhumidity < 90 : yes\nhumidity >= 90\n"
+ + "| outlook = rainy\n| | temperature < 71 : yes\n| | temperature >= 71 : no\n"
+ + "| outlook = sunny : no\n" + "| outlook = overcast : yes\n", ForestVisualizer.toString(forest, data.getDataset(), ATTR_NAMES));
+ }
+
+ @Test
+ public void testLeafless() throws Exception {
+ List<Instance> instances = Lists.newArrayList();
+ for (int i = 0; i < data.size(); i++) {
+ if (data.get(i).get(0) != 0.0d) {
+ instances.add(data.get(i));
+ }
+ }
+ Data lessData = new Data(data.getDataset(), instances);
+
+ // build tree
+ DecisionTreeBuilder builder = new DecisionTreeBuilder();
+ builder.setM(data.getDataset().nbAttributes() - 1);
+ builder.setMinSplitNum(0);
+ builder.setComplemented(false);
+ Node tree = builder.build(rng, lessData);
+
+ assertEquals("\noutlook = sunny\n| humidity < 85 : yes\n| humidity >= 85 : no\noutlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
+ }
+
+ @Test
+ public void testEmpty() throws Exception {
+ Data emptyData = new Data(data.getDataset());
+
+ // build tree
+ DecisionTreeBuilder builder = new DecisionTreeBuilder();
+ Node tree = builder.build(rng, emptyData);
+
+ assertEquals(" : unknown", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
}
}