You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2010/04/03 21:09:28 UTC

svn commit: r930563 - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/df/DecisionForest.java core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java

Author: adeneche
Date: Sat Apr  3 19:09:28 2010
New Revision: 930563

URL: http://svn.apache.org/viewvc?rev=930563&view=rev
Log:
MAHOUT-323 Added a Basic Mapreduce version of TestForest

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java?rev=930563&r1=930562&r2=930563&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java Sat Apr  3 19:09:28 2010
@@ -30,6 +30,10 @@ import org.apache.mahout.df.data.DataUti
 import org.apache.mahout.df.data.Instance;
 import org.apache.mahout.df.node.Node;
 import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.conf.Configuration;
 
 /**
  * Represents a forest of decision trees.
@@ -195,5 +199,36 @@ public class DecisionForest implements W
     return forest;
   }
 
+  /**
+   * Load the forest from a single file or a directrory of files
+   * @param conf
+   * @param forestPath
+   * @return
+   * @throws IOException
+   */
+  public static DecisionForest load(Configuration conf, Path forestPath) throws IOException {
+    FileSystem fs = forestPath.getFileSystem(conf);
+    Path[] files = null;
+
+    if (fs.getFileStatus(forestPath).isDir())
+      files = DFUtils.listOutputFiles(fs, forestPath);
+    else
+      files = new Path[] {forestPath};
+
+    DecisionForest forest = null;
+    for (Path path : files) {
+      FSDataInputStream dataInput = new FSDataInputStream(fs.open(path));
+      if (forest == null) {
+        forest = DecisionForest.read(dataInput);
+      } else {
+        forest.readFields(dataInput);
+      }
+
+      dataInput.close();
+    }
+
+    return forest;
+    
+  }
 
 }

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java?rev=930563&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Classifier.java Sat Apr  3 19:09:28 2010
@@ -0,0 +1,239 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapreduce;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.apache.mahout.df.DecisionForest;
+import org.apache.mahout.df.DFUtils;
+import org.apache.mahout.df.data.DataConverter;
+import org.apache.mahout.df.data.Dataset;
+import org.apache.mahout.df.data.Instance;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.FileSplit;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+
+import java.io.IOException;
+import java.util.Random;
+import java.net.URI;
+
+/**
+ * Mapreduce implementation that classifies the Input data using a previousely built decision forest
+ */
+public class Classifier {
+
+  private static final Logger log = LoggerFactory.getLogger(Classifier.class);
+
+  private final Path forestPath;
+
+  private final Path inputPath;
+
+  private final Path datasetPath;
+
+  private final Configuration conf;
+
+  private final Path outputPath; // path that will containt the final output of the classifier
+  private final Path mappersOutputPath; // mappers will output here
+
+  private void configureJob(Job job) throws IOException {
+    Configuration conf = job.getConfiguration();
+
+    job.setJarByClass(Classifier.class);
+
+    FileInputFormat.setInputPaths(job, inputPath);
+    FileOutputFormat.setOutputPath(job, mappersOutputPath);
+
+    job.setOutputKeyClass(LongWritable.class);
+    job.setOutputValueClass(Text.class);
+
+    job.setMapperClass(CMapper.class);
+    job.setNumReduceTasks(0); // no reducers
+
+    job.setInputFormatClass(CTextInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+  }
+
+  public Classifier(Path forestPath, Path inputPath, Path datasetPath, Path outputPath, Configuration conf) {
+    this.forestPath = forestPath;
+    this.inputPath = inputPath;
+    this.datasetPath = datasetPath;
+    this.outputPath = outputPath;
+    this.conf = conf;
+
+    mappersOutputPath = new Path(outputPath, "mappers");
+  }
+
+  public void run() throws IOException, ClassNotFoundException, InterruptedException {
+    FileSystem fs = FileSystem.get(conf);
+
+    // check the output
+    if (fs.exists(outputPath)) {
+      throw new IOException("Output path already exists : " + outputPath);
+    }
+
+    log.info("Adding the dataset to the DistributedCache");
+    // put the dataset into the DistributedCache
+    DistributedCache.addCacheFile(datasetPath.toUri(), conf);
+
+    log.info("Adding the decision forest to the DistributedCache");
+    DistributedCache.addCacheFile(forestPath.toUri(), conf);
+
+    Job job = new Job(conf, "decision forest classifier");
+
+    log.info("Configuring the job...");
+    configureJob(job);
+
+    log.info("Running the job...");
+    if (!job.waitForCompletion(true)) {
+      log.error("Job failed!");
+      log.error("Job failed!");
+      return;
+    }
+
+    parseOutput(job);
+
+    // delete the output path
+    fs.delete(mappersOutputPath, true);
+  }
+
+  /**
+   * Extract the prediction for each mapper and write them in the corresponding output file. The name of the output file
+   * is based on the name of the corresponding input file
+   * @param job
+   */
+  private void parseOutput(Job job) throws IOException {
+    Configuration conf = job.getConfiguration();
+    FileSystem fs = mappersOutputPath.getFileSystem(conf);
+
+    Path[] outfiles = DFUtils.listOutputFiles(fs, mappersOutputPath);
+    FSDataOutputStream ofile = null;
+
+    // read all the output
+    LongWritable key = new LongWritable();
+    Text value = new Text();
+    for (Path path : outfiles) {
+      SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf);
+
+      try {
+        while (reader.next(key, value)) {
+          if (ofile == null) {
+            // this is the first value, it contains the name of the input file
+            ofile = fs.create(new Path(outputPath, value.toString()).suffix(".out"));
+          } else {
+            // the value contains a prediction
+            ofile.writeChars(value.toString()); // write the prediction
+            ofile.writeChar('\n');
+          }
+        }
+      } finally {
+        reader.close();
+        ofile.close();
+        ofile = null;
+      }
+    }
+
+  }
+
+  /**
+   * TextInputFormat that does not split the input files. This ensures that each input file is processed by one single
+   * mapper.
+   */
+  public static class CTextInputFormat extends TextInputFormat {
+
+    public CTextInputFormat() {
+      super();
+    }
+    
+    @Override
+    protected boolean isSplitable(JobContext jobContext, Path path) {
+      return false;
+    }
+  }
+  
+  public static class CMapper extends Mapper<LongWritable, Text, LongWritable, Text> {
+
+    /** used to convert input values to data instances */
+    private DataConverter converter;
+
+    private DecisionForest forest;
+
+    private Random rng = RandomUtils.getRandom();
+
+    private boolean first = true;
+
+    private Text lvalue = new Text();
+
+
+    @Override
+    protected void setup(Context context) throws IOException, InterruptedException {
+      super.setup(context);    //To change body of overridden methods use File | Settings | File Templates.
+
+      Configuration conf = context.getConfiguration();
+
+      URI[] files = DistributedCache.getCacheFiles(conf);
+
+      if ((files == null) || (files.length < 2)) {
+        throw new IOException("not enough paths in the DistributedCache");
+      }
+
+      Dataset dataset = Dataset.load(conf, new Path(files[0].getPath()));
+
+      converter = new DataConverter(dataset);
+
+      forest = DecisionForest.load(conf, new Path(files[1].getPath()));
+      if (forest == null) {
+        throw new InterruptedException("DecisionForest not found!");
+      }
+    }
+
+    @Override
+    protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
+      if (first) {
+        FileSplit split = (FileSplit) context.getInputSplit();
+        Path path = split.getPath(); // current split path
+        lvalue.set(path.getName());
+        context.write(key, new Text(path.getName()));
+
+        first = false;
+      }
+
+      String line = value.toString();
+      if (line.isEmpty()) return;
+
+      Instance instance = converter.convert(0, line);
+      int prediction = forest.classify(rng, instance);
+      lvalue.set(Integer.toString(prediction));
+      context.write(key, lvalue);
+    }
+  }
+}

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=930563&r1=930562&r2=930563&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sat Apr  3 19:09:28 2010
@@ -69,6 +69,8 @@ public class TestForest extends Configur
 
   private boolean analyze; // analyze the classification results ?
 
