You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/09/26 12:18:20 UTC
svn commit: r1001402 - in /mahout/trunk: conf/driver.classes.props
core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java
Author: srowen
Date: Sun Sep 26 10:18:19 2010
New Revision: 1001402
URL: http://svn.apache.org/viewvc?rev=1001402&view=rev
Log:
MAHOUT-509
Modified:
mahout/trunk/conf/driver.classes.props
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java
Modified: mahout/trunk/conf/driver.classes.props
URL: http://svn.apache.org/viewvc/mahout/trunk/conf/driver.classes.props?rev=1001402&r1=1001401&r2=1001402&view=diff
==============================================================================
--- mahout/trunk/conf/driver.classes.props (original)
+++ mahout/trunk/conf/driver.classes.props Sun Sep 26 10:18:19 2010
@@ -27,3 +27,5 @@ org.apache.mahout.cf.taste.hadoop.simila
org.apache.mahout.classifier.sgd.TrainLogistic = trainlogistic : Train a logistic regression using stochastic gradient descentorg.apache.mahout.classifier.sgd.TrainLogistic = trainlogistic : Train a logistic regression using stochastic gradient descent
org.apache.mahout.classifier.sgd.RunLogistic = runlogistic : Run a logistic regression model against CSV data
org.apache.mahout.classifier.sgd.PrintResourceOrFile = cat : Print a file or resource as the logistic regression models would see it
+org.apache.mahout.classifier.bayes.WikipediaXmlSplitter = wikipediaXMLSplitter : Reads wikipedia data and creates ch
+org.apache.mahout.classifier.bayes.WikipediaDatasetCreatorDriver = wikipediaDataSetCreator : Splits data set of wikipedia wrt feature like country
\ No newline at end of file
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=1001402&r1=1001401&r2=1001402&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java Sun Sep 26 10:18:19 2010
@@ -127,29 +127,39 @@ public final class TestClassifier {
return;
}
+ BayesParameters params = new BayesParameters();
+ // Setting all default values
int gramSize = 1;
+ String classifierType = "bayes";
+ String dataSource = "hdfs";
+ String defaultCat = "unknown";
+ String encoding = "UTF-8";
+ String alphaI = "1.0";
+ String classificationMethod = "sequential";
+
+ String modelBasePath = (String) cmdLine.getValue(pathOpt);
+
if (cmdLine.hasOption(gramSizeOpt)) {
gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt));
}
- BayesParameters params = new BayesParameters(gramSize);
- String modelBasePath = (String) cmdLine.getValue(pathOpt);
+ if (cmdLine.hasOption(classifierType)) {
+ classifierType = (String) cmdLine.getValue(typeOpt);
+ }
- String classifierType = (String) cmdLine.getValue(typeOpt);
- String dataSource = (String) cmdLine.getValue(dataSourceOpt);
+ if (cmdLine.hasOption(dataSource)) {
+ dataSource = (String) cmdLine.getValue(dataSource);
+ }
- String defaultCat = "unknown";
if (cmdLine.hasOption(defaultCatOpt)) {
defaultCat = (String) cmdLine.getValue(defaultCatOpt);
}
- String encoding = "UTF-8";
if (cmdLine.hasOption(encodingOpt)) {
encoding = (String) cmdLine.getValue(encodingOpt);
}
- String alphaI = "1.0";
if (cmdLine.hasOption(alphaOpt)) {
alphaI = (String) cmdLine.getValue(alphaOpt);
}
@@ -158,11 +168,11 @@ public final class TestClassifier {
String testDirPath = (String) cmdLine.getValue(dirOpt);
- String classificationMethod = "sequential";
if (cmdLine.hasOption(methodOpt)) {
classificationMethod = (String) cmdLine.getValue(methodOpt);
}
+ params.setGramSize(gramSize);
params.set("verbose", Boolean.toString(verbose));
params.set("basePath", modelBasePath);
params.set("classifierType", classifierType);
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java?rev=1001402&r1=1001401&r2=1001402&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java Sun Sep 26 10:18:19 2010
@@ -76,7 +76,7 @@ public final class TrainClassifier {
abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
"The location of the model on the HDFS").withShortName("o").create();
- Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true).withArgument(
+ Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(false).withArgument(
abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).withDescription(
"Size of the n-gram. Default Value: 1 ").withShortName("ng").create();
@@ -92,11 +92,11 @@ public final class TrainClassifier {
abuilder.withName("a").withMinimum(1).withMaximum(1).create()).withDescription(
"Smoothing parameter Default Value: 1.0").withShortName("a").create();
- Option typeOpt = obuilder.withLongName("classifierType").withRequired(true).withArgument(
+ Option typeOpt = obuilder.withLongName("classifierType").withRequired(false).withArgument(
abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()).withDescription(
"Type of classifier: bayes|cbayes. Default: bayes").withShortName("type").create();
- Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(true).withArgument(
+ Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(false).withArgument(
abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create()).withDescription(
"Location of model: hdfs|hbase. Default Value: hdfs").withShortName("source").create();
@@ -121,6 +121,12 @@ public final class TrainClassifier {
String dataSourceType = (String) cmdLine.getValue(dataSourceOpt);
BayesParameters params = new BayesParameters();
+ // Setting all the default parameter values
+ params.setGramSize(1);
+ params.setMinDF(1);
+ params.set("alpha_i","1.0");
+ params.set("dataSource", "hdfs");
+
if (cmdLine.hasOption(gramSizeOpt)) {
params.setGramSize(Integer.parseInt((String) cmdLine.getValue(gramSizeOpt)));
}
@@ -137,29 +143,23 @@ public final class TrainClassifier {
params.setSkipCleanup(true);
}
- String alphaI = "1.0";
if (cmdLine.hasOption(alphaOpt)) {
- alphaI = (String) cmdLine.getValue(alphaOpt);
+ params.set("alpha_i",(String) cmdLine.getValue(alphaOpt));
}
- params.set("alpha_i", alphaI);
-
- if ("hbase".equals(dataSourceType)) {
- params.set("dataSource", "hbase");
- } else {
- params.set("dataSource", "hdfs");
- }
+ if (cmdLine.hasOption(dataSourceOpt)){
+ params.set("dataSource", dataSourceType);
+ }
Path inputPath = new Path((String) cmdLine.getValue(inputDirOpt));
Path outputPath = new Path((String) cmdLine.getValue(outputOpt));
- if ("bayes".equalsIgnoreCase(classifierType)) {
- log.info("Training Bayes Classifier");
- trainNaiveBayes(inputPath, outputPath, params);
-
- } else if ("cbayes".equalsIgnoreCase(classifierType)) {
+ if ("cbayes".equalsIgnoreCase(classifierType)) {
log.info("Training Complementary Bayes Classifier");
- // setup the HDFS and copy the files there, then run the trainer
trainCNaiveBayes(inputPath, outputPath, params);
+ } else {
+ log.info("Training Bayes Classifier");
+ // setup the HDFS and copy the files there, then run the trainer
+ trainNaiveBayes(inputPath, outputPath, params);
}
} catch (OptionException e) {
log.error("Error while parsing options", e);