You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2012/05/15 18:08:53 UTC

svn commit: r1338770 - in /mahout/trunk: core/src/main/java/org/apache/mahout/classifier/naivebayes/ core/src/main/java/org/apache/mahout/classifier/naivebayes/test/ core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ core/src/test/jav...

Author: robinanil
Date: Tue May 15 16:08:52 2012
New Revision: 1338770

URL: http://svn.apache.org/viewvc?rev=1338770&view=rev
Log:
MAHOUT-1014 Rollback the multilabel change, fixes some bugs

Removed:
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
    mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
    mahout/trunk/examples/bin/classify-20newsgroups.sh

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java Tue May 15 16:08:52 2012
@@ -42,7 +42,8 @@ public abstract class AbstractNaiveBayes
     double result = 0.0;
     Iterator<Element> elements = instance.iterateNonZero();
     while (elements.hasNext()) {
-      result += getScoreForLabelFeature(label, elements.next().index());
+      Element e = elements.next();
+      result += e.get() * getScoreForLabelFeature(label, e.index());
     }
     return result / model.thetaNormalizer(label);
   }
@@ -53,17 +54,29 @@ public abstract class AbstractNaiveBayes
   }
 
   @Override
-  public Vector classify(Vector instance) {
+  public Vector classifyFull(Vector instance) {
+    System.out.println(1);
     Vector score = model.createScoringVector();
+    System.out.println(score.size());
     for (int label = 0; label < model.numLabels(); label++) {
       score.set(label, getScoreForLabelInstance(label, instance));
     }
     return score;
   }
+  
+  @Override
+  public Vector classifyFull(Vector r, Vector instance) {
+    r = classifyFull(instance);
+    return r;
+  }
 
   @Override
   public double classifyScalar(Vector instance) {
     throw new UnsupportedOperationException("Not supported in Naive Bayes");
   }
   
+  @Override
+  public Vector classify(Vector instance) {
+    throw new UnsupportedOperationException("probabilites not supported in Naive Bayes");
+  }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java Tue May 15 16:08:52 2012
@@ -36,4 +36,5 @@ public class ComplementaryNaiveBayesClas
 
     return Math.log(numerator / denominator);
   }
+
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java Tue May 15 16:08:52 2012
@@ -17,19 +17,21 @@
 
 package org.apache.mahout.classifier.naivebayes;
 
-import com.google.common.base.Preconditions;
-import com.google.common.io.Closeables;
+import java.io.IOException;
+
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FSDataInputStream;
 import org.apache.hadoop.fs.FSDataOutputStream;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.SparseRowMatrix;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
-import java.io.IOException;
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
 
 /** NaiveBayesModel holds the weight Matrix, the feature and label sums and the weight normalizer vectors.*/
 public class NaiveBayesModel {
@@ -105,10 +107,10 @@ public class NaiveBayesModel {
     try {
       alphaI = in.readFloat();
       weightsPerFeature = VectorWritable.readVector(in);
-      weightsPerLabel = VectorWritable.readVector(in);
-      perLabelThetaNormalizer = VectorWritable.readVector(in);
+      weightsPerLabel = new DenseVector(VectorWritable.readVector(in));
+      perLabelThetaNormalizer = new DenseVector(VectorWritable.readVector(in));
 
-      weightsPerLabelAndFeature = new SparseMatrix(weightsPerLabel.size(), weightsPerFeature.size() );
+      weightsPerLabelAndFeature = new SparseRowMatrix(weightsPerLabel.size(), weightsPerFeature.size() );
       for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) {
         weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in));
       }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java Tue May 15 16:08:52 2012
@@ -28,10 +28,10 @@ public class StandardNaiveBayesClassifie
   @Override
   public double getScoreForLabelFeature(int label, int feature) {
     NaiveBayesModel model = getModel();
-
     double numerator = model.weight(label, feature) + model.alphaI();
     double denominator = model.labelWeight(label) + model.alphaI() * model.numFeatures();
 
     return -Math.log(numerator / denominator);
   }
+  
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java Tue May 15 16:08:52 2012
@@ -42,6 +42,8 @@ public class BayesTestMapper extends Map
 
   @Override
   protected void setup(Context context) throws IOException, InterruptedException {
+    super.setup(context);
+    System.out.println("Setup");
     Configuration conf = context.getConfiguration();
     Path modelPath = HadoopUtil.cachedFile(conf);
     NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf);
