You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2009/10/20 00:26:27 UTC

svn commit: r826841 - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/ core/src/main/java/org/apache/mahout/classifier/bayes/datastore/ core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/ core...

Author: robinanil
Date: Mon Oct 19 22:26:27 2009
New Revision: 826841

URL: http://svn.apache.org/viewvc?rev=826841&view=rev
Log:
MAHOUT-188 Cleanup of Bayes/CBayes Classifier

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/HBaseBayesDatastore.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java Mon Oct 19 22:26:27 2009
@@ -33,8 +33,6 @@
 
 public class BayesAlgorithm implements Algorithm{
 
-  private static final double alpha_i = 1.0;
-
   @Override
   public ClassifierResult classifyDocument(String[] document,
       Datastore datastore, String defaultCategory)
@@ -90,7 +88,7 @@
     double result = datastore.getWeight("weight", feature, label);
     double vocabCount = datastore.getWeight("sumWeight", "vocabCount");
     double sumLabelWeight = datastore.getWeight("labelWeight", label);    
-    double numerator = result + alpha_i;
+    double numerator = result + datastore.getWeight("params", "alpha_i");
     double denominator = (sumLabelWeight + vocabCount);
     double weight = Math.log(numerator / denominator);
     result = -weight;

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java Mon Oct 19 22:26:27 2009
@@ -33,8 +33,6 @@
 
 public class CBayesAlgorithm implements Algorithm {
 
-  private static final double alpha_i = 1.0;
-
   @Override
   public ClassifierResult classifyDocument(String[] document,
       Datastore datastore, String defaultCategory)
@@ -96,7 +94,7 @@
 
     double thetaNormalizer = datastore.getWeight("thetaNormalizer", label);
 
-    double numerator = sigma_j - result + alpha_i;
+    double numerator = sigma_j - result + datastore.getWeight("params", "alpha_i");
     double denominator = (sigma_jSigma_k - sigma_k + vocabCount);
     
     double weight = Math.log(numerator / denominator);

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/HBaseBayesDatastore.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/HBaseBayesDatastore.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/HBaseBayesDatastore.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/HBaseBayesDatastore.java Mon Oct 19 22:26:27 2009
@@ -58,10 +58,12 @@
     this.hbaseTable = hbaseTable;
     this.parameters = params;
     this.tableCache = new HybridCache<String, Result>(50000, 100000);
+    alpha_i = Double.valueOf(parameters.get("alpha_i", "1.0"));
   }
 
   protected double thetaNormalizer = 1.0d;
 
+  protected double alpha_i = 1.0d;
   @Override
   public void initialize() throws InvalidDatastoreException {
     config = new HBaseConfiguration(new Configuration());
@@ -142,6 +144,9 @@
     } else if (vectorName.equals("thetaNormalizer")) {
       return getWeightFromHbase(BayesConstants.LABEL_THETA_NORMALIZER, index)
           / thetaNormalizer;
+    } else if (vectorName.equals("params")) {
+      if(index.equals("alpha_i")) return alpha_i;
+      else throw new InvalidDatastoreException();
     } else {
 
       throw new InvalidDatastoreException();

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java Mon Oct 19 22:26:27 2009
@@ -33,12 +33,17 @@
 public class InMemoryBayesDatastore implements Datastore {
 
   final Map<String, Map<String, Map<String, Double>>> matrices = new HashMap<String, Map<String, Map<String, Double>>>();
+
   final Map<String, Map<String, Double>> vectors = new HashMap<String, Map<String, Double>>();
+
   Parameters params = null;
+
   protected double thetaNormalizer = 1.0d;
 
+  protected double alpha_i = 1.0d;
+
   public InMemoryBayesDatastore(Parameters params) {
-    
+
     matrices.put("weight", new HashMap<String, Map<String, Double>>());
     vectors.put("sumWeight", new HashMap<String, Double>());
     matrices.put("weight", new HashMap<String, Map<String, Double>>());
@@ -52,7 +57,7 @@
         + "/trainer-weights/Sigma_kSigma_j/part-*");
     params.set("thetaNormalizer", basePath + "/trainer-thetaNormalizer/part-*");
     params.set("weight", basePath + "/trainer-tfIdf/trainer-tfIdf/part-*");
-
+    alpha_i = Double.valueOf(params.get("alpha_i", "1.0"));
   }
 
   @Override
@@ -64,7 +69,7 @@
           .toUri(), conf), params, conf);
     } catch (IOException e) {
       throw new InvalidDatastoreException(e.getMessage());
-    }    
+    }
     updateVocabCount();
     Collection<String> labels = getKeys("thetaNormalizer");
     for (String label : labels) {
@@ -72,9 +77,9 @@
           "thetaNormalizer", label)));
     }
     for (String label : labels) {
-      System.out.println( label + ' ' +vectorGetCell(
-          "thetaNormalizer", label) + ' ' +thetaNormalizer + ' ' + vectorGetCell(
-          "thetaNormalizer", label)/thetaNormalizer);
+      System.out.println(label + ' ' + vectorGetCell("thetaNormalizer", label)
+          + ' ' + thetaNormalizer + ' '
+          + vectorGetCell("thetaNormalizer", label) / thetaNormalizer);
     }
   }
 
