You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2013/06/05 12:39:16 UTC

svn commit: r1489802 - in /mahout/trunk: CHANGELOG core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java

Author: gsingers
Date: Wed Jun  5 10:39:15 2013
New Revision: 1489802

URL: http://svn.apache.org/r1489802
Log:
MAHOUT-961: fix issues with visualizing DF trees

Modified:
    mahout/trunk/CHANGELOG
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java

Modified: mahout/trunk/CHANGELOG
URL: http://svn.apache.org/viewvc/mahout/trunk/CHANGELOG?rev=1489802&r1=1489801&r2=1489802&view=diff
==============================================================================
--- mahout/trunk/CHANGELOG (original)
+++ mahout/trunk/CHANGELOG Wed Jun  5 10:39:15 2013
@@ -84,3 +84,5 @@ __MAHOUT-1181: Adding StreamingKMeans Ma
   MAHOUT-1176: Introduce a changelog file to raise contributors attribution (ssc)
 
   MAHOUT-1108: Allows cluster-reuters.sh example to be executed on a cluster (elmer.garduno via gsingers) 
+
+  MAHOUT-961: Fix issue in decision forest tree visualizer to properly show stems of tree (Ikumasa Mukai via gsingers)
\ No newline at end of file

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java?rev=1489802&r1=1489801&r2=1489802&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java Wed Jun  5 10:39:15 2013
@@ -35,47 +35,42 @@ import org.apache.mahout.classifier.df.n
  * This tool is to visualize the Decision tree
  */
 public final class TreeVisualizer {
-
-  private TreeVisualizer() {
-  }
-
+  
+  private TreeVisualizer() {}
+  
   private static String doubleToString(double value) {
     DecimalFormat df = new DecimalFormat("0.##");
     return df.format(value);
   }
-
-  private static String toStringNode(Node node,
-                                     Dataset dataset,
-                                     String[] attrNames,
-                                     Map<String, Field> fields,
-                                     int layer) {
-
+  
+  private static String toStringNode(Node node, Dataset dataset,
+      String[] attrNames, Map<String,Field> fields, int layer) {
+    
     StringBuilder buff = new StringBuilder();
-
+    
     try {
-
       if (node instanceof CategoricalNode) {
         CategoricalNode cnode = (CategoricalNode) node;
         int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
-        double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode);
-        Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode);
-        String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset);
-        for (int i = 0; i < childs.length; i++) {
+        double[] values = (double[]) fields.get("CategoricalNode.values").get(
+            cnode);
+        Node[] childs = (Node[]) fields.get("CategoricalNode.childs")
+            .get(cnode);
+        String[][] attrValues = (String[][]) fields.get("Dataset.values").get(
+            dataset);
+        for (int i = 0; i < attrValues[attr].length; i++) {
+          int index = ArrayUtils.indexOf(values, i);
+          if (index < 0) {
+            continue;
+          }
           buff.append('\n');
           for (int j = 0; j < layer; j++) {
             buff.append("|   ");
           }
-          if (values[i] < attrValues[attr].length) {
-            if (attrNames == null) {
-              buff.append(attr);
-            } else {
-              buff.append(attrNames[attr]);
-            }
-            buff.append(" = ");
-            buff.append(attrValues[attr][(int) values[i]]);
-
-            buff.append(toStringNode(childs[i], dataset, attrNames, fields, layer + 1));
-          }
+          buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+              .append(attrValues[attr][i]);
+          buff.append(toStringNode(childs[index], dataset, attrNames, fields,
+              layer + 1));
         }
       } else if (node instanceof NumericalNode) {
         NumericalNode nnode = (NumericalNode) node;
@@ -87,13 +82,15 @@ public final class TreeVisualizer {
         for (int j = 0; j < layer; j++) {
           buff.append("|   ");
         }
-        buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ").append(doubleToString(split));
+        buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ")
+            .append(doubleToString(split));
         buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 1));
         buff.append('\n');
         for (int j = 0; j < layer; j++) {
           buff.append("|   ");
         }
-        buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ").append(doubleToString(split));
+        buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ")
+            .append(doubleToString(split));
         buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1));
       } else if (node instanceof Leaf) {
         Leaf leaf = (Leaf) node;
@@ -101,20 +98,19 @@ public final class TreeVisualizer {
         if (dataset.isNumerical(dataset.getLabelId())) {
           buff.append(" : ").append(doubleToString(label));
         } else {
-          buff.append(" : ").append(dataset.getLabelString((int) label));
+          buff.append(" : ").append(dataset.getLabelString(label));
         }
       }
-
     } catch (IllegalAccessException iae) {
       throw new IllegalStateException(iae);
     }