@@ -55,9 +57,8 @@ public class BayesTestMapper extends Map
 
   @Override
   protected void map(Text key, VectorWritable value, Context context) throws IOException, InterruptedException {
-    Vector result = classifier.classify(value.get());
+    Vector result = classifier.classifyFull(value.get());
     //the key is the expected value
     context.write(key, new VectorWritable(result));
-
   }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java Tue May 15 16:08:52 2012
@@ -17,8 +17,15 @@
 
 package org.apache.mahout.classifier.naivebayes.test;
 
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.Reader;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapreduce.Job;
 import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
@@ -26,7 +33,11 @@ import org.apache.hadoop.mapreduce.lib.o
 import org.apache.hadoop.util.ToolRunner;
 import org.apache.mahout.classifier.ClassifierResult;
 import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
 import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
 import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.Pair;
@@ -39,9 +50,6 @@ import org.apache.mahout.math.VectorWrit
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.List;
-import java.util.Map;
-
 /**
  * Test the (Complementary) Naive Bayes model that was built during training
  * by running the iterating the test set and comparing it to the model
@@ -64,6 +72,7 @@ public class TestNaiveBayesDriver extend
     addOption(addOption(DefaultOptionCreator.overwriteOption().create()));
     addOption("model", "m", "The path to the model built during training", true);
     addOption(buildOption("testComplementary", "c", "test complementary?", false, false, String.valueOf(false)));
+    addOption(buildOption("runSequential", "seq", "run sequential?", true, false, String.valueOf(false)));
     addOption("labelIndex", "l", "The path to the location of the label index", true);
     Map<String, List<String>> parsedArgs = parseArguments(args);
     if (parsedArgs == null) {
@@ -72,18 +81,35 @@ public class TestNaiveBayesDriver extend
     if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
       HadoopUtil.delete(getConf(), getOutputPath());
     }
-    Path model = new Path(getOption("model"));
-    HadoopUtil.cacheFiles(model, getConf());
-    //the output key is the expected value, the output value are the scores for all the labels
-    Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class,
-            Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
-    //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels"));
+    
     boolean complementary = parsedArgs.containsKey("testComplementary");
-    testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary));
-    boolean succeeded = testJob.waitForCompletion(true);
-    if (!succeeded) {
-      return -1;
+    boolean sequential = Boolean.parseBoolean(getOption("runSequential"));
+    if (sequential) {
+      FileSystem fs = FileSystem.get(getConf());
+      NaiveBayesModel model = NaiveBayesModel.materialize(new Path(getOption("model")), getConf());
+      AbstractNaiveBayesClassifier classifier;
+      if (complementary) {
+        classifier = new ComplementaryNaiveBayesClassifier(model);
+      } else {
+        classifier = new StandardNaiveBayesClassifier(model);
+      }
+      SequenceFile.Writer writer =
+          new SequenceFile.Writer(fs, getConf(), getOutputPath(), Text.class, VectorWritable.class);
+      SequenceFile.Reader reader = new Reader(fs, getInputPath(), getConf());
+      Text key = new Text();
+      VectorWritable vw = new VectorWritable();
+      while (reader.next(key, vw)) {
+        writer.append(key, new VectorWritable(classifier.classifyFull(vw.get())));
+      }
+      writer.close();
+      reader.close();
+    } else {
+      boolean succeeded = runMapReduce(parsedArgs);
+      if (!succeeded) {
+        return -1;
+      }
     }
+    
     //load the labels
     Map<Integer, String> labelMap = BayesUtils.readLabelIndex(getConf(), new Path(getOption("labelIndex")));
 
@@ -100,6 +126,20 @@ public class TestNaiveBayesDriver extend
     return 0;
   }
 
+  private boolean runMapReduce(Map<String, List<String>> parsedArgs) throws IOException,
+      InterruptedException, ClassNotFoundException {
+    Path model = new Path(getOption("model"));
+    HadoopUtil.cacheFiles(model, getConf());
+    //the output key is the expected value, the output value are the scores for all the labels
+    Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class,
+            Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+    //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels"));
+    boolean complementary = parsedArgs.containsKey("testComplementary");
+    testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary));
+    boolean succeeded = testJob.waitForCompletion(true);
+    return succeeded;
+  }
+
   private static void analyzeResults(Map<Integer, String> labelMap,
                                      SequenceFileDirIterable<Text, VectorWritable> dirIterable,
                                      ResultAnalyzer analyzer) {

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java Tue May 15 16:08:52 2012
@@ -1,16 +1,18 @@
 /**
- * Licensed to the Apache Software Foundation (ASF) under one or more contributor license
- * agreements. See the NOTICE file distributed with this work for additional information regarding
- * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance with the License. You may obtain a
- * copy of the License at
- * 
- * http://www.apache.org/licenses/LICENSE-2.0
- * 
- * Unless required by applicable law or agreed to in writing, software distributed under the License
- * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
- * or implied. See the License for the specific language governing permissions and limitations under
- * the License.
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
  */
 
 package org.apache.mahout.classifier.naivebayes.training;
