You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2012/06/22 12:42:58 UTC
svn commit: r1352836 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/classifier/
core/src/main/java/org/apache/mahout/classifier/df/
core/src/main/java/org/apache/mahout/classifier/df/builder/
core/src/main/java/org/apache/mahout/classifier/df...
Author: srowen
Date: Fri Jun 22 10:42:57 2012
New Revision: 1352836
URL: http://svn.apache.org/viewvc?rev=1352836&view=rev
Log:
MAHOUT-954 "Unpredictable" have to be represented by NaN on DF.
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java Fri Jun 22 10:42:57 2012
@@ -77,30 +77,28 @@ public class RegressionResultAnalyzer {
double sumActualSquared = 0.0;
double sumResult = 0.0;
double sumResultSquared = 0.0;
+ double sumActualResult = 0.0;
double sumAbsolute = 0.0;
double sumAbsoluteSquared = 0.0;
+ int predictable = 0;
+ int unpredictable = 0;
for (Result res : results) {
double actual = res.getActual();
double result = res.getResult();
- sumActual += actual;
- sumActualSquared += actual * actual;
- sumResult += result;
- sumResultSquared += result * result;
- double absolute = Math.abs(actual - result);
- sumAbsolute += absolute;
- sumAbsoluteSquared += absolute * absolute;
- }
-
- double varActual = sumActualSquared - sumActual * sumActual / results.size();
- double varResult = sumResultSquared - sumResult * sumResult / results.size();
- double varAbsolute = sumResultSquared - sumActual * sumResult / results.size();
-
- double correlation;
- if (varActual * varResult <= 0) {
- correlation = 0.0;
- } else {
- correlation = varAbsolute / Math.sqrt(varActual * varResult);
+ if (Double.isNaN(result)) {
+ unpredictable++;
+ } else {
+ sumActual += actual;
+ sumActualSquared += actual * actual;
+ sumResult += result;
+ sumResultSquared += result * result;
+ sumActualResult += actual * result;
+ double absolute = Math.abs(actual - result);
+ sumAbsolute += absolute;
+ sumAbsoluteSquared += absolute * absolute;
+ predictable++;
+ }
}
StringBuilder returnString = new StringBuilder();
@@ -108,16 +106,33 @@ public class RegressionResultAnalyzer {
returnString.append("=======================================================\n");
returnString.append("Summary\n");
returnString.append("-------------------------------------------------------\n");
-
- NumberFormat decimalFormatter = new DecimalFormat("0.####");
- returnString.append(StringUtils.rightPad("Correlation coefficient", 40)).append(": ").append(
- StringUtils.leftPad(decimalFormatter.format(correlation), 10)).append('\n');
- returnString.append(StringUtils.rightPad("Mean absolute error", 40)).append(": ").append(
- StringUtils.leftPad(decimalFormatter.format(sumAbsolute / results.size()), 10)).append('\n');
- returnString.append(StringUtils.rightPad("Root mean squared error", 40)).append(": ").append(
- StringUtils.leftPad(decimalFormatter.format(Math.sqrt(sumAbsoluteSquared / results.size())),
- 10)).append('\n');
+ if (predictable > 0) {
+ double varActual = sumActualSquared - sumActual * sumActual / predictable;
+ double varResult = sumResultSquared - sumResult * sumResult / predictable;
+ double varCo = sumActualResult - sumActual * sumResult / predictable;
+
+ double correlation;
+ if (varActual * varResult <= 0) {
+ correlation = 0.0;
+ } else {
+ correlation = varCo / Math.sqrt(varActual * varResult);
+ }
+
+ NumberFormat decimalFormatter = new DecimalFormat("0.####");
+
+ returnString.append(StringUtils.rightPad("Correlation coefficient", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(correlation), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Mean absolute error", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(sumAbsolute / predictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Root mean squared error", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(Math.sqrt(sumAbsoluteSquared / predictable)),
+ 10)).append('\n');
+ }
+ returnString.append(StringUtils.rightPad("Predictable Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(predictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Unpredictable Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(unpredictable), 10)).append('\n');
returnString.append(StringUtils.rightPad("Total Regressed Instances", 40)).append(": ").append(
StringUtils.leftPad(Integer.toString(results.size()), 10)).append('\n');
returnString.append('\n');
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java Fri Jun 22 10:42:57 2012
@@ -61,17 +61,22 @@ public class DecisionForest implements W
/**
* Classifies the data and calls callback for each classification
*/
- public void classify(Data data, double[] predictions) {
+ public void classify(Data data, double[][] predictions) {
Preconditions.checkArgument(data.size() == predictions.length, "predictions.length must be equal to data.size()");
if (data.isEmpty()) {
return; // nothing to classify
}
+ int treeId = 0;
for (Node tree : trees) {
for (int index = 0; index < data.size(); index++) {
- predictions[index] = tree.classify(data.get(index));
+ if (predictions[index] == null) {
+ predictions[index] = new double[trees.size()];
+ }
+ predictions[index][treeId] = tree.classify(data.get(index));
}
+ treeId++;
}
}
@@ -80,7 +85,7 @@ public class DecisionForest implements W
*
* @param rng
* Random number generator, used to break ties randomly
- * @return -1 if the label cannot be predicted
+ * @return NaN if the label cannot be predicted
*/
public double classify(Dataset dataset, Random rng, Instance instance) {
if (dataset.isNumerical(dataset.getLabelId())) {
@@ -88,25 +93,30 @@ public class DecisionForest implements W
int cnt = 0;
for (Node tree : trees) {
double prediction = tree.classify(instance);
- if (prediction != -1) {
+ if (!Double.isNaN(prediction)) {
sum += prediction;
cnt++;
}
}
- return sum / cnt;
+
+ if (cnt > 0) {
+ return sum / cnt;
+ } else {
+ return Double.NaN;
+ }
} else {
int[] predictions = new int[dataset.nblabels()];
for (Node tree : trees) {
double prediction = tree.classify(instance);
- if (prediction != -1) {
+ if (!Double.isNaN(prediction)) {
predictions[(int) prediction]++;
}
}
-
+
if (DataUtils.sum(predictions) == 0) {
- return -1; // no prediction available
+ return Double.NaN; // no prediction available
}
-
+
return DataUtils.maxindex(rng, predictions);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java Fri Jun 22 10:42:57 2012
@@ -120,7 +120,7 @@ public class DecisionTreeBuilder impleme
}
if (data.isEmpty()) {
- return new Leaf(-1);
+ return new Leaf(Double.NaN);
}
double sum = 0.0;
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java Fri Jun 22 10:42:57 2012
@@ -184,8 +184,8 @@ public class Dataset implements Writable
* @return label's value
*/
public String getLabelString(double code) {
- // handle the case (prediction == -1)
- if (code == -1) {
+ // handle the case (prediction is NaN)
+ if (Double.isNaN(code)) {
return "unknown";
}
return values[labelId][(int) code];
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java Fri Jun 22 10:42:57 2012
@@ -27,13 +27,13 @@ import java.io.IOException;
import java.util.Arrays;
public class CategoricalNode extends Node {
+
private int attr;
-
private double[] values;
-
private Node[] childs;
- public CategoricalNode() { }
+ public CategoricalNode() {
+ }
public CategoricalNode(int attr, double[] values, Node[] childs) {
this.attr = attr;
@@ -46,7 +46,7 @@ public class CategoricalNode extends Nod
int index = ArrayUtils.indexOf(values, instance.get(attr));
if (index == -1) {
// value not available, we cannot predict
- return -1;
+ return Double.NaN;
}
return childs[index].classify(instance);
}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java?rev=1352836&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java Fri Jun 22 10:42:57 2012
@@ -0,0 +1,128 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public class RegressionResultAnalyzerTest extends MahoutTestCase {
+
+ private static final Pattern p1 = Pattern.compile("Correlation coefficient *: *(.*)\n");
+ private static final Pattern p2 = Pattern.compile("Mean absolute error *: *(.*)\n");
+ private static final Pattern p3 = Pattern.compile("Root mean squared error *: *(.*)\n");
+ private static final Pattern p4 = Pattern.compile("Predictable Instances *: *(.*)\n");
+ private static final Pattern p5 = Pattern.compile("Unpredictable Instances *: *(.*)\n");
+ private static final Pattern p6 = Pattern.compile("Total Regressed Instances *: *(.*)\n");
+
+ private static double[] parseAnalysis(CharSequence analysis) {
+ double[] results = new double[3];
+ Matcher m = p1.matcher(analysis);
+ if (m.find()) {
+ results[0] = Double.parseDouble(m.group(1));
+ } else {
+ return null;
+ }
+ m = p2.matcher(analysis);
+ if (m.find()) {
+ results[1] = Double.parseDouble(m.group(1));
+ } else {
+ return null;
+ }
+ m = p3.matcher(analysis);
+ if (m.find()) {
+ results[2] = Double.parseDouble(m.group(1));
+ } else {
+ return null;
+ }
+ return results;
+ }
+
+ private static int[] parseAnalysisCount(CharSequence analysis) {
+ int[] results = new int[3];
+ Matcher m = p4.matcher(analysis);
+ if (m.find()) {
+ results[0] = Integer.parseInt(m.group(1));
+ }
+ m = p5.matcher(analysis);
+ if (m.find()) {
+ results[1] = Integer.parseInt(m.group(1));
+ }
+ m = p6.matcher(analysis);
+ if (m.find()) {
+ results[2] = Integer.parseInt(m.group(1));
+ }
+ return results;
+ }
+
+ @Test
+ public void testAnalyze() {
+ double results[][] = new double[10][2];
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][0] = i;
+ results[i][1] = i + 1;
+ }
+ RegressionResultAnalyzer analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ String analysis = analyzer.toString();
+ assertArrayEquals(parseAnalysis(analysis), new double[] {1.0, 1.0, 1.0}, 0);
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][1] = Math.sqrt(i);
+ }
+ analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ analysis = analyzer.toString();
+ assertArrayEquals(parseAnalysis(analysis), new double[] {0.9573, 2.5694, 3.2848}, 0);
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][0] = results.length - i;
+ }
+ analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ analysis = analyzer.toString();
+ assertArrayEquals(parseAnalysis(analysis), new double[] {-0.9573, 4.1351, 5.1573}, 0);
+ }
+
+ @Test
+ public void testUnpredictable() {
+ double[][] results = new double[10][2];
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][0] = i;
+ results[i][1] = Double.NaN;
+ }
+ RegressionResultAnalyzer analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ String analysis = analyzer.toString();
+ assertNull(parseAnalysis(analysis));
+ assertArrayEquals(parseAnalysisCount(analysis), new int[] {0, 10, 10});
+
+ for (int i = 0; i < results.length - 3; i++) {
+ results[i][1] = Math.sqrt(i);
+ }
+ analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ analysis = analyzer.toString();
+ assertArrayEquals(parseAnalysis(analysis), new double[] {0.9552, 1.4526, 1.9345}, 0);
+ assertArrayEquals(parseAnalysisCount(analysis), new int[] {7, 3, 10});
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java?rev=1352836&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java Fri Jun 22 10:42:57 2012
@@ -0,0 +1,199 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import java.util.List;
+import java.util.Random;
+
+import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public class DecisionForestTest extends MahoutTestCase {
+
+ private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no",
+ "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes",
+ "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no",
+ "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no",
+ "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes",
+ "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes",
+ "rainy,71,91,TRUE,no"};
+
+ private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-",
+ "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",};
+
+ private Random rng;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+ }
+
+ private static Data[] generateTrainingDataA() throws DescriptorException {
+ // Dataset
+ Dataset dataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA);
+
+ // Training data
+ Data data = DataLoader.loadData(dataset, TRAIN_DATA);
+ @SuppressWarnings("unchecked")
+ List<Instance>[] instances = new List[3];
+ for (int i = 0; i < instances.length; i++) {
+ instances[i] = Lists.newArrayList();
+ }
+ for (int i = 0; i < data.size(); i++) {
+ if (data.get(i).get(0) == 0.0d) {
+ instances[0].add(data.get(i));
+ } else {
+ instances[1].add(data.get(i));
+ }
+ }
+ Data[] datas = new Data[instances.length];
+ for (int i = 0; i < datas.length; i++) {
+ datas[i] = new Data(dataset, instances[i]);
+ }
+
+ return datas;
+ }
+
+ private static Data[] generateTrainingDataB() throws DescriptorException {
+
+ // Training data
+ String[] trainData = new String[20];
+ for (int i = 0; i < trainData.length; i++) {
+ if (i % 3 == 0) {
+ trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+ } else if (i % 3 == 1) {
+ trainData[i] = "B," + (i + 20) + ',' + (40 - i);
+ } else {
+ trainData[i] = "C," + (i + 20) + ',' + (i + 20);
+ }
+ }
+ // Dataset
+ Dataset dataset = DataLoader.generateDataset("C N L", true, trainData);
+ Data[] datas = new Data[3];
+ datas[0] = DataLoader.loadData(dataset, trainData);
+
+ // Training data
+ trainData = new String[20];
+ for (int i = 0; i < trainData.length; i++) {
+ if (i % 2 == 0) {
+ trainData[i] = "A," + (50 - i) + ',' + (i + 10);
+ } else {
+ trainData[i] = "B," + (i + 10) + ',' + (50 - i);
+ }
+ }
+ datas[1] = DataLoader.loadData(dataset, trainData);
+
+ // Training data
+ trainData = new String[10];
+ for (int i = 0; i < trainData.length; i++) {
+ trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+ }
+ datas[2] = DataLoader.loadData(dataset, trainData);
+
+ return datas;
+ }
+
+ private DecisionForest buildForest(Data[] datas) {
+ List<Node> trees = Lists.newArrayList();
+ for (Data data : datas) {
+ // build tree
+ DecisionTreeBuilder builder = new DecisionTreeBuilder();
+ builder.setM(data.getDataset().nbAttributes() - 1);
+ builder.setMinSplitNum(0);
+ builder.setComplemented(false);
+ trees.add(builder.build(rng, data));
+ }
+ return new DecisionForest(trees);
+ }
+
+ @Test
+ public void testClassify() throws DescriptorException {
+ // Training data
+ Data[] datas = generateTrainingDataA();
+ // Build Forest
+ DecisionForest forest = buildForest(datas);
+ // Test data
+ Data testData = DataLoader.loadData(datas[0].getDataset(), TEST_DATA);
+
+ for (int i = 0; i < testData.size(); i++) {
+ assertEquals(1.0, forest.classify(testData.getDataset(), rng, testData.get(i)), 0);
+ }
+ }
+
+ @Test
+ public void testClassifyData() throws DescriptorException {
+ // Training data
+ Data[] datas = generateTrainingDataA();
+ // Build Forest
+ DecisionForest forest = buildForest(datas);
+ // Test data
+ Data testData = DataLoader.loadData(datas[0].getDataset(), TEST_DATA);
+
+ double[][] predictions = new double[testData.size()][];
+ forest.classify(testData, predictions);
+ assertArrayEquals(predictions, new double[][] {{1.0,Double.NaN,Double.NaN},
+ {1.0,0.0,Double.NaN},{1.0,1.0,Double.NaN}});
+ }
+
+ @Test
+ public void testRegression() throws DescriptorException {
+ Data[] datas = generateTrainingDataB();
+ DecisionForest[] forests = new DecisionForest[datas.length];
+ for (int i = 0; i < datas.length; i++) {
+ Data[] subDatas = new Data[datas.length - 1];
+ int k = 0;
+ for (int j = 0; j < datas.length; j++) {
+ if (j != i) {
+ subDatas[k] = datas[j];
+ k++;
+ }
+ }
+ forests[i] = buildForest(subDatas);
+ }
+
+ double[][] predictions = new double[datas[0].size()][];
+ forests[0].classify(datas[0], predictions);
+ assertArrayEquals(predictions[0], new double[] {20.0, 20.0}, 0);
+ assertArrayEquals(predictions[1], new double[] {39.0, 29.0}, 0);
+ assertArrayEquals(predictions[2], new double[] {Double.NaN, 29.0}, 0);
+ assertArrayEquals(predictions[17], new double[] {Double.NaN, 23.0}, 0);
+
+ predictions = new double[datas[1].size()][];
+ forests[1].classify(datas[1], predictions);
+ assertArrayEquals(predictions[19], new double[] {30.0, 29.0}, 0);
+
+ predictions = new double[datas[2].size()][];
+ forests[2].classify(datas[2], predictions);
+ assertArrayEquals(predictions[9], new double[] {29.0, 28.0}, 0);
+
+ assertEquals(20.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(0)), 0);
+ assertEquals(34.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(1)), 0);
+ assertEquals(29.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(2)), 0);
+ }
+}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java Fri Jun 22 10:42:57 2012
@@ -57,7 +57,7 @@ public final class Step1MapperTest exten
assertTrue(expected.contains(data.get(index)));
}
- return new Leaf(-1);
+ return new Leaf(Double.NaN);
}
}
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java?rev=1352836&r1=1352835&r2=1352836&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java Fri Jun 22 10:42:57 2012
@@ -18,6 +18,7 @@
package org.apache.mahout.classifier.df;
import java.io.IOException;
+import java.util.Arrays;
import java.util.Random;
import org.apache.commons.cli2.CommandLine;
@@ -112,13 +113,26 @@ public class BreimanExample extends Conf
// compute the test set error (Selection Error), and mean tree error (One Tree Error),
double[] testLabels = test.extractLabels();
- double[] predictions = new double[test.size()];
+ double[][] predictions = new double[test.size()][];
forestM.classify(test, predictions);
- sumTestErrM += ErrorEstimate.errorRate(testLabels, predictions);
+ double[] sumPredictions = new double[test.size()];
+ Arrays.fill(sumPredictions, 0.0);
+ for (int i = 0; i < predictions.length; i++) {
+ for (int j = 0; j < predictions[i].length; j++) {
+ sumPredictions[i] += predictions[i][j];
+ }
+ }
+ sumTestErrM += ErrorEstimate.errorRate(testLabels, sumPredictions);
forestOne.classify(test, predictions);
- sumTestErrOne += ErrorEstimate.errorRate(testLabels, predictions);
+ Arrays.fill(sumPredictions, 0.0);
+ for (int i = 0; i < predictions.length; i++) {
+ for (int j = 0; j < predictions[i].length; j++) {
+ sumPredictions[i] += predictions[i][j];
+ }
+ }
+ sumTestErrOne += ErrorEstimate.errorRate(testLabels, sumPredictions);
}
public static void main(String[] args) throws Exception {