-
+    
     return buff.toString();
   }
-
-  private static Map<String, Field> getReflectMap() {
-    Map<String, Field> fields = new HashMap<String, Field>();
-
+  
+  private static Map<String,Field> getReflectMap() {
+    Map<String,Field> fields = new HashMap<String,Field>();
+    
     try {
       Field m = CategoricalNode.class.getDeclaredField("attr");
       m.setAccessible(true);
@@ -152,6 +148,7 @@ public final class TreeVisualizer {
   
   /**
    * Decision tree to String
+   * 
    * @param tree
    *          Node of tree
    * @param attrNames
@@ -160,37 +157,41 @@ public final class TreeVisualizer {
   public static String toString(Node tree, Dataset dataset, String[] attrNames) {
     return toStringNode(tree, dataset, attrNames, getReflectMap(), 0);
   }
-
+  
   /**
    * Print Decision tree
-   * @param tree  Node of tree
-   * @param attrNames attribute names
+   * 
+   * @param tree
+   *          Node of tree
+   * @param attrNames
+   *          attribute names
    */
   public static void print(Node tree, Dataset dataset, String[] attrNames) {
     System.out.println(toString(tree, dataset, attrNames));
   }
-
-  private static String toStringPredict(Node node,
-                                        Instance instance,
-                                        Dataset dataset,
-                                        String[] attrNames,
-                                        Map<String, Field> fields) {
+  
+  private static String toStringPredict(Node node, Instance instance,
+      Dataset dataset, String[] attrNames, Map<String,Field> fields) {
     StringBuilder buff = new StringBuilder();
-
+    
     try {
       if (node instanceof CategoricalNode) {
         CategoricalNode cnode = (CategoricalNode) node;
         int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
-        double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode);
-        Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode);
-        String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset);
-
+        double[] values = (double[]) fields.get("CategoricalNode.values").get(
+            cnode);
+        Node[] childs = (Node[]) fields.get("CategoricalNode.childs")
+            .get(cnode);
+        String[][] attrValues = (String[][]) fields.get("Dataset.values").get(
+            dataset);
+        
         int index = ArrayUtils.indexOf(values, instance.get(attr));
         if (index >= 0) {
           buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
               .append(attrValues[attr][(int) instance.get(attr)]);
           buff.append(" -> ");
-          buff.append(toStringPredict(childs[index], instance, dataset, attrNames, fields));
+          buff.append(toStringPredict(childs[index], instance, dataset,
+              attrNames, fields));
         }
       } else if (node instanceof NumericalNode) {
         NumericalNode nnode = (NumericalNode) node;
@@ -198,17 +199,21 @@ public final class TreeVisualizer {
         double split = (Double) fields.get("NumericalNode.split").get(nnode);
         Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
         Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
-
+        
         if (instance.get(attr) < split) {
-          buff.append('(').append(attrNames == null ? attr : attrNames[attr]).append(" = ")
-              .append(doubleToString(instance.get(attr))).append(") < ").append(doubleToString(split));
+          buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+              .append(" = ").append(doubleToString(instance.get(attr)))
+              .append(") < ").append(doubleToString(split));
           buff.append(" -> ");
-          buff.append(toStringPredict(loChild, instance, dataset, attrNames, fields));
+          buff.append(toStringPredict(loChild, instance, dataset, attrNames,
+              fields));
         } else {
-          buff.append('(').append(attrNames == null ? attr : attrNames[attr]).append(" = ")
-              .append(doubleToString(instance.get(attr))).append(") >= ").append(doubleToString(split));
+          buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+              .append(" = ").append(doubleToString(instance.get(attr)))
+              .append(") >= ").append(doubleToString(split));
           buff.append(" -> ");
-          buff.append(toStringPredict(hiChild, instance, dataset, attrNames, fields));
+          buff.append(toStringPredict(hiChild, instance, dataset, attrNames,
+              fields));
         }
       } else if (node instanceof Leaf) {
         Leaf leaf = (Leaf) node;
@@ -216,43 +221,47 @@ public final class TreeVisualizer {
         if (dataset.isNumerical(dataset.getLabelId())) {
           buff.append(doubleToString(label));
         } else {
-          buff.append(dataset.getLabelString((int) label));
+          buff.append(dataset.getLabelString(label));
         }
       }
     } catch (IllegalAccessException iae) {
       throw new IllegalStateException(iae);
     }
-
+    
     return buff.toString();
   }
-
+  
   /**
    * Predict trace to String
+   * 
    * @param tree
    *          Node of tree
    * @param attrNames
    *          attribute names
    */
   public static String[] predictTrace(Node tree, Data data, String[] attrNames) {
-    Map<String, Field> reflectMap = getReflectMap();
+    Map<String,Field> reflectMap = getReflectMap();
     String[] prediction = new String[data.size()];
     for (int i = 0; i < data.size(); i++) {
-      prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(), attrNames, reflectMap);
+      prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(),
+          attrNames, reflectMap);
     }
     return prediction;
   }
-
+  
   /**
    * Print predict trace
+   * 
    * @param tree
    *          Node of tree
    * @param attrNames
    *          attribute names
    */
   public static void predictTracePrint(Node tree, Data data, String[] attrNames) {
-    Map<String, Field> reflectMap = getReflectMap();
+    Map<String,Field> reflectMap = getReflectMap();
     for (int i = 0; i < data.size(); i++) {
-      System.out.println(toStringPredict(tree, data.get(i), data.getDataset(), attrNames, reflectMap));
+      System.out.println(toStringPredict(tree, data.get(i), data.getDataset(),
+          attrNames, reflectMap));
     }
   }
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java?rev=1489802&r1=1489801&r2=1489802&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java Wed Jun  5 10:39:15 2013
@@ -20,12 +20,12 @@ package org.apache.mahout.classifier.df.
 import java.util.List;
 import java.util.Random;
 
-import com.google.common.collect.Lists;
 import org.apache.mahout.classifier.df.DecisionForest;
 import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
 import org.apache.mahout.classifier.df.data.Data;
 import org.apache.mahout.classifier.df.data.DataLoader;
 import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
 import org.apache.mahout.classifier.df.node.CategoricalNode;
 import org.apache.mahout.classifier.df.node.Leaf;
 import org.apache.mahout.classifier.df.node.Node;
@@ -36,24 +36,28 @@ import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
+import com.google.common.collect.Lists;
+
 public final class VisualizerTest extends MahoutTestCase {
   
   private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no",
-    "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes",
-    "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no",
-    "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no",
-    "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes",
-    "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes",
-    "rainy,71,91,TRUE,no"};
+      "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes",
+      "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no",
+      "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no",
+      "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes",
+      "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes",
+      "rainy,71,91,TRUE,no"};
   
   private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-",
-    "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",};
+      "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",};
   
   private static final String[] ATTR_NAMES = {"outlook", "temperature",
-    "humidity", "windy", "play"};
+      "humidity", "windy", "play"};
   
   private Random rng;