@@ -18,16 +20,31 @@ package org.apache.mahout.classifier.nai
 import java.io.IOException;
 
 import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.math.MultiLabelVectorWritable;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class IndexInstancesMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
+
+  public enum Counter { SKIPPED_INSTANCES }
+
+  private OpenObjectIntHashMap<String> labelIndex;
+
+  @Override
+  protected void setup(Context ctx) throws IOException, InterruptedException {
+    super.setup(ctx);
+    labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration());
+  }
 
-public class IndexInstancesMapper
-    extends Mapper<IntWritable, MultiLabelVectorWritable, IntWritable, VectorWritable> {
   @Override
-  protected void map(IntWritable key, MultiLabelVectorWritable instance, Context ctx)
-      throws IOException, InterruptedException {
-    VectorWritable vw = new VectorWritable(instance.getVector());
-    ctx.write(new IntWritable(instance.getLabels()[0]), vw);
+  protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException {
+    String label = labelText.toString();
+    if (labelIndex.containsKey(label)) {
+      ctx.write(new IntWritable(labelIndex.get(label)), instance);
+    } else {
+      ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
+    }
   }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java Tue May 15 16:08:52 2012
@@ -17,6 +17,8 @@
 
  package org.apache.mahout.classifier.naivebayes.training;
 
+import java.util.Iterator;
+
 import org.apache.mahout.math.Vector;
 
 public class StandardThetaTrainer extends AbstractThetaTrainer {
@@ -27,7 +29,14 @@ public class StandardThetaTrainer extend
 
   @Override
   public void train(int label, Vector instance) {
-    double weight = Math.log((instance.zSum() + alphaI()) / (labelWeight(label) + alphaI() * numFeatures()));
-    updatePerLabelThetaNormalizer(label, weight);
+    double sigmaK = labelWeight(label);
+    Iterator<Vector.Element> it = instance.iterateNonZero();
+    while (it.hasNext()) {
+      Vector.Element e = it.next();
+      double numerator = e.get() + alphaI();
+      double denominator = sigmaK + alphaI() * numFeatures();
+      double weight = Math.log(numerator / denominator);
+      updatePerLabelThetaNormalizer(label, weight);
+    }
   }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java Tue May 15 16:08:52 2012
@@ -37,6 +37,7 @@ public class ThetaMapper extends Mapper<
 
   @Override
   protected void setup(Context ctx) throws IOException, InterruptedException {
+    super.setup(ctx);
     Configuration conf = ctx.getConfiguration();
 
     float alphaI = conf.getFloat(ALPHA_I, 1.0f);
@@ -60,5 +61,6 @@ public class ThetaMapper extends Mapper<
   protected void cleanup(Context ctx) throws IOException, InterruptedException {
     ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER),
         new VectorWritable(trainer.retrievePerLabelThetaNormalizer()));
+    super.cleanup(ctx);
   }
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java Tue May 15 16:08:52 2012
@@ -17,10 +17,9 @@
 
 package org.apache.mahout.classifier.naivebayes.training;
 
-import java.util.List;
-import java.util.Map;
-
+import com.google.common.base.Splitter;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapreduce.Job;
@@ -32,9 +31,16 @@ import org.apache.mahout.classifier.naiv
 import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.HadoopUtil;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
 import org.apache.mahout.common.mapreduce.VectorSumReducer;
 import org.apache.mahout.math.VectorWritable;
 
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
 /**
  * This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes and Complementary Naive Bayes)
  */
@@ -54,11 +60,15 @@ public final class TrainNaiveBayesJob ex
 
   @Override
   public int run(String[] args) throws Exception {
+
     addInputOption();
     addOutputOption();
-    addOption("labelSize", "ls", "Number of labels in the input data", String.valueOf(2));
+    addOption("labels", "l", "comma-separated list of labels to include in training", false);
+
+    addOption(buildOption("extractLabels", "el", "Extract the labels from the input", false, false, ""));
     addOption("alphaI", "a", "smoothing parameter", String.valueOf(1.0f));
     addOption(buildOption("trainComplementary", "c", "train complementary?", false, false, String.valueOf(false)));
+    addOption("labelIndex", "li", "The path to store the label index in", false);
     addOption(DefaultOptionCreator.overwriteOption().create());
     Map<String, List<String>> parsedArgs = parseArguments(args);
     if (parsedArgs == null) {
@@ -68,12 +78,21 @@ public final class TrainNaiveBayesJob ex
       HadoopUtil.delete(getConf(), getOutputPath());
       HadoopUtil.delete(getConf(), getTempPath());
     }
-    int labelSize = Integer.parseInt(getOption("labelSize"));
+    Path labPath;
+    String labPathStr = getOption("labelIndex");
+    if (labPathStr != null) {
+      labPath = new Path(labPathStr);
+    } else {
+      labPath = getTempPath("labelIndex");
+    }
+    long labelSize = createLabelIndex(labPath);
     float alphaI = Float.parseFloat(getOption("alphaI"));
     boolean trainComplementary = Boolean.parseBoolean(getOption("trainComplementary"));
 
+
     HadoopUtil.setSerializations(getConf());
-    
+    HadoopUtil.cacheFiles(labPath, getConf());
+
     //add up all the vectors with the same labels, while mapping the labels into our index
     Job indexInstances = prepareJob(getInputPath(), getTempPath(SUMMED_OBSERVATIONS), SequenceFileInputFormat.class,
             IndexInstancesMapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class,
@@ -113,4 +132,18 @@ public final class TrainNaiveBayesJob ex
 
     return 0;
   }
+
+  private long createLabelIndex(Path labPath) throws IOException {
+    long labelSize = 0;
+    if (hasOption("labels")) {
+      Iterable<String> labels = Splitter.on(",").split(getOption("labels"));
+      labelSize = BayesUtils.writeLabelIndex(getConf(), labels, labPath);
+    } else if (hasOption("extractLabels")) {
+      SequenceFileDirIterable<Text, IntWritable> iterable =
+              new SequenceFileDirIterable<Text, IntWritable>(getInputPath(), PathType.LIST, PathFilters.logsCRCFilter(), getConf());
+      labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable);
+    }
+    return labelSize;
+  }
+
 }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java Tue May 15 16:08:52 2012
@@ -37,6 +37,7 @@ public class WeightsMapper extends Mappe
 
   @Override
   protected void setup(Context ctx) throws IOException, InterruptedException {
+    super.setup(ctx);
     int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS));
     Preconditions.checkArgument(numLabels > 0);
     weightsPerLabel = new RandomAccessSparseVector(numLabels);

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java Tue May 15 16:08:52 2012
@@ -37,10 +37,10 @@ public final class ComplementaryNaiveBay
   @Test
   public void testNaiveBayes() throws Exception {
     assertEquals(4, classifier.numCategories());
-    assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
-    assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
-    assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
-    assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
+    assertEquals(0, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
+    assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
+    assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
+    assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
     
   }
   

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java Tue May 15 16:08:52 2012
@@ -17,24 +17,23 @@
 
 package org.apache.mahout.classifier.naivebayes;
 
-import java.io.File;
-
+import com.google.common.io.Closeables;
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
 import org.apache.mahout.classifier.AbstractVectorClassifier;
 import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
 import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.MultiLabelVectorWritable;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.hadoop.MathHelper;
 import org.junit.Before;
 import org.junit.Test;
 
-import com.google.common.io.Closeables;
+import java.io.File;
 
 public class NaiveBayesTest extends MahoutTestCase {
 
@@ -43,8 +42,8 @@ public class NaiveBayesTest extends Maho
   private File outputDir;
   private File tempDir;
 
-  static final String LABEL_STOLEN = "stolen";
-  static final String LABEL_NOT_STOLEN = "not_stolen";
+  static final Text LABEL_STOLEN = new Text("stolen");
+  static final Text LABEL_NOT_STOLEN = new Text("not_stolen");
 
   static final Vector.Element COLOR_RED = MathHelper.elem(0, 1);
   static final Vector.Element COLOR_YELLOW = MathHelper.elem(1, 1);
@@ -67,19 +66,19 @@ public class NaiveBayesTest extends Maho
     tempDir = getTestTempDir("tmp");
 
     SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(conf), conf,
-        new Path(inputFile.getAbsolutePath()), IntWritable.class, MultiLabelVectorWritable.class);
+        new Path(inputFile.getAbsolutePath()), Text.class, VectorWritable.class);
 
     try {
-      writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
-      writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
-      writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
-      writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC));
-      writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED));
-      writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
-      writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
-      writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC));
-      writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED));
-      writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED));
+      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC));
+      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED));
+      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC));
+      writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED));
+      writer.append(LABEL_STOLEN,      trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED));
     } finally {
       Closeables.closeQuietly(writer);
     }