@@ -93,8 +98,12 @@
   @Override
   public double getWeight(String vectorName, String index)
       throws InvalidDatastoreException {
-    if(vectorName.equals("thetaNormalizer"))
-      return  vectorGetCell(vectorName, index)/thetaNormalizer;
+    if (vectorName.equals("thetaNormalizer"))
+      return vectorGetCell(vectorName, index) / thetaNormalizer;
+    else if (vectorName.equals("params")) {
+      if(index.equals("alpha_i")) return alpha_i;
+      else throw new InvalidDatastoreException();
+    } 
     return vectorGetCell(vectorName, index);
   }
 
@@ -173,7 +182,7 @@
   public void setSigma_jSigma_k(double weight) {
     vectorPutCell("sumWeight", "sigma_jSigma_k", weight);
   }
-  
+
   public void updateVocabCount() {
     vectorPutCell("sumWeight", "vocabCount", sizeOfMatrix("weight"));
   }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java Mon Oct 19 22:26:27 2009
@@ -26,6 +26,7 @@
 import org.apache.hadoop.mapred.Reporter;
 import org.apache.hadoop.util.GenericsUtil;
 import org.apache.mahout.classifier.bayes.mapreduce.common.BayesConstants;
+import org.apache.mahout.common.Parameters;
 import org.apache.mahout.common.StringTuple;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -42,7 +43,7 @@
   private Map<String, Double> labelWeightSum = null;
   private double sigma_jSigma_k = 0.0;
   private double vocabCount = 0.0;
