You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2012/06/22 12:42:58 UTC

svn commit: r1352836 - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/ core/src/main/java/org/apache/mahout/classifier/df/ core/src/main/java/org/apache/mahout/classifier/df/builder/ core/src/main/java/org/apache/mahout/classifier/df...

Author: srowen
Date: Fri Jun 22 10:42:57 2012
New Revision: 1352836

URL: http://svn.apache.org/viewvc?rev=1352836&view=rev
Log:
MAHOUT-954 "Unpredictable" have to be represented by NaN on DF.

Added:
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java Fri Jun 22 10:42:57 2012
@@ -77,30 +77,28 @@ public class RegressionResultAnalyzer {
     double sumActualSquared = 0.0;
     double sumResult = 0.0;
     double sumResultSquared = 0.0;
+    double sumActualResult = 0.0;
     double sumAbsolute = 0.0;
     double sumAbsoluteSquared = 0.0;
+    int predictable = 0;
+    int unpredictable = 0;
 
     for (Result res : results) {
       double actual = res.getActual();
       double result = res.getResult();
-      sumActual += actual;
-      sumActualSquared += actual * actual;
-      sumResult += result;
-      sumResultSquared += result * result;
-      double absolute = Math.abs(actual - result);
-      sumAbsolute += absolute;
-      sumAbsoluteSquared += absolute * absolute;
-    }
-    
-    double varActual = sumActualSquared - sumActual * sumActual / results.size();
-    double varResult = sumResultSquared - sumResult * sumResult / results.size();
-    double varAbsolute = sumResultSquared - sumActual * sumResult /  results.size();
-
-    double correlation;
-    if (varActual * varResult <= 0) {
-      correlation = 0.0;
-    } else {
-      correlation = varAbsolute / Math.sqrt(varActual * varResult);
+      if (Double.isNaN(result)) {
+        unpredictable++;
+      } else {
+        sumActual += actual;
+        sumActualSquared += actual * actual;
+        sumResult += result;
+        sumResultSquared += result * result;
+        sumActualResult += actual * result;
+        double absolute = Math.abs(actual - result);
+        sumAbsolute += absolute;
+        sumAbsoluteSquared += absolute * absolute;
+        predictable++;
+      }
     }
 
     StringBuilder returnString = new StringBuilder();
@@ -108,16 +106,33 @@ public class RegressionResultAnalyzer {
     returnString.append("=======================================================\n");
     returnString.append("Summary\n");
     returnString.append("-------------------------------------------------------\n");
-
-    NumberFormat decimalFormatter = new DecimalFormat("0.####");
     
-    returnString.append(StringUtils.rightPad("Correlation coefficient", 40)).append(": ").append(
-      StringUtils.leftPad(decimalFormatter.format(correlation), 10)).append('\n');
-    returnString.append(StringUtils.rightPad("Mean absolute error", 40)).append(": ").append(
-      StringUtils.leftPad(decimalFormatter.format(sumAbsolute / results.size()), 10)).append('\n');
-    returnString.append(StringUtils.rightPad("Root mean squared error", 40)).append(": ").append(
-      StringUtils.leftPad(decimalFormatter.format(Math.sqrt(sumAbsoluteSquared / results.size())),
-        10)).append('\n');
+    if (predictable > 0) {
+      double varActual = sumActualSquared - sumActual * sumActual / predictable;
+      double varResult = sumResultSquared - sumResult * sumResult / predictable;
+      double varCo = sumActualResult - sumActual * sumResult /  predictable;
+  
+      double correlation;
+      if (varActual * varResult <= 0) {
+        correlation = 0.0;
+      } else {
+        correlation = varCo / Math.sqrt(varActual * varResult);
+      }
+  
+      NumberFormat decimalFormatter = new DecimalFormat("0.####");
+      
+      returnString.append(StringUtils.rightPad("Correlation coefficient", 40)).append(": ").append(
+        StringUtils.leftPad(decimalFormatter.format(correlation), 10)).append('\n');
+      returnString.append(StringUtils.rightPad("Mean absolute error", 40)).append(": ").append(
+        StringUtils.leftPad(decimalFormatter.format(sumAbsolute / predictable), 10)).append('\n');
+      returnString.append(StringUtils.rightPad("Root mean squared error", 40)).append(": ").append(
+        StringUtils.leftPad(decimalFormatter.format(Math.sqrt(sumAbsoluteSquared / predictable)),
+          10)).append('\n');
+    }
+    returnString.append(StringUtils.rightPad("Predictable Instances", 40)).append(": ").append(
+      StringUtils.leftPad(Integer.toString(predictable), 10)).append('\n');
+    returnString.append(StringUtils.rightPad("Unpredictable Instances", 40)).append(": ").append(
+      StringUtils.leftPad(Integer.toString(unpredictable), 10)).append('\n');
     returnString.append(StringUtils.rightPad("Total Regressed Instances", 40)).append(": ").append(
       StringUtils.leftPad(Integer.toString(results.size()), 10)).append('\n');
     returnString.append('\n');

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java Fri Jun 22 10:42:57 2012
@@ -61,17 +61,22 @@ public class DecisionForest implements W
   /**
    * Classifies the data and calls callback for each classification
    */
-  public void classify(Data data, double[] predictions) {
+  public void classify(Data data, double[][] predictions) {
     Preconditions.checkArgument(data.size() == predictions.length, "predictions.length must be equal to data.size()");
 
     if (data.isEmpty()) {
       return; // nothing to classify
     }
 
+    int treeId = 0;
     for (Node tree : trees) {
       for (int index = 0; index < data.size(); index++) {
-        predictions[index] = tree.classify(data.get(index));
+        if (predictions[index] == null) {
+          predictions[index] = new double[trees.size()];
+        }
+        predictions[index][treeId] = tree.classify(data.get(index));
       }
+      treeId++;
     }
   }
   
@@ -80,7 +85,7 @@ public class DecisionForest implements W
    * 
    * @param rng
    *          Random number generator, used to break ties randomly
-   * @return -1 if the label cannot be predicted
+   * @return NaN if the label cannot be predicted
    */
   public double classify(Dataset dataset, Random rng, Instance instance) {
     if (dataset.isNumerical(dataset.getLabelId())) {
@@ -88,25 +93,30 @@ public class DecisionForest implements W
       int cnt = 0;
       for (Node tree : trees) {
         double prediction = tree.classify(instance);
-        if (prediction != -1) {
+        if (!Double.isNaN(prediction)) {
           sum += prediction;
           cnt++;
         }
       }
-      return sum / cnt;
+
+      if (cnt > 0) {
+        return sum / cnt;
+      } else {
+        return Double.NaN;
+      }
     } else {
       int[] predictions = new int[dataset.nblabels()];
       for (Node tree : trees) {
         double prediction = tree.classify(instance);
-        if (prediction != -1) {
+        if (!Double.isNaN(prediction)) {
           predictions[(int) prediction]++;
         }
       }
-      
+
       if (DataUtils.sum(predictions) == 0) {
-        return -1; // no prediction available
+        return Double.NaN; // no prediction available
       }
-      
+
       return DataUtils.maxindex(rng, predictions);
     }
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java Fri Jun 22 10:42:57 2012
@@ -120,7 +120,7 @@ public class DecisionTreeBuilder impleme
     }
 
     if (data.isEmpty()) {
-      return new Leaf(-1);
+      return new Leaf(Double.NaN);
     }
 
     double sum = 0.0;

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java Fri Jun 22 10:42:57 2012
@@ -184,8 +184,8 @@ public class Dataset implements Writable
    * @return label's value
    */
   public String getLabelString(double code) {
-    // handle the case (prediction == -1)
-    if (code == -1) {
+    // handle the case (prediction is NaN)
+    if (Double.isNaN(code)) {
       return "unknown";
     }
     return values[labelId][(int) code];

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java Fri Jun 22 10:42:57 2012
@@ -27,13 +27,13 @@ import java.io.IOException;
 import java.util.Arrays;
 
 public class CategoricalNode extends Node {
+
   private int attr;
-  
   private double[] values;
-  
   private Node[] childs;
   
-  public CategoricalNode() { }
+  public CategoricalNode() {
+  }
   
   public CategoricalNode(int attr, double[] values, Node[] childs) {
     this.attr = attr;
@@ -46,7 +46,7 @@ public class CategoricalNode extends Nod
     int index = ArrayUtils.indexOf(values, instance.get(attr));
     if (index == -1) {
       // value not available, we cannot predict
-      return -1;
+      return Double.NaN;
     }
     return childs[index].classify(instance);
   }

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java?rev=1352836&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java Fri Jun 22 10:42:57 2012
@@ -0,0 +1,128 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public class RegressionResultAnalyzerTest extends MahoutTestCase {
+
+  private static final Pattern p1 = Pattern.compile("Correlation coefficient *: *(.*)\n");
+  private static final Pattern p2 = Pattern.compile("Mean absolute error *: *(.*)\n");
+  private static final Pattern p3 = Pattern.compile("Root mean squared error *: *(.*)\n");
+  private static final Pattern p4 = Pattern.compile("Predictable Instances *: *(.*)\n");
+  private static final Pattern p5 = Pattern.compile("Unpredictable Instances *: *(.*)\n");
+  private static final Pattern p6 = Pattern.compile("Total Regressed Instances *: *(.*)\n");
+  
+  private static double[] parseAnalysis(CharSequence analysis) {
+    double[] results = new double[3];
+    Matcher m = p1.matcher(analysis);
+    if (m.find()) {
+      results[0] = Double.parseDouble(m.group(1));
+    } else {
+      return null;
+    }
+    m = p2.matcher(analysis);
+    if (m.find()) {
+      results[1] = Double.parseDouble(m.group(1));
+    } else {
+      return null;
+    }
+    m = p3.matcher(analysis);
+    if (m.find()) {
+      results[2] = Double.parseDouble(m.group(1));
+    } else {
+      return null;
+    }
+    return results;
+  }
+
+  private static int[] parseAnalysisCount(CharSequence analysis) {
+    int[] results = new int[3];
+    Matcher m = p4.matcher(analysis);
+    if (m.find()) {
+      results[0] = Integer.parseInt(m.group(1));
+    }
+    m = p5.matcher(analysis);
+    if (m.find()) {
+      results[1] = Integer.parseInt(m.group(1));
+    }
+    m = p6.matcher(analysis);
+    if (m.find()) {
+      results[2] = Integer.parseInt(m.group(1));
+    }
+    return results;
+  }
+  
+  @Test
+  public void testAnalyze() {
+    double results[][] = new double[10][2];
+
+    for (int i = 0; i < results.length; i++) {
+      results[i][0] = i;
+      results[i][1] = i + 1;
+    }
+    RegressionResultAnalyzer analyzer = new RegressionResultAnalyzer();
+    analyzer.setInstances(results);
+    String analysis = analyzer.toString();
+    assertArrayEquals(parseAnalysis(analysis), new double[] {1.0, 1.0, 1.0}, 0);
+
+    for (int i = 0; i < results.length; i++) {
+      results[i][1] = Math.sqrt(i);
+    }
+    analyzer = new RegressionResultAnalyzer();
+    analyzer.setInstances(results);
+    analysis = analyzer.toString();
+    assertArrayEquals(parseAnalysis(analysis), new double[] {0.9573, 2.5694, 3.2848}, 0);
+
+    for (int i = 0; i < results.length; i++) {
+      results[i][0] = results.length - i;
+    }
+    analyzer = new RegressionResultAnalyzer();
+    analyzer.setInstances(results);
+    analysis = analyzer.toString();
+    assertArrayEquals(parseAnalysis(analysis), new double[] {-0.9573, 4.1351, 5.1573}, 0);
+  }
+
+  @Test
+  public void testUnpredictable() {
+    double[][] results = new double[10][2];
+
+    for (int i = 0; i < results.length; i++) {
+      results[i][0] = i;
+      results[i][1] = Double.NaN;
+    }
+    RegressionResultAnalyzer analyzer = new RegressionResultAnalyzer();
+    analyzer.setInstances(results);
+    String analysis = analyzer.toString();
+    assertNull(parseAnalysis(analysis));
+    assertArrayEquals(parseAnalysisCount(analysis), new int[] {0, 10, 10});
+
+    for (int i = 0; i < results.length - 3; i++) {
+      results[i][1] = Math.sqrt(i);
+    }
+    analyzer = new RegressionResultAnalyzer();
+    analyzer.setInstances(results);
+    analysis = analyzer.toString();
+    assertArrayEquals(parseAnalysis(analysis), new double[] {0.9552, 1.4526, 1.9345}, 0);
+    assertArrayEquals(parseAnalysisCount(analysis), new int[] {7, 3, 10});
+  }
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java?rev=1352836&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java Fri Jun 22 10:42:57 2012
@@ -0,0 +1,199 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import java.util.List;
+import java.util.Random;
+
+import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public class DecisionForestTest extends MahoutTestCase {
+
+  private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no",
+    "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes",
+    "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no",
+    "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no",
+    "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes",
+    "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes",
+    "rainy,71,91,TRUE,no"};
+  
+  private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-",
+    "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",};
+
+  private Random rng;
+
+  @Override
+  public void setUp() throws Exception {
+    super.setUp();
+    rng = RandomUtils.getRandom();
+  }
+
+  private static Data[] generateTrainingDataA() throws DescriptorException {
+    // Dataset
+    Dataset dataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA);
+    
+    // Training data
+    Data data = DataLoader.loadData(dataset, TRAIN_DATA);
+    @SuppressWarnings("unchecked")
+    List<Instance>[] instances = new List[3];
+    for (int i = 0; i < instances.length; i++) {
+      instances[i] = Lists.newArrayList();
+    }
+    for (int i = 0; i < data.size(); i++) {
+      if (data.get(i).get(0) == 0.0d) {
+        instances[0].add(data.get(i));
+      } else {
+        instances[1].add(data.get(i));
+      }
+    }
+    Data[] datas = new Data[instances.length];
+    for (int i = 0; i < datas.length; i++) {
+      datas[i] = new Data(dataset, instances[i]);
+    }
+
+    return datas;
+  }
+
+  private static Data[] generateTrainingDataB() throws DescriptorException {
+
+    // Training data
+    String[] trainData = new String[20];
+    for (int i = 0; i < trainData.length; i++) {
+      if (i % 3 == 0) {
+        trainData[i] = "A," + (40 - i) + ',' +  (i + 20);
+      } else if (i % 3 == 1) {
+        trainData[i] = "B," + (i + 20) + ',' +  (40 - i);
+      } else {
+        trainData[i] = "C," + (i + 20) + ',' +  (i + 20);
+      }
+    }
+    // Dataset
+    Dataset dataset = DataLoader.generateDataset("C N L", true, trainData);
+    Data[] datas = new Data[3];
+    datas[0] = DataLoader.loadData(dataset, trainData);
+
+    // Training data
+    trainData = new String[20];
+    for (int i = 0; i < trainData.length; i++) {
+      if (i % 2 == 0) {
+        trainData[i] = "A," + (50 - i) + ',' +  (i + 10);
+      } else {
+        trainData[i] = "B," + (i + 10) + ',' +  (50 - i);
+      }
+    }
+    datas[1] = DataLoader.loadData(dataset, trainData);
+
+    // Training data
+    trainData = new String[10];
+    for (int i = 0; i < trainData.length; i++) {
+      trainData[i] = "A," + (40 - i) + ',' +  (i + 20);
+    }
+    datas[2] = DataLoader.loadData(dataset, trainData);
+
+    return datas;
+  }
+  
+  private DecisionForest buildForest(Data[] datas) {
+    List<Node> trees = Lists.newArrayList();
+    for (Data data : datas) {
+      // build tree
+      DecisionTreeBuilder builder = new DecisionTreeBuilder();
+      builder.setM(data.getDataset().nbAttributes() - 1);
+      builder.setMinSplitNum(0);
+      builder.setComplemented(false);
+      trees.add(builder.build(rng, data));
+    }
+    return new DecisionForest(trees);
+  }
+  
+  @Test
+  public void testClassify() throws DescriptorException {
+    // Training data
+    Data[] datas = generateTrainingDataA();
+    // Build Forest
+    DecisionForest forest = buildForest(datas);
+    // Test data
+    Data testData = DataLoader.loadData(datas[0].getDataset(), TEST_DATA);
+
+    for (int i = 0; i < testData.size(); i++) {
+      assertEquals(1.0, forest.classify(testData.getDataset(), rng, testData.get(i)), 0);
+    }
+  }
+
+  @Test
+  public void testClassifyData() throws DescriptorException {
+    // Training data
+    Data[] datas = generateTrainingDataA();
+    // Build Forest
+    DecisionForest forest = buildForest(datas);
+    // Test data
+    Data testData = DataLoader.loadData(datas[0].getDataset(), TEST_DATA);
+
+    double[][] predictions = new double[testData.size()][];
+    forest.classify(testData, predictions);
+    assertArrayEquals(predictions, new double[][] {{1.0,Double.NaN,Double.NaN},
+      {1.0,0.0,Double.NaN},{1.0,1.0,Double.NaN}});
+  }
+
+  @Test
+  public void testRegression() throws DescriptorException {
+    Data[] datas = generateTrainingDataB();
+    DecisionForest[] forests = new DecisionForest[datas.length];
+    for (int i = 0; i < datas.length; i++) {
+      Data[] subDatas = new Data[datas.length - 1];
+      int k = 0;
+      for (int j = 0; j < datas.length; j++) {
+        if (j != i) {
+          subDatas[k] = datas[j];
+          k++;
+        }
+      }
+      forests[i] = buildForest(subDatas);
+    }
+    
+    double[][] predictions = new double[datas[0].size()][];
+    forests[0].classify(datas[0], predictions);
+    assertArrayEquals(predictions[0], new double[] {20.0, 20.0}, 0);
+    assertArrayEquals(predictions[1], new double[] {39.0, 29.0}, 0);
+    assertArrayEquals(predictions[2], new double[] {Double.NaN, 29.0}, 0);
+    assertArrayEquals(predictions[17], new double[] {Double.NaN, 23.0}, 0);
+
+    predictions = new double[datas[1].size()][];
+    forests[1].classify(datas[1], predictions);
+    assertArrayEquals(predictions[19], new double[] {30.0, 29.0}, 0);
+
+    predictions = new double[datas[2].size()][];
+    forests[2].classify(datas[2], predictions);
+    assertArrayEquals(predictions[9], new double[] {29.0, 28.0}, 0);
+
+    assertEquals(20.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(0)), 0);
+    assertEquals(34.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(1)), 0);
+    assertEquals(29.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(2)), 0);
+  }
+}

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java Fri Jun 22 10:42:57 2012
@@ -57,7 +57,7 @@ public final class Step1MapperTest exten
         assertTrue(expected.contains(data.get(index)));
       }
 
-      return new Leaf(-1);
+      return new Leaf(Double.NaN);
     }
   }
 

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java Fri Jun 22 10:42:57 2012
@@ -18,6 +18,7 @@
 package org.apache.mahout.classifier.df;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Random;
 
 import org.apache.commons.cli2.CommandLine;
