You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2010/03/13 17:25:13 UTC

svn commit: r922594 - /lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java

Author: adeneche
Date: Sat Mar 13 16:25:13 2010
New Revision: 922594

URL: http://svn.apache.org/viewvc?rev=922594&view=rev
Log:
MAHOUT-323: TestForest can store the predictions in a file

Modified:
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=922594&r1=922593&r2=922594&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sat Mar 13 16:25:13 2010
@@ -35,6 +35,7 @@ import org.apache.hadoop.conf.Configured
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
 import org.apache.hadoop.util.Tool;
 import org.apache.hadoop.util.ToolRunner;
 import org.apache.mahout.common.CommandLineUtil;
@@ -65,6 +66,10 @@ public class TestForest extends Configur
 
   private Path modelPath; // path where the forest is stored
 
+  private Path outputPath; // path to predictions file, if null do not output the predictions
+
+  private boolean analyze; // analyze the classification results ?
+
   @Override
   public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
 
@@ -83,11 +88,17 @@ public class TestForest extends Configur
         abuilder.withName("path").withMinimum(1).withMaximum(1).create()).
         withDescription("Path to the Decision Forest").create();
 
+    Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(false).withArgument(
+      abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+      "Path to generated predictions file").create();
+
+    Option analyzeOpt = obuilder.withLongName("analyze").withShortName("a").withRequired(false).create();
+
     Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
         .create();
 
     Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(datasetOpt)
-        .withOption(modelOpt).withOption(helpOpt).create();
+        .withOption(modelOpt).withOption(outputOpt).withOption(analyzeOpt).withOption(helpOpt).create();
 
     try {
       Parser parser = new Parser();
@@ -102,15 +113,21 @@ public class TestForest extends Configur
       String dataName = cmdLine.getValue(inputOpt).toString();
       String datasetName = cmdLine.getValue(datasetOpt).toString();
       String modelName = cmdLine.getValue(modelOpt).toString();
+      String outputName = (cmdLine.hasOption(outputOpt)) ? cmdLine.getValue(outputOpt).toString() : null;
+      analyze = cmdLine.hasOption(analyzeOpt);
 
-      log.debug("inout : {}", dataName);
+      log.debug("inout   : {}", dataName);
       log.debug("dataset : {}", datasetName);
-      log.debug("model : {}", modelName);
+      log.debug("model   : {}", modelName);
+      log.debug("output  : {}", outputName);
+      log.debug("analyze : {}", analyze);
 
       dataPath = new Path(dataName);
       datasetPath = new Path(datasetName);
       modelPath = new Path(modelName);
-
+      if (outputName != null) {
+        outputPath = new Path(outputName);
+      }
     } catch (OptionException e) {
       System.err.println("Exception : " + e);
       CommandLineUtil.printHelp(group);
@@ -123,6 +140,17 @@ public class TestForest extends Configur
   }
 
   private void testForest() throws IOException, ClassNotFoundException, InterruptedException {
+
+    FileSystem ofs = null;
+
+    // make sure the output file does not exist
+    if (outputPath != null) {
+      ofs = outputPath.getFileSystem(getConf());
+      if (ofs.exists(outputPath)) {
+        throw new IllegalArgumentException("Output path already exists");
+      }
+    }
+
     Dataset dataset = Dataset.load(getConf(), datasetPath);
     DataConverter converter = new DataConverter(dataset);
 
@@ -146,6 +174,9 @@ public class TestForest extends Configur
       return;
     }
 
+    // create the predictions file
+    FSDataOutputStream ofile = (outputPath != null) ? ofs.create(outputPath) : null;
+
     log.info("Sequential classification...");
     long time = System.currentTimeMillis();
 
@@ -153,7 +184,7 @@ public class TestForest extends Configur
     FSDataInputStream input = tfs.open(dataPath);
     Scanner scanner = new Scanner(input);
     Random rng = RandomUtils.getRandom();
-    ResultAnalyzer analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown");
+    ResultAnalyzer analyzer = (analyze) ? new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown") : null;
 
     while (scanner.hasNextLine()) {
       String line = scanner.nextLine();
@@ -164,13 +195,22 @@ public class TestForest extends Configur
       Instance instance = converter.convert(0, line);
       int prediction = forest.classify(rng, instance);
 
-      analyzer.addInstance(dataset.getLabel(instance.label), new ClassifierResult(dataset.getLabel(prediction), 1.0));
+      if (outputPath != null) {
+        ofile.writeChars(Integer.toString(prediction)); // write the prediction
+        ofile.writeChar('\n');
+      }
+      
+      if (analyze) {
+        analyzer.addInstance(dataset.getLabel(instance.label), new ClassifierResult(dataset.getLabel(prediction), 1.0));
+      }
     }
 
     time = System.currentTimeMillis() - time;
     log.info("Classification Time: {}", DFUtils.elapsedTime(time));
 
-    log.info(analyzer.summarize());
+    if (analyze) {
+      log.info(analyzer.summarize());
+    }
   }
 
   /**