You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2010/03/27 07:34:42 UTC

svn commit: r928162 - /lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java

Author: adeneche
Date: Sat Mar 27 06:34:42 2010
New Revision: 928162

URL: http://svn.apache.org/viewvc?rev=928162&view=rev
Log:
MAHOUT-323 added the ability to classify a directory of files

Modified:
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=928162&r1=928161&r2=928162&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sat Mar 27 06:34:42 2010
@@ -60,12 +60,14 @@ public class TestForest extends Configur
 
   private static final Logger log = LoggerFactory.getLogger(TestForest.class);
 
+  private FileSystem dataFS;
   private Path dataPath; // test data path
 
   private Path datasetPath;
 
   private Path modelPath; // path where the forest is stored
 
+  private FileSystem outFS;
   private Path outputPath; // path to predictions file, if null do not output the predictions
 
   private boolean analyze; // analyze the classification results ?
@@ -141,12 +143,10 @@ public class TestForest extends Configur
 
   private void testForest() throws IOException, ClassNotFoundException, InterruptedException {
 
-    FileSystem ofs = null;
-
     // make sure the output file does not exist
     if (outputPath != null) {
-      ofs = outputPath.getFileSystem(getConf());
-      if (ofs.exists(outputPath)) {
+      outFS = outputPath.getFileSystem(getConf());
+      if (outFS.exists(outputPath)) {
         throw new IllegalArgumentException("Output path already exists");
       }
     }
@@ -157,6 +157,12 @@ public class TestForest extends Configur
       throw new IllegalArgumentException("The forest path does not exist");
     }
 
+    // make sure the test data exists
+    dataFS = dataPath.getFileSystem(getConf());
+    if (!dataFS.exists(dataPath)) {
+      throw new IllegalArgumentException("The Test data path does not exist");
+    }
+
     // load the dataset
     Dataset dataset = Dataset.load(getConf(), datasetPath);
     DataConverter converter = new DataConverter(dataset);
@@ -180,18 +186,51 @@ public class TestForest extends Configur
       return;
     }
 
-    // create the predictions file
-    FSDataOutputStream ofile = (outputPath != null) ? ofs.create(outputPath) : null;
-
     log.info("Sequential classification...");
     long time = System.currentTimeMillis();
 
-    FileSystem tfs = dataPath.getFileSystem(getConf());
-    FSDataInputStream input = tfs.open(dataPath);
-    Scanner scanner = new Scanner(input);
     Random rng = RandomUtils.getRandom();
     ResultAnalyzer analyzer = (analyze) ? new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown") : null;
 
+    if (dataFS.getFileStatus(dataPath).isDir()) {
+      //the input is a directory of files
+      testDirectory(dataPath, outputPath, converter, forest, dataset, analyzer, rng);
+    }  else {
+      // the input is one single file
+      testFile(dataPath, outputPath, converter, forest, dataset, analyzer, rng);
+    }
+
+    time = System.currentTimeMillis() - time;
+    log.info("Classification Time: {}", DFUtils.elapsedTime(time));
+
+    if (analyze) {
+      log.info(analyzer.summarize());
+    }
+  }
+
+  private void testDirectory(Path inPath, Path outPath, DataConverter converter, DecisionForest forest, Dataset dataset,
+                        ResultAnalyzer analyzer, Random rng) throws IOException {
+    Path[] infiles = DFUtils.listOutputFiles(dataFS, dataPath);
+
+    for (Path path : infiles) {
+      log.info("Classifying : " + path);
+      Path outfile = (outPath != null) ? new Path(outPath, path.getName()).suffix(".out"): null;
+      testFile(path, outfile, converter, forest, dataset, analyzer, rng);
+    }
+  }
+
+  private void testFile(Path inPath, Path outPath, DataConverter converter, DecisionForest forest, Dataset dataset,
+                        ResultAnalyzer analyzer, Random rng) throws IOException {
+    // create the predictions file
+    FSDataOutputStream ofile = null;
+
+    if (outPath != null) {
+      ofile = outFS.create(outPath);
+    }
+
+    FSDataInputStream input = dataFS.open(inPath);
+    Scanner scanner = new Scanner(input);
+
     while (scanner.hasNextLine()) {
       String line = scanner.nextLine();
       if (line.isEmpty()) {
@@ -205,18 +244,14 @@ public class TestForest extends Configur
         ofile.writeChars(Integer.toString(prediction)); // write the prediction
         ofile.writeChar('\n');
       }
-      
+
       if (analyze) {
         analyzer.addInstance(dataset.getLabel(instance.label), new ClassifierResult(dataset.getLabel(prediction), 1.0));
       }
     }
 
-    time = System.currentTimeMillis() - time;
-    log.info("Classification Time: {}", DFUtils.elapsedTime(time));
-
-    if (analyze) {
-      log.info(analyzer.summarize());
-    }
+    scanner.close();
+    input.close();
   }
 
   /**