You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2010/03/13 17:25:13 UTC
svn commit: r922594 -
/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
Author: adeneche
Date: Sat Mar 13 16:25:13 2010
New Revision: 922594
URL: http://svn.apache.org/viewvc?rev=922594&view=rev
Log:
MAHOUT-323: TestForest can store the predictions in a file
Modified:
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=922594&r1=922593&r2=922594&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sat Mar 13 16:25:13 2010
@@ -35,6 +35,7 @@ import org.apache.hadoop.conf.Configured
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.CommandLineUtil;
@@ -65,6 +66,10 @@ public class TestForest extends Configur
private Path modelPath; // path where the forest is stored
+ private Path outputPath; // path to predictions file, if null do not output the predictions
+
+ private boolean analyze; // analyze the classification results ?
+
@Override
public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
@@ -83,11 +88,17 @@ public class TestForest extends Configur
abuilder.withName("path").withMinimum(1).withMaximum(1).create()).
withDescription("Path to the Decision Forest").create();
+ Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(false).withArgument(
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Path to generated predictions file").create();
+
+ Option analyzeOpt = obuilder.withLongName("analyze").withShortName("a").withRequired(false).create();
+
Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
.create();
Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(datasetOpt)
- .withOption(modelOpt).withOption(helpOpt).create();
+ .withOption(modelOpt).withOption(outputOpt).withOption(analyzeOpt).withOption(helpOpt).create();
try {
Parser parser = new Parser();
@@ -102,15 +113,21 @@ public class TestForest extends Configur
String dataName = cmdLine.getValue(inputOpt).toString();
String datasetName = cmdLine.getValue(datasetOpt).toString();
String modelName = cmdLine.getValue(modelOpt).toString();
+ String outputName = (cmdLine.hasOption(outputOpt)) ? cmdLine.getValue(outputOpt).toString() : null;
+ analyze = cmdLine.hasOption(analyzeOpt);
- log.debug("inout : {}", dataName);
+ log.debug("inout : {}", dataName);
log.debug("dataset : {}", datasetName);
- log.debug("model : {}", modelName);
+ log.debug("model : {}", modelName);
+ log.debug("output : {}", outputName);
+ log.debug("analyze : {}", analyze);
dataPath = new Path(dataName);
datasetPath = new Path(datasetName);
modelPath = new Path(modelName);
-
+ if (outputName != null) {
+ outputPath = new Path(outputName);
+ }
} catch (OptionException e) {
System.err.println("Exception : " + e);
CommandLineUtil.printHelp(group);
@@ -123,6 +140,17 @@ public class TestForest extends Configur
}
private void testForest() throws IOException, ClassNotFoundException, InterruptedException {
+
+ FileSystem ofs = null;
+
+ // make sure the output file does not exist
+ if (outputPath != null) {
+ ofs = outputPath.getFileSystem(getConf());
+ if (ofs.exists(outputPath)) {
+ throw new IllegalArgumentException("Output path already exists");
+ }
+ }
+
Dataset dataset = Dataset.load(getConf(), datasetPath);
DataConverter converter = new DataConverter(dataset);
@@ -146,6 +174,9 @@ public class TestForest extends Configur
return;
}
+ // create the predictions file
+ FSDataOutputStream ofile = (outputPath != null) ? ofs.create(outputPath) : null;
+
log.info("Sequential classification...");
long time = System.currentTimeMillis();
@@ -153,7 +184,7 @@ public class TestForest extends Configur
FSDataInputStream input = tfs.open(dataPath);
Scanner scanner = new Scanner(input);
Random rng = RandomUtils.getRandom();
- ResultAnalyzer analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown");
+ ResultAnalyzer analyzer = (analyze) ? new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown") : null;
while (scanner.hasNextLine()) {
String line = scanner.nextLine();
@@ -164,13 +195,22 @@ public class TestForest extends Configur
Instance instance = converter.convert(0, line);
int prediction = forest.classify(rng, instance);
- analyzer.addInstance(dataset.getLabel(instance.label), new ClassifierResult(dataset.getLabel(prediction), 1.0));
+ if (outputPath != null) {
+ ofile.writeChars(Integer.toString(prediction)); // write the prediction
+ ofile.writeChar('\n');
+ }
+
+ if (analyze) {
+ analyzer.addInstance(dataset.getLabel(instance.label), new ClassifierResult(dataset.getLabel(prediction), 1.0));
+ }
}
time = System.currentTimeMillis() - time;
log.info("Classification Time: {}", DFUtils.elapsedTime(time));
- log.info(analyzer.summarize());
+ if (analyze) {
+ log.info(analyzer.summarize());
+ }
}
/**