You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2010/03/13 11:54:15 UTC
svn commit: r922523 - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/df/
core/src/main/java/org/apache/mahout/df/split/
core/src/main/java/org/apache/mahout/df/tools/
examples/src/main/java/org/apache/mahout/df/mapreduce/
Author: adeneche
Date: Sat Mar 13 10:54:14 2010
New Revision: 922523
URL: http://svn.apache.org/viewvc?rev=922523&view=rev
Log:
MAHOUT-323
Added:
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java?rev=922523&r1=922522&r2=922523&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java Sat Mar 13 10:54:14 2010
@@ -23,6 +23,9 @@ import java.io.IOException;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.Writable;
import org.apache.mahout.df.node.Node;
import org.apache.mahout.ga.watchmaker.OutputUtils;
@@ -162,5 +165,12 @@ public class DFUtils {
return hours + "h " + minutes + "m " + seconds + "s " + milli;
}
-
+
+ public static void storeWritable(Configuration conf, Path path, Writable writable) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+
+ FSDataOutputStream out = fs.create(path);
+ writable.write(out);
+ out.close();
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java?rev=922523&r1=922522&r2=922523&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java Sat Mar 13 10:54:14 2010
@@ -20,17 +20,21 @@ package org.apache.mahout.df;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.io.DataInput;
import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataUtils;
import org.apache.mahout.df.data.Instance;
import org.apache.mahout.df.node.Node;
+import org.apache.hadoop.io.Writable;
/**
* Represents a forest of decision trees.
*/
-public class DecisionForest {
+public class DecisionForest implements Writable {
private final List<Node> trees;
@@ -163,5 +167,31 @@ public class DecisionForest {
public int hashCode() {
return trees.hashCode();
}
-
+
+ public void write(DataOutput dataOutput) throws IOException {
+ dataOutput.writeInt(trees.size());
+ for (Node tree:trees) {
+ tree.write(dataOutput);
+ }
+ }
+
+ /**
+ * Reads the trees from the input and adds them to the existing trees
+ * @param dataInput
+ * @throws IOException
+ */
+ public void readFields(DataInput dataInput) throws IOException {
+ int size = dataInput.readInt();
+ for (int i = 0; i < size; i++) {
+ trees.add(Node.read(dataInput));
+ }
+ }
+
+ public static DecisionForest read(DataInput dataInput) throws IOException {
+ DecisionForest forest = new DecisionForest();
+ forest.readFields(dataInput);
+ return forest;
+ }
+
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java?rev=922523&r1=922522&r2=922523&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java Sat Mar 13 10:54:14 2010
@@ -40,4 +40,9 @@ public class Split {
public Split(int attr, double ig) {
this(attr, ig, Double.NaN);
}
+
+ @Override
+ public String toString() {
+ return String.format("attr: %d, ig: %f, split: %f", attr, ig, split);
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java?rev=922523&r1=922522&r2=922523&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java Sat Mar 13 10:54:14 2010
@@ -30,15 +30,14 @@ import org.apache.commons.cli2.builder.D
import org.apache.commons.cli2.builder.GroupBuilder;
import org.apache.commons.cli2.commandline.Parser;
import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.Writable;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.df.data.DataLoader;
import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.data.DescriptorException;
import org.apache.mahout.df.data.DescriptorUtils;
+import org.apache.mahout.df.DFUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -94,7 +93,7 @@ public class Describe {
runTool(dataPath, descriptor, descPath);
} catch (OptionException e) {
- log.warn(e.toString(), e);
+ log.warn(e.toString());
CommandLineUtil.printHelp(group);
}
}
@@ -110,7 +109,7 @@ public class Describe {
Dataset dataset = generateDataset(descriptor, dataPath);
log.info("storing the dataset description");
- storeWritable(new Configuration(), fPath, dataset);
+ DFUtils.storeWritable(new Configuration(), fPath, dataset);
}
private static Dataset generateDataset(String descriptor, String dataPath) throws IOException,
@@ -138,12 +137,5 @@ public class Describe {
}
return list;
}
-
- private static void storeWritable(Configuration conf, Path path, Writable dataset) throws IOException {
- FileSystem fs = path.getFileSystem(conf);
-
- FSDataOutputStream out = fs.create(path);
- dataset.write(out);
- out.close();
- }
+
}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java?rev=922523&r1=922522&r2=922523&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java Sat Mar 13 10:54:14 2010
@@ -51,16 +51,17 @@ import org.slf4j.LoggerFactory;
/**
* Tool to builds a Random Forest using any given dataset (in UCI format). Can use either the in-mem mapred or
- * partial mapred implementations
+ * partial mapred implementations. Stores the forest in the given output directory
*/
public class BuildForest extends Configured implements Tool {
private static final Logger log = LoggerFactory.getLogger(BuildForest.class);
- private Path dataPath; // Data path
-
- private Path datasetPath; // Dataset path
+ private Path dataPath;
+ private Path datasetPath;
+
+ private Path outputPath;
private int m; // Number of variables to select at each tree-node
private int nbTrees; // Number of trees to grow
@@ -70,7 +71,7 @@ public class BuildForest extends Configu
private boolean isPartial; // use partial data implementation
private boolean isOob; // estimate oob error;
-
+
@Override
public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
@@ -103,12 +104,16 @@ public class BuildForest extends Configu
abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create()).withDescription(
"Number of trees to grow").create();
+ Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).
+ withDescription("Output path, will contain the Decision Forest").create();
+
Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
.create();
Group group = gbuilder.withName("Options").withOption(oobOpt).withOption(dataOpt).withOption(datasetOpt)
.withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt)
- .withOption(helpOpt).create();
+ .withOption(outputOpt).withOption(helpOpt).create();
try {
Parser parser = new Parser();
@@ -124,6 +129,7 @@ public class BuildForest extends Configu
isOob = cmdLine.hasOption(oobOpt);
String dataName = cmdLine.getValue(dataOpt).toString();
String datasetName = cmdLine.getValue(datasetOpt).toString();
+ String outputName = cmdLine.getValue(outputOpt).toString();
m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString());
@@ -133,6 +139,7 @@ public class BuildForest extends Configu
log.debug("data : {}", dataName);
log.debug("dataset : {}", datasetName);
+ log.debug("output : {}", outputName);
log.debug("m : {}", m);
log.debug("seed : {}", seed);
log.debug("nbtrees : {}", nbTrees);
@@ -141,6 +148,7 @@ public class BuildForest extends Configu
dataPath = new Path(dataName);
datasetPath = new Path(datasetName);
+ outputPath = new Path(outputName);
} catch (OptionException e) {
log.error("Exception", e);
@@ -153,7 +161,14 @@ public class BuildForest extends Configu
return 0;
}
- private DecisionForest buildForest() throws IOException, ClassNotFoundException, InterruptedException {
+ private void buildForest() throws IOException, ClassNotFoundException, InterruptedException {
+ // make sure the output path does not exist
+ FileSystem ofs = outputPath.getFileSystem(getConf());
+ if (ofs.exists(outputPath)) {
+ log.error("Output path already exists");
+ return;
+ }
+
DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
treeBuilder.setM(m);
@@ -171,6 +186,9 @@ public class BuildForest extends Configu
log.info("InMem Mapred implementation");
forestBuilder = new InMemBuilder(treeBuilder, dataPath, datasetPath, seed, getConf());
}
+
+ forestBuilder.setOutputDirName(outputPath.getName());
+
log.info("Building the forest...");
long time = System.currentTimeMillis();
@@ -193,8 +211,12 @@ public class BuildForest extends Configu
log.info("oob error estimate : "
+ ErrorEstimate.errorRate(labels, callback.computePredictions(rng)));
}
-
- return forest;
+
+ // store the decision forest in the output path
+ Path forestPath = new Path(outputPath, "forest.seq");
+ log.info("Storing the forest in: " + forestPath);
+ DFUtils.storeWritable(getConf(), forestPath, forest);
+
}
protected static Data loadData(Configuration conf, Path dataPath, Dataset dataset) throws IOException {
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=922523&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sat Mar 13 10:54:14 2010
@@ -0,0 +1,184 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapreduce;
+
+import java.io.IOException;
+import java.util.Random;
+import java.util.Scanner;
+import java.util.Arrays;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.conf.Configured;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.df.DFUtils;
+import org.apache.mahout.df.DecisionForest;
+import org.apache.mahout.df.ErrorEstimate;
+import org.apache.mahout.df.builder.DefaultTreeBuilder;
+import org.apache.mahout.df.callback.ForestPredictions;
+import org.apache.mahout.df.data.*;
+import org.apache.mahout.df.mapreduce.inmem.InMemBuilder;
+import org.apache.mahout.df.mapreduce.partial.PartialBuilder;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Tool to classify a Dataset using a previously built Decision Forest
+ */
+public class TestForest extends Configured implements Tool {
+
+ private static final Logger log = LoggerFactory.getLogger(TestForest.class);
+
+ private Path dataPath; // test data path
+
+ private Path datasetPath;
+
+ private Path modelPath; // path where the forest is stored
+
+ @Override
+ public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option inputOpt = obuilder.withLongName("input").withShortName("i").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Test data path").create();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+ abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()).withDescription("Dataset path")
+ .create();
+
+ Option modelOpt = obuilder.withLongName("model").withShortName("m").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).
+ withDescription("Path to the Decision Forest").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(datasetOpt)
+ .withOption(modelOpt).withOption(helpOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption("help")) {
+ CommandLineUtil.printHelp(group);
+ return -1;
+ }
+
+ String dataName = cmdLine.getValue(inputOpt).toString();
+ String datasetName = cmdLine.getValue(datasetOpt).toString();
+ String modelName = cmdLine.getValue(modelOpt).toString();
+
+ log.debug("inout : {}", dataName);
+ log.debug("dataset : {}", datasetName);
+ log.debug("model : {}", modelName);
+
+ dataPath = new Path(dataName);
+ datasetPath = new Path(datasetName);
+ modelPath = new Path(modelName);
+
+ } catch (OptionException e) {
+ System.err.println("Exception : " + e);
+ CommandLineUtil.printHelp(group);
+ return -1;
+ }
+
+ testForest();
+
+ return 0;
+ }
+
+ private void testForest() throws IOException, ClassNotFoundException, InterruptedException {
+ Dataset dataset = Dataset.load(getConf(), datasetPath);
+ DataConverter converter = new DataConverter(dataset);
+
+ log.info("Loading the forest...");
+ FileSystem fs = modelPath.getFileSystem(getConf());
+ Path[] modelfiles = DFUtils.listOutputFiles(fs, modelPath);
+ DecisionForest forest = null;
+ for (Path path : modelfiles) {
+ FSDataInputStream dataInput = new FSDataInputStream(fs.open(path));
+ if (forest == null) {
+ forest = DecisionForest.read(dataInput);
+ } else {
+ forest.readFields(dataInput);
+ }
+
+ dataInput.close();
+ }
+
+ if (forest == null) {
+ log.error("No Decision Forest found!");
+ return;
+ }
+
+ log.info("Sequential classification...");
+ long time = System.currentTimeMillis();
+
+ FileSystem tfs = dataPath.getFileSystem(getConf());
+ FSDataInputStream input = tfs.open(dataPath);
+ Scanner scanner = new Scanner(input);
+ Random rng = RandomUtils.getRandom();
+ ResultAnalyzer analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown");
+
+ while (scanner.hasNextLine()) {
+ String line = scanner.nextLine();
+ if (line.isEmpty()) {
+ continue; // skip empty lines
+ }
+
+ Instance instance = converter.convert(0, line);
+ int prediction = forest.classify(rng, instance);
+
+ analyzer.addInstance(dataset.getLabel(instance.label), new ClassifierResult(dataset.getLabel(prediction), 1.0));
+ }
+
+ time = System.currentTimeMillis() - time;
+ log.info("Classification Time: {}", DFUtils.elapsedTime(time));
+
+ log.info(analyzer.summarize());
+ }
+
+ /**
+ * @param args
+ * @throws Exception
+ */
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new TestForest(), args);
+ }
+
+}
\ No newline at end of file