You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/19 08:51:57 UTC

svn commit: r998601 - in /mahout/trunk: core/src/main/java/org/apache/mahout/df/data/Dataset.java core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java

Author: tdunning
Date: Sun Sep 19 06:51:56 2010
New Revision: 998601

URL: http://svn.apache.org/viewvc?rev=998601&view=rev
Log:
MAHOUT-323 - Changes to allow a DF model to be saved and later used.

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java
    mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
    mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java?rev=998601&r1=998600&r2=998601&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java Sun Sep 19 06:51:56 2010
@@ -165,6 +165,7 @@ public class Dataset implements Writable
   }
   
   public String getLabel(int code) {
+    // TODO should handle the case (prediction == -1)
     return labels[code];
   }
   

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java?rev=998601&r1=998600&r2=998601&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java Sun Sep 19 06:51:56 2010
@@ -44,8 +44,11 @@ import org.apache.hadoop.io.LongWritable
 import org.apache.hadoop.io.SequenceFile;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Random;
 import java.net.URI;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ResultAnalyzer;
 
 /**
  * Mapreduce implementation that classifies the Input data using a previousely built decision forest
@@ -62,15 +65,35 @@ public class Classifier {
 
   private final Configuration conf;
 
+  /**
+   * If not null, the Job will build the confusionMatrix.
+   */
+  private final ResultAnalyzer analyzer;
+  private final Dataset dataset;
+
   private final Path outputPath; // path that will containt the final output of the classifier
   private final Path mappersOutputPath; // mappers will output here
 
-  public Classifier(Path forestPath, Path inputPath, Path datasetPath, Path outputPath, Configuration conf) {
+  public ResultAnalyzer getAnalyzer() {
+    return analyzer;
+  }
+
+  public Classifier(Path forestPath, Path inputPath, Path datasetPath, Path outputPath, Configuration conf, boolean analyze) throws IOException {
     this.forestPath = forestPath;
     this.inputPath = inputPath;
     this.datasetPath = datasetPath;
     this.outputPath = outputPath;
     this.conf = conf;
+
+    if (analyze) {
+      dataset = Dataset.load(conf, datasetPath);
+      analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown");
+
+    } else {
+      dataset = null;
+      analyzer = null;
+    }
+
     mappersOutputPath = new Path(outputPath, "mappers");
   }
 
@@ -115,7 +138,6 @@ public class Classifier {
     log.info("Running the job...");
     if (!job.waitForCompletion(true)) {
       log.error("Job failed!");
-      log.error("Job failed!");
       return;
     }
 
@@ -125,8 +147,9 @@ public class Classifier {
   }
 
   /**
-   * Extract the prediction for each mapper and write them in the corresponding output file. The name of the output file
-   * is based on the name of the corresponding input file
+   * Extract the prediction for each mapper and write them in the corresponding output file. 
+   * The name of the output file is based on the name of the corresponding input file.
+   * Will compute the ConfusionMatrix if necessary.
    * @param job
    */
   private void parseOutput(Job job) throws IOException {
@@ -148,9 +171,15 @@ public class Classifier {
             // this is the first value, it contains the name of the input file
             ofile = fs.create(new Path(outputPath, value.toString()).suffix(".out"));
           } else {
-            // the value contains a prediction
+            // The key contains the correct label of the data. The value contains a prediction
             ofile.writeChars(value.toString()); // write the prediction
             ofile.writeChar('\n');
+
+            if (analyzer != null) {
+                analyzer.addInstance(
+                        dataset.getLabel((int)key.get()),
+                        new ClassifierResult(dataset.getLabel(Integer.parseInt(value.toString())), 1.0));
+            }
           }
         }
       } finally {
@@ -215,7 +244,7 @@ public class Classifier {
         FileSplit split = (FileSplit) context.getInputSplit();
         Path path = split.getPath(); // current split path
         lvalue.set(path.getName());
-        context.write(key, new Text(path.getName()));
+        context.write(key, lvalue);
 
         first = false;
       }
@@ -224,6 +253,7 @@ public class Classifier {
       if (!line.isEmpty()) {
         Instance instance = converter.convert(0, line);
         int prediction = forest.classify(rng, instance);
+        key.set(instance.getLabel());
         lvalue.set(Integer.toString(prediction));
         context.write(key, lvalue);
       }

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=998601&r1=998600&r2=998601&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sun Sep 19 06:51:56 2010
@@ -175,17 +175,17 @@ public class TestForest extends Configur
   }
 
   private void mapreduce() throws ClassNotFoundException, IOException, InterruptedException {
-    if (analyze) {
-      log.warn("IMPORTANT: The current mapreduce implementation of TestForest does not support result analysis");
-    }
-
     if (outputPath == null) {
       throw new IllegalArgumentException("You must specify the ouputPath when using the mapreduce implementation");
     }
-    
-    Classifier classifier = new Classifier(modelPath, dataPath, datasetPath, outputPath, getConf());
+
+    Classifier classifier = new Classifier(modelPath, dataPath, datasetPath, outputPath, getConf(), analyze);
 
     classifier.run();
+
+    if (analyze) {
+      log.info(classifier.getAnalyzer().summarize());
+    }
   }
 
   private void sequential() throws IOException {
@@ -219,7 +219,7 @@ public class TestForest extends Configur
     time = System.currentTimeMillis() - time;
     log.info("Classification Time: {}", DFUtils.elapsedTime(time));
 
-    if (analyze) {
+    if (analyzer != null) {
       log.info(analyzer.summarize());
     }
   }
@@ -261,7 +261,7 @@ public class TestForest extends Configur
         ofile.writeChar('\n');
       }
 
-      if (analyze) {
+      if (analyzer != null) {
         analyzer.addInstance(dataset.getLabel(instance.getLabel()),
                              new ClassifierResult(dataset.getLabel(prediction), 1.0));
       }