+  
   private Data data;
+  
   private Data testData;
   
   @Override
@@ -61,10 +65,11 @@ public final class VisualizerTest extend
   public void setUp() throws Exception {
     super.setUp();
     
-    rng = RandomUtils.getRandom();
+    rng = RandomUtils.getRandom(1);
     
     // Dataset
-    Dataset dataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA);
+    Dataset dataset = DataLoader
+        .generateDataset("C N N C L", false, TRAIN_DATA);
     
     // Training data
     data = DataLoader.loadData(dataset, TRAIN_DATA);
@@ -80,10 +85,9 @@ public final class VisualizerTest extend
     builder.setM(data.getDataset().nbAttributes() - 1);
     Node tree = builder.build(rng, data);
     
-    assertEquals(TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES),
-      "\noutlook = rainy\n|   windy = FALSE : yes\n|   windy = TRUE : no\n"
-        + "outlook = overcast : yes\n"
-        + "outlook = sunny\n|   humidity < 85 : yes\n|   humidity >= 85 : no");
+    assertEquals("\noutlook = rainy\n|   windy = FALSE : yes\n|   windy = TRUE : no\n"
+            + "outlook = sunny\n|   humidity < 85 : yes\n|   humidity >= 85 : no\n"
+            + "outlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
   }
   
   @Test
