You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2010/04/03 21:09:28 UTC
svn commit: r930563 - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/df/DecisionForest.java
core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
Author: adeneche
Date: Sat Apr 3 19:09:28 2010
New Revision: 930563
URL: http://svn.apache.org/viewvc?rev=930563&view=rev
Log:
MAHOUT-323 Added a Basic Mapreduce version of TestForest
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java?rev=930563&r1=930562&r2=930563&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java Sat Apr 3 19:09:28 2010
@@ -30,6 +30,10 @@ import org.apache.mahout.df.data.DataUti
import org.apache.mahout.df.data.Instance;
import org.apache.mahout.df.node.Node;
import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.conf.Configuration;
/**
* Represents a forest of decision trees.
@@ -195,5 +199,36 @@ public class DecisionForest implements W
return forest;
}
+ /**
+ * Load the forest from a single file or a directrory of files
+ * @param conf
+ * @param forestPath
+ * @return
+ * @throws IOException
+ */
+ public static DecisionForest load(Configuration conf, Path forestPath) throws IOException {
+ FileSystem fs = forestPath.getFileSystem(conf);
+ Path[] files = null;
+
+ if (fs.getFileStatus(forestPath).isDir())
+ files = DFUtils.listOutputFiles(fs, forestPath);
+ else
+ files = new Path[] {forestPath};
+
+ DecisionForest forest = null;
+ for (Path path : files) {
+ FSDataInputStream dataInput = new FSDataInputStream(fs.open(path));
+ if (forest == null) {
+ forest = DecisionForest.read(dataInput);
+ } else {
+ forest.readFields(dataInput);
+ }
+
+ dataInput.close();
+ }
+
+ return forest;
+
+ }
}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java?rev=930563&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java Sat Apr 3 19:09:28 2010
@@ -0,0 +1,239 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapreduce;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.apache.mahout.df.DecisionForest;
+import org.apache.mahout.df.DFUtils;
+import org.apache.mahout.df.data.DataConverter;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Instance;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.FileSplit;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+
+import java.io.IOException;
+import java.util.Random;
+import java.net.URI;
+
+/**
+ * Mapreduce implementation that classifies the Input data using a previousely built decision forest
+ */
+public class Classifier {
+
+ private static final Logger log = LoggerFactory.getLogger(Classifier.class);
+
+ private final Path forestPath;
+
+ private final Path inputPath;
+
+ private final Path datasetPath;
+
+ private final Configuration conf;
+
+ private final Path outputPath; // path that will containt the final output of the classifier
+ private final Path mappersOutputPath; // mappers will output here
+
+ private void configureJob(Job job) throws IOException {
+ Configuration conf = job.getConfiguration();
+
+ job.setJarByClass(Classifier.class);
+
+ FileInputFormat.setInputPaths(job, inputPath);
+ FileOutputFormat.setOutputPath(job, mappersOutputPath);
+
+ job.setOutputKeyClass(LongWritable.class);
+ job.setOutputValueClass(Text.class);
+
+ job.setMapperClass(CMapper.class);
+ job.setNumReduceTasks(0); // no reducers
+
+ job.setInputFormatClass(CTextInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ }
+
+ public Classifier(Path forestPath, Path inputPath, Path datasetPath, Path outputPath, Configuration conf) {
+ this.forestPath = forestPath;
+ this.inputPath = inputPath;
+ this.datasetPath = datasetPath;
+ this.outputPath = outputPath;
+ this.conf = conf;
+
+ mappersOutputPath = new Path(outputPath, "mappers");
+ }
+
+ public void run() throws IOException, ClassNotFoundException, InterruptedException {
+ FileSystem fs = FileSystem.get(conf);
+
+ // check the output
+ if (fs.exists(outputPath)) {
+ throw new IOException("Output path already exists : " + outputPath);
+ }
+
+ log.info("Adding the dataset to the DistributedCache");
+ // put the dataset into the DistributedCache
+ DistributedCache.addCacheFile(datasetPath.toUri(), conf);
+
+ log.info("Adding the decision forest to the DistributedCache");
+ DistributedCache.addCacheFile(forestPath.toUri(), conf);
+
+ Job job = new Job(conf, "decision forest classifier");
+
+ log.info("Configuring the job...");
+ configureJob(job);
+
+ log.info("Running the job...");
+ if (!job.waitForCompletion(true)) {
+ log.error("Job failed!");
+ log.error("Job failed!");
+ return;
+ }
+
+ parseOutput(job);
+
+ // delete the output path
+ fs.delete(mappersOutputPath, true);
+ }
+
+ /**
+ * Extract the prediction for each mapper and write them in the corresponding output file. The name of the output file
+ * is based on the name of the corresponding input file
+ * @param job
+ */
+ private void parseOutput(Job job) throws IOException {
+ Configuration conf = job.getConfiguration();
+ FileSystem fs = mappersOutputPath.getFileSystem(conf);
+
+ Path[] outfiles = DFUtils.listOutputFiles(fs, mappersOutputPath);
+ FSDataOutputStream ofile = null;
+
+ // read all the output
+ LongWritable key = new LongWritable();
+ Text value = new Text();
+ for (Path path : outfiles) {
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
+
+ try {
+ while (reader.next(key, value)) {
+ if (ofile == null) {
+ // this is the first value, it contains the name of the input file
+ ofile = fs.create(new Path(outputPath, value.toString()).suffix(".out"));
+ } else {
+ // the value contains a prediction
+ ofile.writeChars(value.toString()); // write the prediction
+ ofile.writeChar('\n');
+ }
+ }
+ } finally {
+ reader.close();
+ ofile.close();
+ ofile = null;
+ }
+ }
+
+ }
+
+ /**
+ * TextInputFormat that does not split the input files. This ensures that each input file is processed by one single
+ * mapper.
+ */
+ public static class CTextInputFormat extends TextInputFormat {
+
+ public CTextInputFormat() {
+ super();
+ }
+
+ @Override
+ protected boolean isSplitable(JobContext jobContext, Path path) {
+ return false;
+ }
+ }
+
+ public static class CMapper extends Mapper<LongWritable, Text, LongWritable, Text> {
+
+ /** used to convert input values to data instances */
+ private DataConverter converter;
+
+ private DecisionForest forest;
+
+ private Random rng = RandomUtils.getRandom();
+
+ private boolean first = true;
+
+ private Text lvalue = new Text();
+
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context); //To change body of overridden methods use File | Settings | File Templates.
+
+ Configuration conf = context.getConfiguration();
+
+ URI[] files = DistributedCache.getCacheFiles(conf);
+
+ if ((files == null) || (files.length < 2)) {
+ throw new IOException("not enough paths in the DistributedCache");
+ }
+
+ Dataset dataset = Dataset.load(conf, new Path(files[0].getPath()));
+
+ converter = new DataConverter(dataset);
+
+ forest = DecisionForest.load(conf, new Path(files[1].getPath()));
+ if (forest == null) {
+ throw new InterruptedException("DecisionForest not found!");
+ }
+ }
+
+ @Override
+ protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
+ if (first) {
+ FileSplit split = (FileSplit) context.getInputSplit();
+ Path path = split.getPath(); // current split path
+ lvalue.set(path.getName());
+ context.write(key, new Text(path.getName()));
+
+ first = false;
+ }
+
+ String line = value.toString();
+ if (line.isEmpty()) return;
+
+ Instance instance = converter.convert(0, line);
+ int prediction = forest.classify(rng, instance);
+ lvalue.set(Integer.toString(prediction));
+ context.write(key, lvalue);
+ }
+ }
+}
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=930563&r1=930562&r2=930563&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sat Apr 3 19:09:28 2010
@@ -69,6 +69,8 @@ public class TestForest extends Configur
private boolean analyze; // analyze the classification results ?
+ private boolean useMapreduce; // use the mapreduce classifier ?
+
@Override
public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
@@ -93,11 +95,13 @@ public class TestForest extends Configur
Option analyzeOpt = obuilder.withLongName("analyze").withShortName("a").withRequired(false).create();
+ Option mrOpt = obuilder.withLongName("mapreduce").withShortName("mr").withRequired(false).create();
+
Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
.create();
- Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(datasetOpt)
- .withOption(modelOpt).withOption(outputOpt).withOption(analyzeOpt).withOption(helpOpt).create();
+ Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(datasetOpt).withOption(modelOpt)
+ .withOption(outputOpt).withOption(analyzeOpt).withOption(mrOpt).withOption(helpOpt).create();
try {
Parser parser = new Parser();
@@ -114,12 +118,14 @@ public class TestForest extends Configur
String modelName = cmdLine.getValue(modelOpt).toString();
String outputName = (cmdLine.hasOption(outputOpt)) ? cmdLine.getValue(outputOpt).toString() : null;
analyze = cmdLine.hasOption(analyzeOpt);
+ useMapreduce = cmdLine.hasOption(mrOpt);
- log.debug("inout : {}", dataName);
- log.debug("dataset : {}", datasetName);
- log.debug("model : {}", modelName);
- log.debug("output : {}", outputName);
- log.debug("analyze : {}", analyze);
+ log.debug("inout : {}", dataName);
+ log.debug("dataset : {}", datasetName);
+ log.debug("model : {}", modelName);
+ log.debug("output : {}", outputName);
+ log.debug("analyze : {}", analyze);
+ log.debug("mapreduce : {}", useMapreduce);
dataPath = new Path(dataName);
datasetPath = new Path(datasetName);
@@ -160,29 +166,42 @@ public class TestForest extends Configur
throw new IllegalArgumentException("The Test data path does not exist");
}
- // load the dataset
- Dataset dataset = Dataset.load(getConf(), datasetPath);
- DataConverter converter = new DataConverter(dataset);
+ if (useMapreduce) {
+ mapreduce();
+ } else {
+ sequential();
+ }
- log.info("Loading the forest...");
- Path[] modelfiles = DFUtils.listOutputFiles(mfs, modelPath);
- DecisionForest forest = null;
- for (Path path : modelfiles) {
- FSDataInputStream dataInput = new FSDataInputStream(mfs.open(path));
- if (forest == null) {
- forest = DecisionForest.read(dataInput);
- } else {
- forest.readFields(dataInput);
- }
+ }
+
+ private void mapreduce() throws ClassNotFoundException, IOException, InterruptedException {
+ if (analyze) {
+ log.warn("IMPORTANT: The current mapreduce implementation of TestForest does not support result analysis");
+ }
+
+ if (outputPath == null) {
+ throw new IllegalArgumentException("You must specify the ouputPath when using the mapreduce implementation");
+ }
+
+ Classifier classifier = new Classifier(modelPath, dataPath, datasetPath, outputPath, getConf());
- dataInput.close();
- }
+ classifier.run();
+ }
+
+ private void sequential() throws IOException {
+
+ log.info("Loading the forest...");
+ DecisionForest forest = DecisionForest.load(getConf(), modelPath);
if (forest == null) {
log.error("No Decision Forest found!");
return;
}
+ // load the dataset
+ Dataset dataset = Dataset.load(getConf(), datasetPath);
+ DataConverter converter = new DataConverter(dataset);
+
log.info("Sequential classification...");
long time = System.currentTimeMillis();