@@ -90,7 +89,7 @@ public class NaiveBayesTest extends Maho
     TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
     trainNaiveBayes.setConf(conf);
     trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
-        "--labelSize", "2", "--tempDir", tempDir.getAbsolutePath() });
+        "--labels", "stolen,not_stolen", "--tempDir", tempDir.getAbsolutePath() });
 
     NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf);
 
@@ -98,7 +97,7 @@ public class NaiveBayesTest extends Maho
 
     assertEquals(2, classifier.numCategories());
 
-    Vector prediction = classifier.classify(trainingInstance("", COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).getVector());
+    Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
 
     // should be classified as not stolen
     assertTrue(prediction.get(0) < prediction.get(1));
@@ -109,7 +108,7 @@ public class NaiveBayesTest extends Maho
     TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
     trainNaiveBayes.setConf(conf);
     trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
-        "--labelSize", "2", "--trainComplementary",
+        "--labels", "stolen,not_stolen", "--trainComplementary",
         "--tempDir", tempDir.getAbsolutePath() });
 
     NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf);
@@ -118,18 +117,18 @@ public class NaiveBayesTest extends Maho
 
     assertEquals(2, classifier.numCategories());
 
-    Vector prediction = classifier.classify(trainingInstance("", COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).getVector());
+    Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
 
     // should be classified as not stolen
     assertTrue(prediction.get(0) < prediction.get(1));
   }
 