@@ -112,13 +113,26 @@ public class BreimanExample extends Conf
     
     // compute the test set error (Selection Error), and mean tree error (One Tree Error),
     double[] testLabels = test.extractLabels();
-    double[] predictions = new double[test.size()];
+    double[][] predictions = new double[test.size()][];
     
     forestM.classify(test, predictions);
-    sumTestErrM += ErrorEstimate.errorRate(testLabels, predictions);
+    double[] sumPredictions = new double[test.size()];
+    Arrays.fill(sumPredictions, 0.0);
+    for (int i = 0; i < predictions.length; i++) {
+      for (int j = 0; j < predictions[i].length; j++) {
+        sumPredictions[i] += predictions[i][j];
+      }
+    }
+    sumTestErrM += ErrorEstimate.errorRate(testLabels, sumPredictions);
     
     forestOne.classify(test, predictions);
-    sumTestErrOne += ErrorEstimate.errorRate(testLabels, predictions);
+    Arrays.fill(sumPredictions, 0.0);
+    for (int i = 0; i < predictions.length; i++) {
+      for (int j = 0; j < predictions[i].length; j++) {
+        sumPredictions[i] += predictions[i][j];
+      }
+    }
+    sumTestErrOne += ErrorEstimate.errorRate(testLabels, sumPredictions);
   }
   
   public static void main(String[] args) throws Exception {