@@ -93,8 +97,9 @@ public final class VisualizerTest extend
     builder.setM(data.getDataset().nbAttributes() - 1);
     Node tree = builder.build(rng, data);
     
-    String[] prediction = TreeVisualizer.predictTrace(tree, testData, ATTR_NAMES);
-    Assert.assertArrayEquals(new String[]{
+    String[] prediction = TreeVisualizer.predictTrace(tree, testData,
+        ATTR_NAMES);
+    Assert.assertArrayEquals(new String[] {
         "outlook = rainy -> windy = TRUE -> no", "outlook = overcast -> yes",
         "outlook = sunny -> (humidity = 90) >= 85 -> no"}, prediction);
   }
@@ -103,22 +108,51 @@ public final class VisualizerTest extend
   public void testForestVisualize() throws Exception {
     // Tree
     NumericalNode root = new NumericalNode(2, 90, new Leaf(0),
-      new CategoricalNode(0, new double[] {0, 1, 2}, new Node[] {
-        new NumericalNode(1, 71, new Leaf(0), new Leaf(1)), new Leaf(1),
-        new Leaf(0)}));
+        new CategoricalNode(0, new double[] {0, 1, 2}, new Node[] {
+            new NumericalNode(1, 71, new Leaf(0), new Leaf(1)), new Leaf(1),
+            new Leaf(0)}));
     List<Node> trees = Lists.newArrayList();
     trees.add(root);
     
     // Forest
     DecisionForest forest = new DecisionForest(trees);
-    assertEquals(ForestVisualizer.toString(forest, data.getDataset(), null),
-      "Tree[1]:\n2 < 90 : yes\n2 >= 90\n"
-        + "|   0 = rainy\n|   |   1 < 71 : yes\n|   |   1 >= 71 : no\n"
-        + "|   0 = sunny : no\n" + "|   0 = overcast : yes\n");
-    
-    assertEquals(ForestVisualizer.toString(forest, data.getDataset(), ATTR_NAMES),
-      "Tree[1]:\nhumidity < 90 : yes\nhumidity >= 90\n"
-        + "|   outlook = rainy\n|   |   temperature < 71 : yes\n|   |   temperature >= 71 : no\n"
-        + "|   outlook = sunny : no\n" + "|   outlook = overcast : yes\n");
+    assertEquals("Tree[1]:\n2 < 90 : yes\n2 >= 90\n"
+            + "|   0 = rainy\n|   |   1 < 71 : yes\n|   |   1 >= 71 : no\n"
+            + "|   0 = sunny : no\n" + "|   0 = overcast : yes\n", ForestVisualizer.toString(forest, data.getDataset(), null));
+
+    assertEquals("Tree[1]:\nhumidity < 90 : yes\nhumidity >= 90\n"
+            + "|   outlook = rainy\n|   |   temperature < 71 : yes\n|   |   temperature >= 71 : no\n"
+            + "|   outlook = sunny : no\n" + "|   outlook = overcast : yes\n", ForestVisualizer.toString(forest, data.getDataset(), ATTR_NAMES));
+  }
+  
+  @Test
+  public void testLeafless() throws Exception {
+    List<Instance> instances = Lists.newArrayList();
+    for (int i = 0; i < data.size(); i++) {
+      if (data.get(i).get(0) != 0.0d) {
+        instances.add(data.get(i));
+      }
+    }
+    Data lessData = new Data(data.getDataset(), instances);
+    
+    // build tree
+    DecisionTreeBuilder builder = new DecisionTreeBuilder();
+    builder.setM(data.getDataset().nbAttributes() - 1);
+    builder.setMinSplitNum(0);
+    builder.setComplemented(false);
+    Node tree = builder.build(rng, lessData);
+
+    assertEquals("\noutlook = sunny\n|   humidity < 85 : yes\n|   humidity >= 85 : no\noutlook = overcast : yes", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
+  }
+  
+  @Test
+  public void testEmpty() throws Exception {
+    Data emptyData = new Data(data.getDataset());
+    
+    // build tree
+    DecisionTreeBuilder builder = new DecisionTreeBuilder();
+    Node tree = builder.build(rng, emptyData);
+
+    assertEquals(" : unknown", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
   }
 }