-  static MultiLabelVectorWritable trainingInstance(String label, Vector.Element... elems) {
+  static VectorWritable trainingInstance(Vector.Element... elems) {
     DenseVector trainingInstance = new DenseVector(6);
     for (Vector.Element elem : elems) {
       trainingInstance.set(elem.index(), elem.get());
     }
-    return new MultiLabelVectorWritable(trainingInstance, new int[] {label.equals("stolen") ? 0 : 1});
+    return new VectorWritable(trainingInstance);
   }
 
 

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java Tue May 15 16:08:52 2012
@@ -38,10 +38,10 @@ public final class StandardNaiveBayesCla
   @Test
   public void testNaiveBayes() throws Exception {
     assertEquals(4, classifier.numCategories());
-    assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
-    assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
-    assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
-    assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
+    assertEquals(0, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
+    assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
+    assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
+    assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
     
   }
   

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java Tue May 15 16:08:52 2012
@@ -18,20 +18,22 @@
 package org.apache.mahout.classifier.naivebayes.training;
 
 import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Counter;
 import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.MultiLabelVectorWritable;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
 import org.easymock.EasyMock;
 import org.junit.Before;
 import org.junit.Test;
 
-@SuppressWarnings("unchecked")
 public class IndexInstancesMapperTest extends MahoutTestCase {
-  private static final DenseVector VECTOR = new DenseVector(new double[] { 1, 0, 1, 1, 0 });
+
   private Mapper.Context ctx;
-  private MultiLabelVectorWritable instance;
+  private OpenObjectIntHashMap<String> labelIndex;
+  private VectorWritable instance;
 
   @Override
   @Before
@@ -39,16 +41,45 @@ public class IndexInstancesMapperTest ex
     super.setUp();
 
     ctx = EasyMock.createMock(Mapper.Context.class);
-    instance = new MultiLabelVectorWritable(VECTOR,
-      new int[] {0});
+    instance = new VectorWritable(new DenseVector(new double[] { 1, 0, 1, 1, 0 }));
+
+    labelIndex = new OpenObjectIntHashMap<String>();
+    labelIndex.put("bird", 0);
+    labelIndex.put("cat", 1);
   }
-  
+
+
   @Test
   public void index() throws Exception {
-    ctx.write(new IntWritable(0), new VectorWritable(VECTOR));
+
+    ctx.write(new IntWritable(0), instance);
+
     EasyMock.replay(ctx);
+
     IndexInstancesMapper indexInstances = new IndexInstancesMapper();
-    indexInstances.map(new IntWritable(-1), instance, ctx);
+    setField(indexInstances, "labelIndex", labelIndex);
+
+    indexInstances.map(new Text("bird"), instance, ctx);
+
     EasyMock.verify(ctx);
   }
+
+  @Test
+  public void skip() throws Exception {
+
+    Counter skippedInstances = EasyMock.createMock(Counter.class);
+
+    EasyMock.expect(ctx.getCounter(IndexInstancesMapper.Counter.SKIPPED_INSTANCES)).andReturn(skippedInstances);
+    skippedInstances.increment(1);
+
+    EasyMock.replay(ctx, skippedInstances);
+
+    IndexInstancesMapper indexInstances = new IndexInstancesMapper();
+    setField(indexInstances, "labelIndex", labelIndex);
+
+    indexInstances.map(new Text("fish"), instance, ctx);
+
+    EasyMock.verify(ctx, skippedInstances);
+  }
+
 }

Modified: mahout/trunk/examples/bin/classify-20newsgroups.sh
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/bin/classify-20newsgroups.sh?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/examples/bin/classify-20newsgroups.sh (original)
+++ mahout/trunk/examples/bin/classify-20newsgroups.sh Tue May 15 16:08:52 2012
@@ -23,7 +23,7 @@
 #  examples/bin/build-20news.sh
 
 if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
-  echo "This script runs SGD and Bayes classifiers over the classic 20 News Groups."
+  echo "This script runs the SGD classifier over the classic 20 News Groups."
   exit
 fi
 
@@ -34,14 +34,13 @@ fi
 START_PATH=`pwd`
 
 WORK_DIR=/tmp/mahout-work-${USER}
-algorithm=( naivebayes sgd clean)
+algorithm=( sgd clean)
 if [ -n "$1" ]; then
   choice=$1
 else
   echo "Please select a number to choose the corresponding task to run"
   echo "1. ${algorithm[0]}"
-  echo "2. ${algorithm[1]}"
-  echo "3. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
+  echo "2. ${algorithm[1]} -- cleans up the work area in $WORK_DIR"
   read -p "Enter your choice : " choice
 fi
 
@@ -68,15 +67,7 @@ cd ../..
 
 set -e
 
-if [ "x$alg" == "xnaivebayes" ]; then
-  if [ ! -e "/tmp/news-group.model" ]; then
-    echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
-    ./bin/mahout org.apache.mahout.classifier.naivebayes.TrainNewsGroups ${WORK_DIR}/20news-bydate/20news-bydate-train/ 0 \
-	--input /tmp/news-group-train/ --output ${WORK_DIR}/news-group.naivebayes.model -ls 20 --tempDir ${WORK_DIR}/tmp/ -ow
-  fi
-  echo "Testing on ${WORK_DIR}/20news-bydate/20news-bydate-test/ with model: /tmp/news-group.model"
-  # ./bin/mahout org.apache.mahout.classifier.sgd.TestNewsGroups --input ${WORK_DIR}/20news-bydate/20news-bydate-test/ --model /tmp/news-group.model
-elif [ "x$alg" == "xsgd" ]; then
+if [ "x$alg" == "xsgd" ]; then
   if [ ! -e "/tmp/news-group.model" ]; then
     echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
     ./bin/mahout org.apache.mahout.classifier.sgd.TrainNewsGroups ${WORK_DIR}/20news-bydate/20news-bydate-train/