You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ad...@apache.org on 2010/03/13 11:54:15 UTC

svn commit: r922523 - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/df/ core/src/main/java/org/apache/mahout/df/split/ core/src/main/java/org/apache/mahout/df/tools/ examples/src/main/java/org/apache/mahout/df/mapreduce/

Author: adeneche
Date: Sat Mar 13 10:54:14 2010
New Revision: 922523

URL: http://svn.apache.org/viewvc?rev=922523&view=rev
Log:
MAHOUT-323

Added:
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java?rev=922523&r1=922522&r2=922523&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DFUtils.java Sat Mar 13 10:54:14 2010
@@ -23,6 +23,9 @@ import java.io.IOException;
 
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.Writable;
 import org.apache.mahout.df.node.Node;
 import org.apache.mahout.ga.watchmaker.OutputUtils;
 
@@ -162,5 +165,12 @@ public class DFUtils {
     
     return hours + "h " + minutes + "m " + seconds + "s " + milli;
   }
-  
+
+  public static void storeWritable(Configuration conf, Path path, Writable writable) throws IOException {
+    FileSystem fs = path.getFileSystem(conf);
+
+    FSDataOutputStream out = fs.create(path);
+    writable.write(out);
+    out.close();
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java?rev=922523&r1=922522&r2=922523&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java Sat Mar 13 10:54:14 2010
@@ -20,17 +20,21 @@ package org.apache.mahout.df;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Random;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.io.DataInput;
 
 import org.apache.mahout.df.callback.PredictionCallback;
 import org.apache.mahout.df.data.Data;
 import org.apache.mahout.df.data.DataUtils;
 import org.apache.mahout.df.data.Instance;
 import org.apache.mahout.df.node.Node;
+import org.apache.hadoop.io.Writable;
 
 /**
  * Represents a forest of decision trees.
  */
-public class DecisionForest {
+public class DecisionForest implements Writable {
   
   private final List<Node> trees;
   
@@ -163,5 +167,31 @@ public class DecisionForest {
   public int hashCode() {
     return trees.hashCode();
   }
-  
+
+  public void write(DataOutput dataOutput) throws IOException {
+    dataOutput.writeInt(trees.size());
+    for (Node tree:trees) {
+      tree.write(dataOutput);
+    }
+  }
+
+  /**
+   * Reads the trees from the input and adds them to the existing trees
+   * @param dataInput
+   * @throws IOException
+   */
+  public void readFields(DataInput dataInput) throws IOException {
+    int size = dataInput.readInt();
+    for (int i = 0; i < size; i++) {
+      trees.add(Node.read(dataInput));
+    }
+  }
+
+  public static DecisionForest read(DataInput dataInput) throws IOException {
+    DecisionForest forest = new DecisionForest();
+    forest.readFields(dataInput);
+    return forest;
+  }
+
+
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java?rev=922523&r1=922522&r2=922523&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/split/Split.java Sat Mar 13 10:54:14 2010
@@ -40,4 +40,9 @@ public class Split {
   public Split(int attr, double ig) {
     this(attr, ig, Double.NaN);
   }
+
+  @Override
+  public String toString() {
+    return String.format("attr: %d, ig: %f, split: %f", attr, ig, split);
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java?rev=922523&r1=922522&r2=922523&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/df/tools/Describe.java Sat Mar 13 10:54:14 2010
@@ -30,15 +30,14 @@ import org.apache.commons.cli2.builder.D
 import org.apache.commons.cli2.builder.GroupBuilder;
 import org.apache.commons.cli2.commandline.Parser;
 import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FSDataOutputStream;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.Writable;
 import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.df.data.DataLoader;
 import org.apache.mahout.df.data.Dataset;
 import org.apache.mahout.df.data.DescriptorException;
 import org.apache.mahout.df.data.DescriptorUtils;
+import org.apache.mahout.df.DFUtils;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -94,7 +93,7 @@ public class Describe {
       
       runTool(dataPath, descriptor, descPath);
     } catch (OptionException e) {
-      log.warn(e.toString(), e);
+      log.warn(e.toString());
       CommandLineUtil.printHelp(group);
     }
   }
@@ -110,7 +109,7 @@ public class Describe {
     Dataset dataset = generateDataset(descriptor, dataPath);
     
     log.info("storing the dataset description");
-    storeWritable(new Configuration(), fPath, dataset);
+    DFUtils.storeWritable(new Configuration(), fPath, dataset);
   }
   
   private static Dataset generateDataset(String descriptor, String dataPath) throws IOException,
@@ -138,12 +137,5 @@ public class Describe {
     }
     return list;
   }
-  
-  private static void storeWritable(Configuration conf, Path path, Writable dataset) throws IOException {
-    FileSystem fs = path.getFileSystem(conf);
-    
-    FSDataOutputStream out = fs.create(path);
-    dataset.write(out);
-    out.close();
-  }
+
 }

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java?rev=922523&r1=922522&r2=922523&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java Sat Mar 13 10:54:14 2010
@@ -51,16 +51,17 @@ import org.slf4j.LoggerFactory;
 
 /**
  * Tool to builds a Random Forest using any given dataset (in UCI format). Can use either the in-mem mapred or
- * partial mapred implementations
+ * partial mapred implementations. Stores the forest in the given output directory
  */
 public class BuildForest extends Configured implements Tool {
   
   private static final Logger log = LoggerFactory.getLogger(BuildForest.class);
   
-  private Path dataPath; // Data path
-  
-  private Path datasetPath; // Dataset path
+  private Path dataPath;
   
+  private Path datasetPath;
+
+  private Path outputPath;
   private int m; // Number of variables to select at each tree-node
   
   private int nbTrees; // Number of trees to grow
@@ -70,7 +71,7 @@ public class BuildForest extends Configu
   private boolean isPartial; // use partial data implementation
   
   private boolean isOob; // estimate oob error;
-  
+
   @Override
   public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
     
@@ -103,12 +104,16 @@ public class BuildForest extends Configu
       abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create()).withDescription(
       "Number of trees to grow").create();
     
+    Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument(
+        abuilder.withName("path").withMinimum(1).withMaximum(1).create()).
+        withDescription("Output path, will contain the Decision Forest").create();
+
     Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
         .create();
     
     Group group = gbuilder.withName("Options").withOption(oobOpt).withOption(dataOpt).withOption(datasetOpt)
         .withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt)
-        .withOption(helpOpt).create();
+        .withOption(outputOpt).withOption(helpOpt).create();
     
     try {
       Parser parser = new Parser();
@@ -124,6 +129,7 @@ public class BuildForest extends Configu
       isOob = cmdLine.hasOption(oobOpt);
       String dataName = cmdLine.getValue(dataOpt).toString();
       String datasetName = cmdLine.getValue(datasetOpt).toString();
+      String outputName = cmdLine.getValue(outputOpt).toString();
       m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
       nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString());
       
@@ -133,6 +139,7 @@ public class BuildForest extends Configu
       
       log.debug("data : {}", dataName);
       log.debug("dataset : {}", datasetName);
+      log.debug("output : {}", outputName);
       log.debug("m : {}", m);
       log.debug("seed : {}", seed);
       log.debug("nbtrees : {}", nbTrees);
@@ -141,6 +148,7 @@ public class BuildForest extends Configu
       
       dataPath = new Path(dataName);
       datasetPath = new Path(datasetName);
+      outputPath = new Path(outputName);
       
     } catch (OptionException e) {
       log.error("Exception", e);
@@ -153,7 +161,14 @@ public class BuildForest extends Configu
     return 0;
   }
   
-  private DecisionForest buildForest() throws IOException, ClassNotFoundException, InterruptedException {
+  private void buildForest() throws IOException, ClassNotFoundException, InterruptedException {
+    // make sure the output path does not exist
+    FileSystem ofs = outputPath.getFileSystem(getConf());
+    if (ofs.exists(outputPath)) {
+      log.error("Output path already exists");
+      return;
+    }
+
     DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
     treeBuilder.setM(m);
     
@@ -171,6 +186,9 @@ public class BuildForest extends Configu
       log.info("InMem Mapred implementation");
       forestBuilder = new InMemBuilder(treeBuilder, dataPath, datasetPath, seed, getConf());
     }
+
+    forestBuilder.setOutputDirName(outputPath.getName());
+    
     log.info("Building the forest...");
     long time = System.currentTimeMillis();
     
@@ -193,8 +211,12 @@ public class BuildForest extends Configu
       log.info("oob error estimate : "
                            + ErrorEstimate.errorRate(labels, callback.computePredictions(rng)));
     }
-    
-    return forest;
+
+    // store the decision forest in the output path
+    Path forestPath = new Path(outputPath, "forest.seq");
+    log.info("Storing the forest in: " + forestPath);
+    DFUtils.storeWritable(getConf(), forestPath, forest);
+
   }
   
   protected static Data loadData(Configuration conf, Path dataPath, Dataset dataset) throws IOException {

Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=922523&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sat Mar 13 10:54:14 2010
@@ -0,0 +1,184 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.df.mapreduce;
+
+import java.io.IOException;
+import java.util.Random;
+import java.util.Scanner;
+import java.util.Arrays;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.conf.Configured;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.df.DFUtils;
+import org.apache.mahout.df.DecisionForest;
+import org.apache.mahout.df.ErrorEstimate;
+import org.apache.mahout.df.builder.DefaultTreeBuilder;
+import org.apache.mahout.df.callback.ForestPredictions;
+import org.apache.mahout.df.data.*;
+import org.apache.mahout.df.mapreduce.inmem.InMemBuilder;
+import org.apache.mahout.df.mapreduce.partial.PartialBuilder;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Tool to classify a Dataset using a previously built Decision Forest
+ */
+public class TestForest extends Configured implements Tool {
+
+  private static final Logger log = LoggerFactory.getLogger(TestForest.class);
+
+  private Path dataPath; // test data path
+
+  private Path datasetPath;
+
+  private Path modelPath; // path where the forest is stored
+
+  @Override
+  public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
+
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+
+    Option inputOpt = obuilder.withLongName("input").withShortName("i").withRequired(true).withArgument(
+      abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Test data path").create();
+
+    Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+      abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()).withDescription("Dataset path")
+        .create();
+
+    Option modelOpt = obuilder.withLongName("model").withShortName("m").withRequired(true).withArgument(
+        abuilder.withName("path").withMinimum(1).withMaximum(1).create()).
+        withDescription("Path to the Decision Forest").create();
+
+    Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+        .create();
+
+    Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(datasetOpt)
+        .withOption(modelOpt).withOption(helpOpt).create();
+
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+
+      if (cmdLine.hasOption("help")) {
+        CommandLineUtil.printHelp(group);
+        return -1;
+      }
+
+      String dataName = cmdLine.getValue(inputOpt).toString();
+      String datasetName = cmdLine.getValue(datasetOpt).toString();
+      String modelName = cmdLine.getValue(modelOpt).toString();
+
+      log.debug("inout : {}", dataName);
+      log.debug("dataset : {}", datasetName);
+      log.debug("model : {}", modelName);
+
+      dataPath = new Path(dataName);
+      datasetPath = new Path(datasetName);
+      modelPath = new Path(modelName);
+
+    } catch (OptionException e) {
+      System.err.println("Exception : " + e);
+      CommandLineUtil.printHelp(group);
+      return -1;
+    }
+
+    testForest();
+
+    return 0;
+  }
+
+  private void testForest() throws IOException, ClassNotFoundException, InterruptedException {
+    Dataset dataset = Dataset.load(getConf(), datasetPath);
+    DataConverter converter = new DataConverter(dataset);
+
+    log.info("Loading the forest...");
+    FileSystem fs = modelPath.getFileSystem(getConf());
+    Path[] modelfiles = DFUtils.listOutputFiles(fs, modelPath);
+    DecisionForest forest = null;
+    for (Path path : modelfiles) {
+      FSDataInputStream dataInput = new FSDataInputStream(fs.open(path));
+      if (forest == null) {
+        forest = DecisionForest.read(dataInput);
+      } else {
+        forest.readFields(dataInput);
+      }
+
+      dataInput.close();
+    }    
+
+    if (forest == null) {
+      log.error("No Decision Forest found!");
+      return;
+    }
+
+    log.info("Sequential classification...");
+    long time = System.currentTimeMillis();
+
+    FileSystem tfs = dataPath.getFileSystem(getConf());
+    FSDataInputStream input = tfs.open(dataPath);
+    Scanner scanner = new Scanner(input);
+    Random rng = RandomUtils.getRandom();
+    ResultAnalyzer analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown");
+
+    while (scanner.hasNextLine()) {
+      String line = scanner.nextLine();
+      if (line.isEmpty()) {
+        continue; // skip empty lines
+      }
+
+      Instance instance = converter.convert(0, line);
+      int prediction = forest.classify(rng, instance);
+
+      analyzer.addInstance(dataset.getLabel(instance.label), new ClassifierResult(dataset.getLabel(prediction), 1.0));
+    }
+
+    time = System.currentTimeMillis() - time;
+    log.info("Classification Time: {}", DFUtils.elapsedTime(time));
+
+    log.info(analyzer.summarize());
+  }
+
+  /**
+   * @param args
+   * @throws Exception
+   */
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new Configuration(), new TestForest(), args);
+  }
+
+}
\ No newline at end of file