You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2009/10/20 00:26:27 UTC
svn commit: r826841 - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/
core/src/main/java/org/apache/mahout/classifier/bayes/datastore/
core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/
core...
Author: robinanil
Date: Mon Oct 19 22:26:27 2009
New Revision: 826841
URL: http://svn.apache.org/viewvc?rev=826841&view=rev
Log:
MAHOUT-188 Cleanup of Bayes/CBayes Classifier
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/HBaseBayesDatastore.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/BayesAlgorithm.java Mon Oct 19 22:26:27 2009
@@ -33,8 +33,6 @@
public class BayesAlgorithm implements Algorithm{
- private static final double alpha_i = 1.0;
-
@Override
public ClassifierResult classifyDocument(String[] document,
Datastore datastore, String defaultCategory)
@@ -90,7 +88,7 @@
double result = datastore.getWeight("weight", feature, label);
double vocabCount = datastore.getWeight("sumWeight", "vocabCount");
double sumLabelWeight = datastore.getWeight("labelWeight", label);
- double numerator = result + alpha_i;
+ double numerator = result + datastore.getWeight("params", "alpha_i");
double denominator = (sumLabelWeight + vocabCount);
double weight = Math.log(numerator / denominator);
result = -weight;
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/algorithm/CBayesAlgorithm.java Mon Oct 19 22:26:27 2009
@@ -33,8 +33,6 @@
public class CBayesAlgorithm implements Algorithm {
- private static final double alpha_i = 1.0;
-
@Override
public ClassifierResult classifyDocument(String[] document,
Datastore datastore, String defaultCategory)
@@ -96,7 +94,7 @@
double thetaNormalizer = datastore.getWeight("thetaNormalizer", label);
- double numerator = sigma_j - result + alpha_i;
+ double numerator = sigma_j - result + datastore.getWeight("params", "alpha_i");
double denominator = (sigma_jSigma_k - sigma_k + vocabCount);
double weight = Math.log(numerator / denominator);
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/HBaseBayesDatastore.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/HBaseBayesDatastore.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/HBaseBayesDatastore.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/HBaseBayesDatastore.java Mon Oct 19 22:26:27 2009
@@ -58,10 +58,12 @@
this.hbaseTable = hbaseTable;
this.parameters = params;
this.tableCache = new HybridCache<String, Result>(50000, 100000);
+ alpha_i = Double.valueOf(parameters.get("alpha_i", "1.0"));
}
protected double thetaNormalizer = 1.0d;
+ protected double alpha_i = 1.0d;
@Override
public void initialize() throws InvalidDatastoreException {
config = new HBaseConfiguration(new Configuration());
@@ -142,6 +144,9 @@
} else if (vectorName.equals("thetaNormalizer")) {
return getWeightFromHbase(BayesConstants.LABEL_THETA_NORMALIZER, index)
/ thetaNormalizer;
+ } else if (vectorName.equals("params")) {
+ if(index.equals("alpha_i")) return alpha_i;
+ else throw new InvalidDatastoreException();
} else {
throw new InvalidDatastoreException();
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/datastore/InMemoryBayesDatastore.java Mon Oct 19 22:26:27 2009
@@ -33,12 +33,17 @@
public class InMemoryBayesDatastore implements Datastore {
final Map<String, Map<String, Map<String, Double>>> matrices = new HashMap<String, Map<String, Map<String, Double>>>();
+
final Map<String, Map<String, Double>> vectors = new HashMap<String, Map<String, Double>>();
+
Parameters params = null;
+
protected double thetaNormalizer = 1.0d;
+ protected double alpha_i = 1.0d;
+
public InMemoryBayesDatastore(Parameters params) {
-
+
matrices.put("weight", new HashMap<String, Map<String, Double>>());
vectors.put("sumWeight", new HashMap<String, Double>());
matrices.put("weight", new HashMap<String, Map<String, Double>>());
@@ -52,7 +57,7 @@
+ "/trainer-weights/Sigma_kSigma_j/part-*");
params.set("thetaNormalizer", basePath + "/trainer-thetaNormalizer/part-*");
params.set("weight", basePath + "/trainer-tfIdf/trainer-tfIdf/part-*");
-
+ alpha_i = Double.valueOf(params.get("alpha_i", "1.0"));
}
@Override
@@ -64,7 +69,7 @@
.toUri(), conf), params, conf);
} catch (IOException e) {
throw new InvalidDatastoreException(e.getMessage());
- }
+ }
updateVocabCount();
Collection<String> labels = getKeys("thetaNormalizer");
for (String label : labels) {
@@ -72,9 +77,9 @@
"thetaNormalizer", label)));
}
for (String label : labels) {
- System.out.println( label + ' ' +vectorGetCell(
- "thetaNormalizer", label) + ' ' +thetaNormalizer + ' ' + vectorGetCell(
- "thetaNormalizer", label)/thetaNormalizer);
+ System.out.println(label + ' ' + vectorGetCell("thetaNormalizer", label)
+ + ' ' + thetaNormalizer + ' '
+ + vectorGetCell("thetaNormalizer", label) / thetaNormalizer);
}
}
@@ -93,8 +98,12 @@
@Override
public double getWeight(String vectorName, String index)
throws InvalidDatastoreException {
- if(vectorName.equals("thetaNormalizer"))
- return vectorGetCell(vectorName, index)/thetaNormalizer;
+ if (vectorName.equals("thetaNormalizer"))
+ return vectorGetCell(vectorName, index) / thetaNormalizer;
+ else if (vectorName.equals("params")) {
+ if(index.equals("alpha_i")) return alpha_i;
+ else throw new InvalidDatastoreException();
+ }
return vectorGetCell(vectorName, index);
}
@@ -173,7 +182,7 @@
public void setSigma_jSigma_k(double weight) {
vectorPutCell("sumWeight", "sigma_jSigma_k", weight);
}
-
+
public void updateVocabCount() {
vectorPutCell("sumWeight", "vocabCount", sizeOfMatrix("weight"));
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/BayesThetaNormalizerMapper.java Mon Oct 19 22:26:27 2009
@@ -26,6 +26,7 @@
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.util.GenericsUtil;
import org.apache.mahout.classifier.bayes.mapreduce.common.BayesConstants;
+import org.apache.mahout.common.Parameters;
import org.apache.mahout.common.StringTuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -42,7 +43,7 @@
private Map<String, Double> labelWeightSum = null;
private double sigma_jSigma_k = 0.0;
private double vocabCount = 0.0;
- //private final double alpha_i = 1.0;
+ private double alpha_i = 1.0;
/**
* We need to calculate the thetaNormalization factor of each label
@@ -58,7 +59,7 @@
String label = key.stringAt(1);
reporter.setStatus("Bayes Theta Normalizer Mapper: " + label);
- double alpha_i = 1.0;
+
double weight = Math.log((value.get() + alpha_i) / (labelWeightSum.get(label) + vocabCount));
StringTuple thetaNormalizerTuple = new StringTuple(BayesConstants.LABEL_THETA_NORMALIZER);
thetaNormalizerTuple.add(label);
@@ -89,6 +90,9 @@
String vocabCountString = stringifier.toString(vocabCount);
vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString);
vocabCount = stringifier.fromString(vocabCountString);
+
+ Parameters params = Parameters.fromString(job.get("bayes.parameters", ""));
+ alpha_i = Double.valueOf(params.get("alpha_i", "1.0"));
}
} catch (IOException ex) {
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/cbayes/CBayesThetaNormalizerMapper.java Mon Oct 19 22:26:27 2009
@@ -26,6 +26,7 @@
import org.apache.hadoop.mapred.Reporter;
import org.apache.hadoop.util.GenericsUtil;
import org.apache.mahout.classifier.bayes.mapreduce.common.BayesConstants;
+import org.apache.mahout.common.Parameters;
import org.apache.mahout.common.StringTuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -42,7 +43,7 @@
private Map<String, Double> labelWeightSum = null;
private double sigma_jSigma_k = 0.0;
private double vocabCount = 0.0;
-
+ private double alpha_i = 1.0;
/**
* We need to calculate the idf of each feature in each label
*
@@ -55,7 +56,7 @@
if (key.stringAt(0).equals(BayesConstants.FEATURE_SUM)) { // if it is from the Sigma_j folder
- double alpha_i = 1.0;
+
for (Map.Entry<String, Double> stringDoubleEntry : labelWeightSum.entrySet()) {
String label = stringDoubleEntry.getKey();
double weight = Math.log((value.get() + alpha_i) / (sigma_jSigma_k - stringDoubleEntry.getValue() + vocabCount));
@@ -110,6 +111,9 @@
String vocabCountString = stringifier.toString(vocabCount);
vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString);
vocabCount = stringifier.fromString(vocabCountString);
+
+ Parameters params = Parameters.fromString(job.get("bayes.parameters", ""));
+ alpha_i = Double.valueOf(params.get("alpha_i", "1.0"));
}
} catch (IOException ex) {
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java Mon Oct 19 22:26:27 2009
@@ -18,6 +18,7 @@
package org.apache.mahout.classifier.bayes;
import org.apache.mahout.classifier.ClassifierResult;
+
import org.apache.mahout.classifier.ResultAnalyzer;
import org.apache.mahout.classifier.bayes.algorithm.BayesAlgorithm;
import org.apache.mahout.classifier.bayes.algorithm.CBayesAlgorithm;
@@ -29,7 +30,9 @@
import org.apache.mahout.classifier.bayes.interfaces.Datastore;
import org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesClassifierDriver;
import org.apache.mahout.classifier.bayes.model.ClassifierContext;
+import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.TimingStatistics;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.nlp.NGrams;
import org.apache.mahout.common.FileLineIterable;
import org.apache.commons.cli2.Option;
@@ -50,6 +53,12 @@
import java.util.Map;
import java.nio.charset.Charset;
+/**
+ * Test the Naive Bayes classifier with improved weighting
+ * <p/>
+ * To run the twenty newsgroups example: refer
+ * http://cwiki.apache.org/MAHOUT/twentynewsgroups.html
+ */
public class TestClassifier {
private static final Logger log = LoggerFactory
@@ -83,101 +92,124 @@
.withDescription("The directory where test documents resides in")
.withShortName("d").create();
+ Option helpOpt = DefaultOptionCreator.helpOption(obuilder);
+
Option encodingOpt = obuilder.withLongName("encoding").withArgument(
abuilder.withName("encoding").withMinimum(1).withMaximum(1).create())
.withDescription("The file encoding. Defaults to UTF-8")
.withShortName("e").create();
- Option analyzerOpt = obuilder.withLongName("analyzer").withArgument(
- abuilder.withName("analyzer").withDefault(
- "org.apache.lucene.analysis.standard.StandardAnalyzer")
- .withMinimum(1).withMaximum(1).create()).withDescription(
- "The Analyzer to use").withShortName("a").create();
-
Option defaultCatOpt = obuilder.withLongName("defaultCat").withArgument(
abuilder.withName("defaultCat").withMinimum(1).withMaximum(1).create())
- .withDescription("The default category").withShortName("default")
- .create();
+ .withDescription("The default category Default Value: unknown")
+ .withShortName("default").create();
Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true)
.withArgument(
abuilder.withName("gramSize").withMinimum(1).withMaximum(1)
- .create()).withDescription("Size of the n-gram").withShortName(
- "ng").create();
+ .create()).withDescription(
+ "Size of the n-gram. Default Value: 1").withShortName("ng")
+ .create();
+
+ Option alphaOpt = obuilder.withLongName("alpha").withRequired(false)
+ .withArgument(
+ abuilder.withName("a").withMinimum(1).withMaximum(1).create())
+ .withDescription("Smoothing parameter Default Value: 1.0")
+ .withShortName("a").create();
+
Option verboseOutputOpt = obuilder.withLongName("verbose").withRequired(
false).withDescription(
"Output which values were correctly and incorrectly classified")
.withShortName("v").create();
+
Option typeOpt = obuilder.withLongName("classifierType").withRequired(true)
.withArgument(
abuilder.withName("classifierType").withMinimum(1).withMaximum(1)
- .create()).withDescription("Type of classifier: bayes|cbayes")
+ .create()).withDescription(
+ "Type of classifier: bayes|cbayes. Default Value: bayes")
.withShortName("type").create();
Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(
true).withArgument(
abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create())
- .withDescription("Location of model: hdfs|hbase").withShortName(
- "source").create();
+ .withDescription("Location of model: hdfs|hbase Default Value: hdfs")
+ .withShortName("source").create();
- Option methodOpt = obuilder.withLongName("method").withRequired(true)
+ Option methodOpt = obuilder
+ .withLongName("method")
+ .withRequired(true)
.withArgument(
abuilder.withName("method").withMinimum(1).withMaximum(1).create())
- .withDescription("Method of Classification: sequential|mapreduce")
+ .withDescription(
+ "Method of Classification: sequential|mapreduce. Default Value: sequential")
.withShortName("method").create();
- Group group = gbuilder.withName("Options").withOption(analyzerOpt)
- .withOption(defaultCatOpt).withOption(dirOpt).withOption(encodingOpt)
- .withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt)
- .withOption(dataSourceOpt).withOption(methodOpt).withOption(
- verboseOutputOpt).create();
-
- Parser parser = new Parser();
- parser.setGroup(group);
- CommandLine cmdLine = parser.parse(args);
-
- int gramSize = 1;
- if (cmdLine.hasOption(gramSizeOpt)) {
- gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt));
+ Group group = gbuilder.withName("Options").withOption(defaultCatOpt)
+ .withOption(dirOpt).withOption(encodingOpt).withOption(gramSizeOpt)
+ .withOption(pathOpt).withOption(typeOpt).withOption(dataSourceOpt)
+ .withOption(helpOpt).withOption(methodOpt).withOption(verboseOutputOpt)
+ .withOption(alphaOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ System.exit(0);
+ }
- }
- BayesParameters params = new BayesParameters(gramSize);
+ int gramSize = 1;
+ if (cmdLine.hasOption(gramSizeOpt)) {
+ gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt));
- String modelBasePath = (String) cmdLine.getValue(pathOpt);
+ }
+ BayesParameters params = new BayesParameters(gramSize);
- String classifierType = (String) cmdLine.getValue(typeOpt);
- String dataSource = (String) cmdLine.getValue(dataSourceOpt);
+ String modelBasePath = (String) cmdLine.getValue(pathOpt);
- String defaultCat = "unknown";
- if (cmdLine.hasOption(defaultCatOpt)) {
- defaultCat = (String) cmdLine.getValue(defaultCatOpt);
- }
+ String classifierType = (String) cmdLine.getValue(typeOpt);
+ String dataSource = (String) cmdLine.getValue(dataSourceOpt);
- String encoding = "UTF-8";
- if (cmdLine.hasOption(encodingOpt)) {
- encoding = (String) cmdLine.getValue(encodingOpt);
- }
+ String defaultCat = "unknown";
+ if (cmdLine.hasOption(defaultCatOpt)) {
+ defaultCat = (String) cmdLine.getValue(defaultCatOpt);
+ }
- boolean verbose = cmdLine.hasOption(verboseOutputOpt);
+ String encoding = "UTF-8";
+ if (cmdLine.hasOption(encodingOpt)) {
+ encoding = (String) cmdLine.getValue(encodingOpt);
+ }
- String className = (String) cmdLine.getValue(analyzerOpt);
+ String alpha_i = "1.0";
+ if (cmdLine.hasOption(alphaOpt)) {
+ alpha_i = (String) cmdLine.getValue(alphaOpt);
+ }
- String testDirPath = (String) cmdLine.getValue(dirOpt);
+ boolean verbose = cmdLine.hasOption(verboseOutputOpt);
- String classificationMethod = (String) cmdLine.getValue(methodOpt);
+ String testDirPath = (String) cmdLine.getValue(dirOpt);
- params.set("verbose", Boolean.toString(verbose));
- params.set("basePath", modelBasePath);
- params.set("classifierType", classifierType);
- params.set("dataSource", dataSource);
- params.set("defaultCat", defaultCat);
- params.set("analyzer", className);
- params.set("encoding", encoding);
- params.set("testDirPath", testDirPath);
- if (classificationMethod.equalsIgnoreCase("sequential"))
- classifySequential(params);
- else if (classificationMethod.equalsIgnoreCase("mapreduce"))
- classifyParallel(params);
+ String classificationMethod = (String) cmdLine.getValue(methodOpt);
+
+ params.set("verbose", Boolean.toString(verbose));
+ params.set("basePath", modelBasePath);
+ params.set("classifierType", classifierType);
+ params.set("dataSource", dataSource);
+ params.set("defaultCat", defaultCat);
+ params.set("encoding", encoding);
+ params.set("alpha_i", alpha_i);
+ params.set("testDirPath", testDirPath);
+
+ if (classificationMethod.equalsIgnoreCase("sequential"))
+ classifySequential(params);
+ else if (classificationMethod.equalsIgnoreCase("mapreduce"))
+ classifyParallel(params);
+ } catch (OptionException e) {
+ CommandLineUtil.printHelp(group);
+ System.exit(0);
+ }
}
public static void classifySequential(BayesParameters params)
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java?rev=826841&r1=826840&r2=826841&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java Mon Oct 19 22:26:27 2009
@@ -28,95 +28,137 @@
import org.apache.mahout.classifier.bayes.common.BayesParameters;
import org.apache.mahout.classifier.bayes.mapreduce.bayes.BayesDriver;
import org.apache.mahout.classifier.bayes.mapreduce.cbayes.CBayesDriver;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
/**
- * Train the Naive Bayes Complement classifier with improved weighting on the Twenty Newsgroups data (http://people.csail.mit.edu/jrennie/20Newsgroups/20news-18828.tar.gz)
+ * Train the Naive Bayes classifier with improved weighting
* <p/>
- * To run:
- * Assume MAHOUT_HOME refers to the location where you checked out/installed Mahout
- * <ol>
- * <li>From the main dir: ant extract-20news-18828</li>
- * <li>ant job</li>
- * <li>Start up Hadoop and copy the files to the system. See http://hadoop.apache.org/core/docs/r0.16.2/quickstart.html</li>
- * <li>From the Hadoop dir (where Hadoop is installed):
- * <ol>
- * <li>emacs conf/hadoop-site.xml (add in local settings per quickstart)</li>
- * <li>bin/hadoop namenode -format //Format the HDFS</li>
- * <li>bin/start-all.sh //Start Hadoop</li>
- * <li>bin/hadoop dfs -put <MAHOUT_HOME>/work/20news-18828-collapse 20newsInput //Copies the extracted text to HDFS</li>
- * <li>bin/hadoop jar <MAHOUT_HOME>/build/apache-mahout-0.1-dev-ex.jar org.apache.mahout.classifier.bayes.TraingClassifier -t -i 20newsInput -o 20newsOutput</li>
- * </ol>
- * </li>
- * </ol>
+ * To run the twenty newsgroups example: refer
+ * http://cwiki.apache.org/MAHOUT/twentynewsgroups.html
*/
public class TrainClassifier {
- private static final Logger log = LoggerFactory.getLogger(TrainClassifier.class);
+ private static final Logger log = LoggerFactory
+ .getLogger(TrainClassifier.class);
private TrainClassifier() {
}
- public static void trainNaiveBayes(String dir, String outputDir, BayesParameters params) throws IOException, InterruptedException, ClassNotFoundException {
+ public static void trainNaiveBayes(String dir, String outputDir,
+ BayesParameters params) throws IOException, InterruptedException,
+ ClassNotFoundException {
BayesDriver driver = new BayesDriver();
driver.runJob(dir, outputDir, params);
}
-
- public static void trainCNaiveBayes(String dir, String outputDir, BayesParameters params) throws IOException, InterruptedException, ClassNotFoundException {
+
+ public static void trainCNaiveBayes(String dir, String outputDir,
+ BayesParameters params) throws IOException, InterruptedException,
+ ClassNotFoundException {
CBayesDriver driver = new CBayesDriver();
driver.runJob(dir, outputDir, params);
}
- public static void main(String[] args) throws IOException, OptionException, NumberFormatException, IllegalStateException, InterruptedException, ClassNotFoundException {
+ public static void main(String[] args) throws IOException, OptionException,
+ NumberFormatException, IllegalStateException, InterruptedException,
+ ClassNotFoundException {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
- Option inputDirOpt = obuilder.withLongName("input").withRequired(true).withArgument(
- abuilder.withName("input").withMinimum(1).withMaximum(1).create()).
- withDescription("The Directory on HDFS containing the collapsed, properly formatted files").withShortName("i").create();
-
- Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument(
- abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
- withDescription("The location of the modelon the HDFS").withShortName("o").create();
-
- Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true).withArgument(
- abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).
- withDescription("Size of the n-gram").withShortName("ng").create();
-
- Option typeOpt = obuilder.withLongName("classifierType").withRequired(true).withArgument(
- abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()).
- withDescription("Type of classifier: bayes or cbayes").withShortName("type").create();
+ Option helpOpt = DefaultOptionCreator.helpOption(obuilder);
+
+ Option inputDirOpt = obuilder
+ .withLongName("input")
+ .withRequired(true)
+ .withArgument(
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "The Directory on HDFS containing the collapsed, properly formatted files")
+ .withShortName("i").create();
+
+ Option outputOpt = obuilder.withLongName("output").withRequired(true)
+ .withArgument(
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create())
+ .withDescription("The location of the modelon the HDFS").withShortName(
+ "o").create();
+
+ Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true)
+ .withArgument(
+ abuilder.withName("gramSize").withMinimum(1).withMaximum(1)
+ .create()).withDescription(
+ "Size of the n-gram. Default Value: 1 ").withShortName("ng")
+ .create();
+
+ Option alphaOpt = obuilder.withLongName("alpha").withRequired(false)
+ .withArgument(
+ abuilder.withName("a").withMinimum(1).withMaximum(1).create())
+ .withDescription("Smoothing parameter Default Value: 1.0")
+ .withShortName("a").create();
+
+ Option typeOpt = obuilder.withLongName("classifierType").withRequired(true)
+ .withArgument(
+ abuilder.withName("classifierType").withMinimum(1).withMaximum(1)
+ .create()).withDescription(
+ "Type of classifier: bayes|cbayes. Default: bayes").withShortName(
+ "type").create();
Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(
true).withArgument(
abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create())
- .withDescription("Location of model: hdfs|hbase").withShortName(
- "source").create();
+ .withDescription("Location of model: hdfs|hbase. Default Value: hdfs")
+ .withShortName("source").create();
- Group group = gbuilder.withName("Options").withOption(gramSizeOpt).withOption(inputDirOpt).withOption(outputOpt).withOption(typeOpt).withOption(dataSourceOpt).create();
- Parser parser = new Parser();
- parser.setGroup(group);
- CommandLine cmdLine = parser.parse(args);
- String classifierType = (String) cmdLine.getValue(typeOpt);
- String dataSourceType = (String) cmdLine.getValue(dataSourceOpt);
- BayesParameters params = new BayesParameters(Integer.parseInt((String) cmdLine.getValue(gramSizeOpt)));
-
- if(dataSourceType.equals("hbase"))
- params.set("dataSource", "hbase");
- else
- params.set("dataSource", "hdfs");
-
- if (classifierType.equalsIgnoreCase("bayes")) {
- log.info("Training Bayes Classifier");
- trainNaiveBayes((String)cmdLine.getValue(inputDirOpt), (String)cmdLine.getValue(outputOpt), params);
-
- } else if (classifierType.equalsIgnoreCase("cbayes")) {
- log.info("Training Complementary Bayes Classifier");
- //setup the HDFS and copy the files there, then run the trainer
- trainCNaiveBayes((String) cmdLine.getValue(inputDirOpt), (String) cmdLine.getValue(outputOpt), params);
+ Group group = gbuilder.withName("Options").withOption(gramSizeOpt)
+ .withOption(helpOpt).withOption(inputDirOpt).withOption(outputOpt)
+ .withOption(typeOpt).withOption(dataSourceOpt).withOption(alphaOpt)
+ .create();
+ try {
+ Parser parser = new Parser();
+
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ System.exit(0);
+ }
+
+ String classifierType = (String) cmdLine.getValue(typeOpt);
+ String dataSourceType = (String) cmdLine.getValue(dataSourceOpt);
+
+ BayesParameters params = new BayesParameters(Integer
+ .parseInt((String) cmdLine.getValue(gramSizeOpt)));
+
+ String alpha_i = "1.0";
+ if (cmdLine.hasOption(alphaOpt)) {
+ alpha_i = (String) cmdLine.getValue(alphaOpt);
+ }
+
+ params.set("alpha_i", alpha_i);
+
+ if (dataSourceType.equals("hbase"))
+ params.set("dataSource", "hbase");
+ else
+ params.set("dataSource", "hdfs");
+
+ if (classifierType.equalsIgnoreCase("bayes")) {
+ log.info("Training Bayes Classifier");
+ trainNaiveBayes((String) cmdLine.getValue(inputDirOpt),
+ (String) cmdLine.getValue(outputOpt), params);
+
+ } else if (classifierType.equalsIgnoreCase("cbayes")) {
+ log.info("Training Complementary Bayes Classifier");
+ // setup the HDFS and copy the files there, then run the trainer
+ trainCNaiveBayes((String) cmdLine.getValue(inputDirOpt),
+ (String) cmdLine.getValue(outputOpt), params);
+ }
+ } catch (OptionException e) {
+ log.info("{}", e);
+ CommandLineUtil.printHelp(group);
+ System.exit(0);
}
}
}