-  //private final double alpha_i = 1.0;
+  private double alpha_i = 1.0;
 
   /**
    * We need to calculate the thetaNormalization factor of each label
@@ -58,7 +59,7 @@
     String label = key.stringAt(1);
 
     reporter.setStatus("Bayes Theta Normalizer Mapper: " + label);
-    double alpha_i = 1.0;
+    
     double weight = Math.log((value.get() + alpha_i) / (labelWeightSum.get(label) + vocabCount));
     StringTuple thetaNormalizerTuple = new StringTuple(BayesConstants.LABEL_THETA_NORMALIZER);
     thetaNormalizerTuple.add(label);
@@ -89,6 +90,9 @@
         String vocabCountString = stringifier.toString(vocabCount);
         vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString);
         vocabCount = stringifier.fromString(vocabCountString);
+        
+        Parameters params = Parameters.fromString(job.get("bayes.parameters", ""));
+        alpha_i = Double.valueOf(params.get("alpha_i", "1.0"));
 
       }
     } catch (IOException ex) {

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java Mon Oct 19 22:26:27 2009
@@ -26,6 +26,7 @@
 import org.apache.hadoop.mapred.Reporter;
 import org.apache.hadoop.util.GenericsUtil;
 import org.apache.mahout.classifier.bayes.mapreduce.common.BayesConstants;
+import org.apache.mahout.common.Parameters;
 import org.apache.mahout.common.StringTuple;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -42,7 +43,7 @@
   private Map<String, Double> labelWeightSum = null;
   private double sigma_jSigma_k = 0.0;
   private double vocabCount = 0.0;
-
+  private double alpha_i = 1.0;
   /**
    * We need to calculate the idf of each feature in each label
    *
@@ -55,7 +56,7 @@
 
     if (key.stringAt(0).equals(BayesConstants.FEATURE_SUM)) { // if it is from the Sigma_j folder
 
-      double alpha_i = 1.0;
+      
       for (Map.Entry<String, Double> stringDoubleEntry : labelWeightSum.entrySet()) {
         String label = stringDoubleEntry.getKey();
         double weight = Math.log((value.get() + alpha_i) / (sigma_jSigma_k - stringDoubleEntry.getValue() + vocabCount));
@@ -110,6 +111,9 @@
         String vocabCountString = stringifier.toString(vocabCount);
         vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString);
         vocabCount = stringifier.fromString(vocabCountString);
+        
+        Parameters params = Parameters.fromString(job.get("bayes.parameters", ""));
+        alpha_i = Double.valueOf(params.get("alpha_i", "1.0"));
 
       }
     } catch (IOException ex) {

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java Mon Oct 19 22:26:27 2009
@@ -18,6 +18,7 @@
 package org.apache.mahout.classifier.bayes;
 
 import org.apache.mahout.classifier.ClassifierResult;
+
 import org.apache.mahout.classifier.ResultAnalyzer;
 import org.apache.mahout.classifier.bayes.algorithm.BayesAlgorithm;
 import org.apache.mahout.classifier.bayes.algorithm.CBayesAlgorithm;
@@ -29,7 +30,9 @@
 import org.apache.mahout.classifier.bayes.interfaces.Datastore;
 import org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver;
 import org.apache.mahout.classifier.bayes.model.ClassifierContext;
+import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.common.TimingStatistics;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
 import org.apache.mahout.common.nlp.NGrams;
 import org.apache.mahout.common.FileLineIterable;
 import org.apache.commons.cli2.Option;
@@ -50,6 +53,12 @@
 import java.util.Map;
 import java.nio.charset.Charset;
 
+/**
+ * Test the Naive Bayes classifier with improved weighting
+ * <p/>
+ * To run the twenty newsgroups example: refer
+ * http://cwiki.apache.org/MAHOUT/twentynewsgroups.html
+ */
 public class TestClassifier {
 
   private static final Logger log = LoggerFactory
@@ -83,101 +92,124 @@
         .withDescription("The directory where test documents resides in")
         .withShortName("d").create();
 
+    Option helpOpt = DefaultOptionCreator.helpOption(obuilder);
+
     Option encodingOpt = obuilder.withLongName("encoding").withArgument(
         abuilder.withName("encoding").withMinimum(1).withMaximum(1).create())
         .withDescription("The file encoding.  Defaults to UTF-8")
         .withShortName("e").create();
 
-    Option analyzerOpt = obuilder.withLongName("analyzer").withArgument(
-        abuilder.withName("analyzer").withDefault(
-            "org.apache.lucene.analysis.standard.StandardAnalyzer")
-            .withMinimum(1).withMaximum(1).create()).withDescription(
-        "The Analyzer to use").withShortName("a").create();
-
     Option defaultCatOpt = obuilder.withLongName("defaultCat").withArgument(
         abuilder.withName("defaultCat").withMinimum(1).withMaximum(1).create())
-        .withDescription("The default category").withShortName("default")
-        .create();
+        .withDescription("The default category Default Value: unknown")
+        .withShortName("default").create();
 
     Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true)
         .withArgument(
             abuilder.withName("gramSize").withMinimum(1).withMaximum(1)
-                .create()).withDescription("Size of the n-gram").withShortName(
-            "ng").create();
+                .create()).withDescription(
+            "Size of the n-gram. Default Value: 1").withShortName("ng")
+        .create();
+
+    Option alphaOpt = obuilder.withLongName("alpha").withRequired(false)
+        .withArgument(
+            abuilder.withName("a").withMinimum(1).withMaximum(1).create())
+        .withDescription("Smoothing parameter Default Value: 1.0")
+        .withShortName("a").create();
+
     Option verboseOutputOpt = obuilder.withLongName("verbose").withRequired(
         false).withDescription(
         "Output which values were correctly and incorrectly classified")
         .withShortName("v").create();