+  private boolean useMapreduce; // use the mapreduce classifier ?
+
   @Override
   public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
 
@@ -93,11 +95,13 @@ public class TestForest extends Configur
 
     Option analyzeOpt = obuilder.withLongName("analyze").withShortName("a").withRequired(false).create();
 
+    Option mrOpt = obuilder.withLongName("mapreduce").withShortName("mr").withRequired(false).create();
+
     Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
         .create();
 
-    Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(datasetOpt)
-        .withOption(modelOpt).withOption(outputOpt).withOption(analyzeOpt).withOption(helpOpt).create();
+    Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(datasetOpt).withOption(modelOpt)
+        .withOption(outputOpt).withOption(analyzeOpt).withOption(mrOpt).withOption(helpOpt).create();
 
     try {
       Parser parser = new Parser();
@@ -114,12 +118,14 @@ public class TestForest extends Configur
       String modelName = cmdLine.getValue(modelOpt).toString();
       String outputName = (cmdLine.hasOption(outputOpt)) ? cmdLine.getValue(outputOpt).toString() : null;
       analyze = cmdLine.hasOption(analyzeOpt);
+      useMapreduce = cmdLine.hasOption(mrOpt);
 
-      log.debug("inout   : {}", dataName);
-      log.debug("dataset : {}", datasetName);
-      log.debug("model   : {}", modelName);
-      log.debug("output  : {}", outputName);
-      log.debug("analyze : {}", analyze);
+      log.debug("inout     : {}", dataName);
+      log.debug("dataset   : {}", datasetName);
+      log.debug("model     : {}", modelName);
+      log.debug("output    : {}", outputName);
+      log.debug("analyze   : {}", analyze);
+      log.debug("mapreduce : {}", useMapreduce);
 
       dataPath = new Path(dataName);
       datasetPath = new Path(datasetName);
@@ -160,29 +166,42 @@ public class TestForest extends Configur
       throw new IllegalArgumentException("The Test data path does not exist");
     }
 
