You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by td...@apache.org on 2010/09/19 08:51:57 UTC
svn commit: r998601 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/df/data/Dataset.java
core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
Author: tdunning
Date: Sun Sep 19 06:51:56 2010
New Revision: 998601
URL: http://svn.apache.org/viewvc?rev=998601&view=rev
Log:
MAHOUT-323 - Changes to allow a DF model to be saved and later used.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java?rev=998601&r1=998600&r2=998601&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/data/Dataset.java Sun Sep 19 06:51:56 2010
@@ -165,6 +165,7 @@ public class Dataset implements Writable
}
public String getLabel(int code) {
+ // TODO should handle the case (prediction == -1)
return labels[code];
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java?rev=998601&r1=998600&r2=998601&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java Sun Sep 19 06:51:56 2010
@@ -44,8 +44,11 @@ import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.SequenceFile;
import java.io.IOException;
+import java.util.Arrays;
import java.util.Random;
import java.net.URI;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ResultAnalyzer;
/**
* Mapreduce implementation that classifies the Input data using a previousely built decision forest
@@ -62,15 +65,35 @@ public class Classifier {
private final Configuration conf;
+ /**
+ * If not null, the Job will build the confusionMatrix.
+ */
+ private final ResultAnalyzer analyzer;
+ private final Dataset dataset;
+
private final Path outputPath; // path that will containt the final output of the classifier
private final Path mappersOutputPath; // mappers will output here
- public Classifier(Path forestPath, Path inputPath, Path datasetPath, Path outputPath, Configuration conf) {
+ public ResultAnalyzer getAnalyzer() {
+ return analyzer;
+ }
+
+ public Classifier(Path forestPath, Path inputPath, Path datasetPath, Path outputPath, Configuration conf, boolean analyze) throws IOException {
this.forestPath = forestPath;
this.inputPath = inputPath;
this.datasetPath = datasetPath;
this.outputPath = outputPath;
this.conf = conf;
+
+ if (analyze) {
+ dataset = Dataset.load(conf, datasetPath);
+ analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown");
+
+ } else {
+ dataset = null;
+ analyzer = null;
+ }
+
mappersOutputPath = new Path(outputPath, "mappers");
}
@@ -115,7 +138,6 @@ public class Classifier {
log.info("Running the job...");
if (!job.waitForCompletion(true)) {
log.error("Job failed!");
- log.error("Job failed!");
return;
}
@@ -125,8 +147,9 @@ public class Classifier {
}
/**
- * Extract the prediction for each mapper and write them in the corresponding output file. The name of the output file
- * is based on the name of the corresponding input file
+ * Extract the prediction for each mapper and write them in the corresponding output file.
+ * The name of the output file is based on the name of the corresponding input file.
+ * Will compute the ConfusionMatrix if necessary.
* @param job
*/
private void parseOutput(Job job) throws IOException {
@@ -148,9 +171,15 @@ public class Classifier {
// this is the first value, it contains the name of the input file
ofile = fs.create(new Path(outputPath, value.toString()).suffix(".out"));
} else {
- // the value contains a prediction
+ // The key contains the correct label of the data. The value contains a prediction
ofile.writeChars(value.toString()); // write the prediction
ofile.writeChar('\n');
+
+ if (analyzer != null) {
+ analyzer.addInstance(
+ dataset.getLabel((int)key.get()),
+ new ClassifierResult(dataset.getLabel(Integer.parseInt(value.toString())), 1.0));
+ }
}
}
} finally {
@@ -215,7 +244,7 @@ public class Classifier {
FileSplit split = (FileSplit) context.getInputSplit();
Path path = split.getPath(); // current split path
lvalue.set(path.getName());
- context.write(key, new Text(path.getName()));
+ context.write(key, lvalue);
first = false;
}
@@ -224,6 +253,7 @@ public class Classifier {
if (!line.isEmpty()) {
Instance instance = converter.convert(0, line);
int prediction = forest.classify(rng, instance);
+ key.set(instance.getLabel());
lvalue.set(Integer.toString(prediction));
context.write(key, lvalue);
}
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=998601&r1=998600&r2=998601&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sun Sep 19 06:51:56 2010
@@ -175,17 +175,17 @@ public class TestForest extends Configur
}
private void mapreduce() throws ClassNotFoundException, IOException, InterruptedException {
- if (analyze) {
- log.warn("IMPORTANT: The current mapreduce implementation of TestForest does not support result analysis");
- }
-
if (outputPath == null) {
throw new IllegalArgumentException("You must specify the ouputPath when using the mapreduce implementation");
}
-
- Classifier classifier = new Classifier(modelPath, dataPath, datasetPath, outputPath, getConf());
+
+ Classifier classifier = new Classifier(modelPath, dataPath, datasetPath, outputPath, getConf(), analyze);
classifier.run();
+
+ if (analyze) {
+ log.info(classifier.getAnalyzer().summarize());
+ }
}
private void sequential() throws IOException {
@@ -219,7 +219,7 @@ public class TestForest extends Configur
time = System.currentTimeMillis() - time;
log.info("Classification Time: {}", DFUtils.elapsedTime(time));
- if (analyze) {
+ if (analyzer != null) {
log.info(analyzer.summarize());
}
}
@@ -261,7 +261,7 @@ public class TestForest extends Configur
ofile.writeChar('\n');
}
- if (analyze) {
+ if (analyzer != null) {
analyzer.addInstance(dataset.getLabel(instance.getLabel()),
new ClassifierResult(dataset.getLabel(prediction), 1.0));
}