+
     Option typeOpt = obuilder.withLongName("classifierType").withRequired(true)
         .withArgument(
             abuilder.withName("classifierType").withMinimum(1).withMaximum(1)
-                .create()).withDescription("Type of classifier: bayes|cbayes")
+                .create()).withDescription(
+            "Type of classifier: bayes|cbayes. Default Value: bayes")
         .withShortName("type").create();
 
     Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(
         true).withArgument(
         abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create())
-        .withDescription("Location of model: hdfs|hbase").withShortName(
-            "source").create();
+        .withDescription("Location of model: hdfs|hbase Default Value: hdfs")
+        .withShortName("source").create();
 
-    Option methodOpt = obuilder.withLongName("method").withRequired(true)
+    Option methodOpt = obuilder
+        .withLongName("method")
+        .withRequired(true)
         .withArgument(
             abuilder.withName("method").withMinimum(1).withMaximum(1).create())
-        .withDescription("Method of Classification: sequential|mapreduce")
+        .withDescription(
+            "Method of Classification: sequential|mapreduce. Default Value: sequential")
         .withShortName("method").create();
 
-    Group group = gbuilder.withName("Options").withOption(analyzerOpt)
-        .withOption(defaultCatOpt).withOption(dirOpt).withOption(encodingOpt)
-        .withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt)
-        .withOption(dataSourceOpt).withOption(methodOpt).withOption(
-            verboseOutputOpt).create();
-
-    Parser parser = new Parser();
-    parser.setGroup(group);
-    CommandLine cmdLine = parser.parse(args);
-
-    int gramSize = 1;
-    if (cmdLine.hasOption(gramSizeOpt)) {
-      gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt));
+    Group group = gbuilder.withName("Options").withOption(defaultCatOpt)
+        .withOption(dirOpt).withOption(encodingOpt).withOption(gramSizeOpt)
+        .withOption(pathOpt).withOption(typeOpt).withOption(dataSourceOpt)
+        .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt)
+        .withOption(alphaOpt).create();
+
+    try {
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        System.exit(0);
+      }
 
-    }
-    BayesParameters params = new BayesParameters(gramSize);
+      int gramSize = 1;
+      if (cmdLine.hasOption(gramSizeOpt)) {
+        gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt));
 
-    String modelBasePath = (String) cmdLine.getValue(pathOpt);
+      }
+      BayesParameters params = new BayesParameters(gramSize);
 
-    String classifierType = (String) cmdLine.getValue(typeOpt);
-    String dataSource = (String) cmdLine.getValue(dataSourceOpt);
+      String modelBasePath = (String) cmdLine.getValue(pathOpt);
 
-    String defaultCat = "unknown";
-    if (cmdLine.hasOption(defaultCatOpt)) {
-      defaultCat = (String) cmdLine.getValue(defaultCatOpt);
-    }
+      String classifierType = (String) cmdLine.getValue(typeOpt);
+      String dataSource = (String) cmdLine.getValue(dataSourceOpt);
 
