You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/09/26 12:18:20 UTC

svn commit: r1001402 - in /mahout/trunk: conf/driver.classes.props core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java

Author: srowen
Date: Sun Sep 26 10:18:19 2010
New Revision: 1001402

URL: http://svn.apache.org/viewvc?rev=1001402&view=rev
Log:
MAHOUT-509

Modified:
    mahout/trunk/conf/driver.classes.props
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java

Modified: mahout/trunk/conf/driver.classes.props
URL: http://svn.apache.org/viewvc/mahout/trunk/conf/driver.classes.props?rev=1001402&r1=1001401&r2=1001402&view=diff
==============================================================================
--- mahout/trunk/conf/driver.classes.props (original)
+++ mahout/trunk/conf/driver.classes.props Sun Sep 26 10:18:19 2010
@@ -27,3 +27,5 @@ org.apache.mahout.cf.taste.hadoop.simila
 org.apache.mahout.classifier.sgd.TrainLogistic = trainlogistic : Train a logistic regression using stochastic gradient descentorg.apache.mahout.classifier.sgd.TrainLogistic = trainlogistic : Train a logistic regression using stochastic gradient descent
 org.apache.mahout.classifier.sgd.RunLogistic = runlogistic : Run a logistic regression model against CSV data
 org.apache.mahout.classifier.sgd.PrintResourceOrFile = cat : Print a file or resource as the logistic regression models would see it
+org.apache.mahout.classifier.bayes.WikipediaXmlSplitter = wikipediaXMLSplitter : Reads wikipedia data and creates ch  
+org.apache.mahout.classifier.bayes.WikipediaDatasetCreatorDriver = wikipediaDataSetCreator : Splits data set of wikipedia wrt feature like country
\ No newline at end of file

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=1001402&r1=1001401&r2=1001402&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java Sun Sep 26 10:18:19 2010
@@ -127,29 +127,39 @@ public final class TestClassifier {
         return;
       }
       
+      BayesParameters params = new BayesParameters();
+      // Setting all default values
       int gramSize = 1;
+      String classifierType = "bayes";      
+      String dataSource = "hdfs";
+      String defaultCat = "unknown";
+      String encoding = "UTF-8";
+      String alphaI = "1.0";
+      String classificationMethod = "sequential";
+
+      String modelBasePath = (String) cmdLine.getValue(pathOpt);
+      
       if (cmdLine.hasOption(gramSizeOpt)) {
         gramSize = Integer.parseInt((String) cmdLine.getValue(gramSizeOpt));
         
       }
-      BayesParameters params = new BayesParameters(gramSize);
       
-      String modelBasePath = (String) cmdLine.getValue(pathOpt);
+      if (cmdLine.hasOption(classifierType)) {
+        classifierType = (String) cmdLine.getValue(typeOpt);
+      }
       
-      String classifierType = (String) cmdLine.getValue(typeOpt);
-      String dataSource = (String) cmdLine.getValue(dataSourceOpt);
+      if (cmdLine.hasOption(dataSource)) {
+        dataSource = (String) cmdLine.getValue(dataSource);
+      }
       
-      String defaultCat = "unknown";
       if (cmdLine.hasOption(defaultCatOpt)) {
         defaultCat = (String) cmdLine.getValue(defaultCatOpt);
       }
       
-      String encoding = "UTF-8";
       if (cmdLine.hasOption(encodingOpt)) {
         encoding = (String) cmdLine.getValue(encodingOpt);
       }
       
-      String alphaI = "1.0";
       if (cmdLine.hasOption(alphaOpt)) {
         alphaI = (String) cmdLine.getValue(alphaOpt);
       }
@@ -158,11 +168,11 @@ public final class TestClassifier {
       
       String testDirPath = (String) cmdLine.getValue(dirOpt);
       
-      String classificationMethod = "sequential";
       if (cmdLine.hasOption(methodOpt)) {
         classificationMethod = (String) cmdLine.getValue(methodOpt);
       }
       
+      params.setGramSize(gramSize);
       params.set("verbose", Boolean.toString(verbose));
       params.set("basePath", modelBasePath);
       params.set("classifierType", classifierType);

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java?rev=1001402&r1=1001401&r2=1001402&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java Sun Sep 26 10:18:19 2010
@@ -76,7 +76,7 @@ public final class TrainClassifier {
       abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
       "The location of the model on the HDFS").withShortName("o").create();
     
-    Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true).withArgument(
+    Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(false).withArgument(
       abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).withDescription(
       "Size of the n-gram. Default Value: 1 ").withShortName("ng").create();
     
@@ -92,11 +92,11 @@ public final class TrainClassifier {
       abuilder.withName("a").withMinimum(1).withMaximum(1).create()).withDescription(
       "Smoothing parameter Default Value: 1.0").withShortName("a").create();
     
-    Option typeOpt = obuilder.withLongName("classifierType").withRequired(true).withArgument(
+    Option typeOpt = obuilder.withLongName("classifierType").withRequired(false).withArgument(
       abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()).withDescription(
       "Type of classifier: bayes|cbayes. Default: bayes").withShortName("type").create();
     
-    Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(true).withArgument(
+    Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(false).withArgument(
       abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create()).withDescription(
       "Location of model: hdfs|hbase. Default Value: hdfs").withShortName("source").create();
     
@@ -121,6 +121,12 @@ public final class TrainClassifier {
       String dataSourceType = (String) cmdLine.getValue(dataSourceOpt);
       
       BayesParameters params = new BayesParameters();
+      // Setting all the default parameter values
+      params.setGramSize(1);
+      params.setMinDF(1);
+      params.set("alpha_i","1.0");
+      params.set("dataSource", "hdfs");
+      
       if (cmdLine.hasOption(gramSizeOpt)) {
         params.setGramSize(Integer.parseInt((String) cmdLine.getValue(gramSizeOpt)));
       }
@@ -137,29 +143,23 @@ public final class TrainClassifier {
         params.setSkipCleanup(true);
       }
       
-      String alphaI = "1.0";
       if (cmdLine.hasOption(alphaOpt)) {
-        alphaI = (String) cmdLine.getValue(alphaOpt);
+        params.set("alpha_i",(String) cmdLine.getValue(alphaOpt));
       }
       
-      params.set("alpha_i", alphaI);
-      
-      if ("hbase".equals(dataSourceType)) {
-        params.set("dataSource", "hbase");
-      } else {
-        params.set("dataSource", "hdfs");
-      }
+      if (cmdLine.hasOption(dataSourceOpt)){
+        params.set("dataSource", dataSourceType);
+      } 
 
       Path inputPath = new Path((String) cmdLine.getValue(inputDirOpt));
       Path outputPath = new Path((String) cmdLine.getValue(outputOpt));
-      if ("bayes".equalsIgnoreCase(classifierType)) {
-        log.info("Training Bayes Classifier");
-        trainNaiveBayes(inputPath, outputPath, params);
-        
-      } else if ("cbayes".equalsIgnoreCase(classifierType)) {
+      if ("cbayes".equalsIgnoreCase(classifierType)) {
         log.info("Training Complementary Bayes Classifier");
-        // setup the HDFS and copy the files there, then run the trainer
         trainCNaiveBayes(inputPath, outputPath, params);
+      } else {
+        log.info("Training Bayes Classifier");
+        // setup the HDFS and copy the files there, then run the trainer
+        trainNaiveBayes(inputPath, outputPath, params);
       }
     } catch (OptionException e) {
       log.error("Error while parsing options", e);