-    // load the dataset
-    Dataset dataset = Dataset.load(getConf(), datasetPath);
-    DataConverter converter = new DataConverter(dataset);
+    if (useMapreduce) {
+      mapreduce();
+    } else {
+      sequential();
+    }
 
-    log.info("Loading the forest...");
-    Path[] modelfiles = DFUtils.listOutputFiles(mfs, modelPath);
-    DecisionForest forest = null;
-    for (Path path : modelfiles) {
-      FSDataInputStream dataInput = new FSDataInputStream(mfs.open(path));
-      if (forest == null) {
-        forest = DecisionForest.read(dataInput);
-      } else {
-        forest.readFields(dataInput);
-      }
+  }
+
+  private void mapreduce() throws ClassNotFoundException, IOException, InterruptedException {
+    if (analyze) {
+      log.warn("IMPORTANT: The current mapreduce implementation of TestForest does not support result analysis");
+    }
+
+    if (outputPath == null) {
+      throw new IllegalArgumentException("You must specify the ouputPath when using the mapreduce implementation");
+    }
+    
+    Classifier classifier = new Classifier(modelPath, dataPath, datasetPath, outputPath, getConf());
 
-      dataInput.close();
-    }    
+    classifier.run();
+  }
+
+  private void sequential() throws IOException {
+
+    log.info("Loading the forest...");
+    DecisionForest forest = DecisionForest.load(getConf(), modelPath);
 
     if (forest == null) {
       log.error("No Decision Forest found!");
       return;
     }
 
+    // load the dataset
+    Dataset dataset = Dataset.load(getConf(), datasetPath);
+    DataConverter converter = new DataConverter(dataset);
+
     log.info("Sequential classification...");
     long time = System.currentTimeMillis();