-    String encoding = "UTF-8";
-    if (cmdLine.hasOption(encodingOpt)) {
-      encoding = (String) cmdLine.getValue(encodingOpt);
-    }
+      String defaultCat = "unknown";
+      if (cmdLine.hasOption(defaultCatOpt)) {
+        defaultCat = (String) cmdLine.getValue(defaultCatOpt);
+      }
 
-    boolean verbose = cmdLine.hasOption(verboseOutputOpt);
+      String encoding = "UTF-8";
+      if (cmdLine.hasOption(encodingOpt)) {
+        encoding = (String) cmdLine.getValue(encodingOpt);
+      }
 
-    String className = (String) cmdLine.getValue(analyzerOpt);
+      String alpha_i = "1.0";
+      if (cmdLine.hasOption(alphaOpt)) {
+        alpha_i = (String) cmdLine.getValue(alphaOpt);
+      }
 
-    String testDirPath = (String) cmdLine.getValue(dirOpt);
+      boolean verbose = cmdLine.hasOption(verboseOutputOpt);
 
-    String classificationMethod = (String) cmdLine.getValue(methodOpt);
+      String testDirPath = (String) cmdLine.getValue(dirOpt);
 
-    params.set("verbose", Boolean.toString(verbose));
-    params.set("basePath", modelBasePath);
-    params.set("classifierType", classifierType);
-    params.set("dataSource", dataSource);
-    params.set("defaultCat", defaultCat);
-    params.set("analyzer", className);
-    params.set("encoding", encoding);
-    params.set("testDirPath", testDirPath);
-    if (classificationMethod.equalsIgnoreCase("sequential"))
-      classifySequential(params);
-    else if (classificationMethod.equalsIgnoreCase("mapreduce"))
-      classifyParallel(params);
+      String classificationMethod = (String) cmdLine.getValue(methodOpt);
+
+      params.set("verbose", Boolean.toString(verbose));
+      params.set("basePath", modelBasePath);
+      params.set("classifierType", classifierType);
+      params.set("dataSource", dataSource);
+      params.set("defaultCat", defaultCat);
+      params.set("encoding", encoding);
+      params.set("alpha_i", alpha_i);
+      params.set("testDirPath", testDirPath);
+
+      if (classificationMethod.equalsIgnoreCase("sequential"))
+        classifySequential(params);
+      else if (classificationMethod.equalsIgnoreCase("mapreduce"))
+        classifyParallel(params);
+    } catch (OptionException e) {
+      CommandLineUtil.printHelp(group);
+      System.exit(0);
+    }
   }
 
   public static void classifySequential(BayesParameters params)

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java Mon Oct 19 22:26:27 2009
@@ -28,95 +28,137 @@
 import org.apache.mahout.classifier.bayes.common.BayesParameters;
 import org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesDriver;
 import org.apache.mahout.classifier.bayes.mapreduce.cbayes.CBayesDriver;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 
 /**
- * Train the Naive Bayes Complement classifier with improved weighting on the Twenty Newsgroups data (http://people.csail.mit.edu/jrennie/20Newsgroups/20news-18828.tar.gz)
+ * Train the Naive Bayes classifier with improved weighting
  * <p/>
- * To run:
- * Assume MAHOUT_HOME refers to the location where you checked out/installed Mahout
- * <ol>
- * <li>From the main dir: ant extract-20news-18828</li>
- * <li>ant job</li>
- * <li>Start up Hadoop and copy the files to the system. See http://hadoop.apache.org/core/docs/r0.16.2/quickstart.html</li>
- * <li>From the Hadoop dir (where Hadoop is installed):
- * <ol>
- * <li>emacs conf/hadoop-site.xml (add in local settings per quickstart)</li>
- * <li>bin/hadoop namenode -format  //Format the HDFS</li>
- * <li>bin/start-all.sh  //Start Hadoop</li>
- * <li>bin/hadoop dfs -put &lt;MAHOUT_HOME&gt;/work/20news-18828-collapse 20newsInput  //Copies the extracted text to HDFS</li>
- * <li>bin/hadoop jar &lt;MAHOUT_HOME&gt;/build/apache-mahout-0.1-dev-ex.jar org.apache.mahout.classifier.bayes.TraingClassifier -t -i 20newsInput -o 20newsOutput</li>
- * </ol>
- * </li>
- * </ol>
+ * To run the twenty newsgroups example: refer
+ * http://cwiki.apache.org/MAHOUT/twentynewsgroups.html
  */
 public class TrainClassifier {
 
-  private static final Logger log = LoggerFactory.getLogger(TrainClassifier.class);
+  private static final Logger log = LoggerFactory
+      .getLogger(TrainClassifier.class);
 
   private TrainClassifier() {
   }
 
-  public static void trainNaiveBayes(String dir, String outputDir, BayesParameters params) throws IOException, InterruptedException, ClassNotFoundException {
+  public static void trainNaiveBayes(String dir, String outputDir,
+      BayesParameters params) throws IOException, InterruptedException,
+      ClassNotFoundException {
     BayesDriver driver = new BayesDriver();
     driver.runJob(dir, outputDir, params);
   }
-  
-  public static void trainCNaiveBayes(String dir, String outputDir, BayesParameters params) throws IOException, InterruptedException, ClassNotFoundException {
+
+  public static void trainCNaiveBayes(String dir, String outputDir,
+      BayesParameters params) throws IOException, InterruptedException,
+      ClassNotFoundException {
     CBayesDriver driver = new CBayesDriver();
     driver.runJob(dir, outputDir, params);
   }
 
-  public static void main(String[] args) throws IOException, OptionException, NumberFormatException, IllegalStateException, InterruptedException, ClassNotFoundException {
+  public static void main(String[] args) throws IOException, OptionException,
+      NumberFormatException, IllegalStateException, InterruptedException,
+      ClassNotFoundException {
     DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
     ArgumentBuilder abuilder = new ArgumentBuilder();
     GroupBuilder gbuilder = new GroupBuilder();
 
-    Option inputDirOpt = obuilder.withLongName("input").withRequired(true).withArgument(
-            abuilder.withName("input").withMinimum(1).withMaximum(1).create()).
-            withDescription("The Directory on HDFS containing the collapsed, properly formatted files").withShortName("i").create();
-
-    Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
-            abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
-            withDescription("The location of the modelon the HDFS").withShortName("o").create();
-
-    Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true).withArgument(
-            abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).
-            withDescription("Size of the n-gram").withShortName("ng").create();
-
-    Option typeOpt = obuilder.withLongName("classifierType").withRequired(true).withArgument(
-            abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()).
-            withDescription("Type of classifier: bayes or cbayes").withShortName("type").create();
+    Option helpOpt = DefaultOptionCreator.helpOption(obuilder);
+
+    Option inputDirOpt = obuilder
+        .withLongName("input")
+        .withRequired(true)
+        .withArgument(
+            abuilder.withName("input").withMinimum(1).withMaximum(1).create())
+        .withDescription(
+            "The Directory on HDFS containing the collapsed, properly formatted files")
+        .withShortName("i").create();
+
+    Option outputOpt = obuilder.withLongName("output").withRequired(true)
+        .withArgument(
+            abuilder.withName("output").withMinimum(1).withMaximum(1).create())
+        .withDescription("The location of the modelon the HDFS").withShortName(
+            "o").create();
+
+    Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true)
+        .withArgument(
+            abuilder.withName("gramSize").withMinimum(1).withMaximum(1)
+                .create()).withDescription(
+            "Size of the n-gram. Default Value: 1 ").withShortName("ng")
+        .create();
+
+    Option alphaOpt = obuilder.withLongName("alpha").withRequired(false)
+        .withArgument(
+            abuilder.withName("a").withMinimum(1).withMaximum(1).create())
+        .withDescription("Smoothing parameter Default Value: 1.0")
+        .withShortName("a").create();
+
+    Option typeOpt = obuilder.withLongName("classifierType").withRequired(true)
+        .withArgument(
+            abuilder.withName("classifierType").withMinimum(1).withMaximum(1)
+                .create()).withDescription(
+            "Type of classifier: bayes|cbayes. Default: bayes").withShortName(
+            "type").create();
     Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(
         true).withArgument(
         abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create())
-        .withDescription("Location of model: hdfs|hbase").withShortName(
-            "source").create();
+        .withDescription("Location of model: hdfs|hbase. Default Value: hdfs")
+        .withShortName("source").create();
 
-    Group group = gbuilder.withName("Options").withOption(gramSizeOpt).withOption(inputDirOpt).withOption(outputOpt).withOption(typeOpt).withOption(dataSourceOpt).create();
-    Parser parser = new Parser();
-    parser.setGroup(group);
-    CommandLine cmdLine = parser.parse(args);
-    String classifierType = (String) cmdLine.getValue(typeOpt);
-    String dataSourceType = (String) cmdLine.getValue(dataSourceOpt);
-    BayesParameters params = new BayesParameters(Integer.parseInt((String) cmdLine.getValue(gramSizeOpt)));
-    
-    if(dataSourceType.equals("hbase"))
-      params.set("dataSource", "hbase");
-    else
-      params.set("dataSource", "hdfs");
-    
-    if (classifierType.equalsIgnoreCase("bayes")) {
-      log.info("Training Bayes Classifier");
-      trainNaiveBayes((String)cmdLine.getValue(inputDirOpt), (String)cmdLine.getValue(outputOpt), params);
-
-    } else if (classifierType.equalsIgnoreCase("cbayes")) {
-      log.info("Training Complementary Bayes Classifier");
-      //setup the HDFS and copy the files there, then run the trainer
-      trainCNaiveBayes((String) cmdLine.getValue(inputDirOpt), (String) cmdLine.getValue(outputOpt), params);
+    Group group = gbuilder.withName("Options").withOption(gramSizeOpt)
+        .withOption(helpOpt).withOption(inputDirOpt).withOption(outputOpt)
+        .withOption(typeOpt).withOption(dataSourceOpt).withOption(alphaOpt)
+        .create();
+    try {
+      Parser parser = new Parser();
+
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        System.exit(0);
+      }
+
+      String classifierType = (String) cmdLine.getValue(typeOpt);
+      String dataSourceType = (String) cmdLine.getValue(dataSourceOpt);
+
+      BayesParameters params = new BayesParameters(Integer
+          .parseInt((String) cmdLine.getValue(gramSizeOpt)));
+
+      String alpha_i = "1.0";
+      if (cmdLine.hasOption(alphaOpt)) {
+        alpha_i = (String) cmdLine.getValue(alphaOpt);
+      }
+
+      params.set("alpha_i", alpha_i);
+
+      if (dataSourceType.equals("hbase"))
+        params.set("dataSource", "hbase");
+      else
+        params.set("dataSource", "hdfs");
+
+      if (classifierType.equalsIgnoreCase("bayes")) {
+        log.info("Training Bayes Classifier");
+        trainNaiveBayes((String) cmdLine.getValue(inputDirOpt),
+            (String) cmdLine.getValue(outputOpt), params);
+
+      } else if (classifierType.equalsIgnoreCase("cbayes")) {
+        log.info("Training Complementary Bayes Classifier");
+        // setup the HDFS and copy the files there, then run the trainer
+        trainCNaiveBayes((String) cmdLine.getValue(inputDirOpt),
+            (String) cmdLine.getValue(outputOpt), params);
+      }
+    } catch (OptionException e) {
+      log.info("{}", e);
+      CommandLineUtil.printHelp(group);
+      System.exit(0);
     }
   }
 }