You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2012/05/15 18:08:53 UTC
svn commit: r1338770 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/classifier/naivebayes/
core/src/main/java/org/apache/mahout/classifier/naivebayes/test/
core/src/main/java/org/apache/mahout/classifier/naivebayes/training/
core/src/test/jav...
Author: robinanil
Date: Tue May 15 16:08:52 2012
New Revision: 1338770
URL: http://svn.apache.org/viewvc?rev=1338770&view=rev
Log:
MAHOUT-1014 Rollback the multilabel change, fixes some bugs
Removed:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/naivebayes/
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
mahout/trunk/examples/bin/classify-20newsgroups.sh
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java Tue May 15 16:08:52 2012
@@ -42,7 +42,8 @@ public abstract class AbstractNaiveBayes
double result = 0.0;
Iterator<Element> elements = instance.iterateNonZero();
while (elements.hasNext()) {
- result += getScoreForLabelFeature(label, elements.next().index());
+ Element e = elements.next();
+ result += e.get() * getScoreForLabelFeature(label, e.index());
}
return result / model.thetaNormalizer(label);
}
@@ -53,17 +54,29 @@ public abstract class AbstractNaiveBayes
}
@Override
- public Vector classify(Vector instance) {
+ public Vector classifyFull(Vector instance) {
+ System.out.println(1);
Vector score = model.createScoringVector();
+ System.out.println(score.size());
for (int label = 0; label < model.numLabels(); label++) {
score.set(label, getScoreForLabelInstance(label, instance));
}
return score;
}
+
+ @Override
+ public Vector classifyFull(Vector r, Vector instance) {
+ r = classifyFull(instance);
+ return r;
+ }
@Override
public double classifyScalar(Vector instance) {
throw new UnsupportedOperationException("Not supported in Naive Bayes");
}
+ @Override
+ public Vector classify(Vector instance) {
+ throw new UnsupportedOperationException("probabilites not supported in Naive Bayes");
+ }
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java Tue May 15 16:08:52 2012
@@ -36,4 +36,5 @@ public class ComplementaryNaiveBayesClas
return Math.log(numerator / denominator);
}
+
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java Tue May 15 16:08:52 2012
@@ -17,19 +17,21 @@
package org.apache.mahout.classifier.naivebayes;
-import com.google.common.base.Preconditions;
-import com.google.common.io.Closeables;
+import java.io.IOException;
+
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FSDataOutputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
-import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.SparseRowMatrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import java.io.IOException;
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
/** NaiveBayesModel holds the weight Matrix, the feature and label sums and the weight normalizer vectors.*/
public class NaiveBayesModel {
@@ -105,10 +107,10 @@ public class NaiveBayesModel {
try {
alphaI = in.readFloat();
weightsPerFeature = VectorWritable.readVector(in);
- weightsPerLabel = VectorWritable.readVector(in);
- perLabelThetaNormalizer = VectorWritable.readVector(in);
+ weightsPerLabel = new DenseVector(VectorWritable.readVector(in));
+ perLabelThetaNormalizer = new DenseVector(VectorWritable.readVector(in));
- weightsPerLabelAndFeature = new SparseMatrix(weightsPerLabel.size(), weightsPerFeature.size() );
+ weightsPerLabelAndFeature = new SparseRowMatrix(weightsPerLabel.size(), weightsPerFeature.size() );
for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) {
weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in));
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java Tue May 15 16:08:52 2012
@@ -28,10 +28,10 @@ public class StandardNaiveBayesClassifie
@Override
public double getScoreForLabelFeature(int label, int feature) {
NaiveBayesModel model = getModel();
-
double numerator = model.weight(label, feature) + model.alphaI();
double denominator = model.labelWeight(label) + model.alphaI() * model.numFeatures();
return -Math.log(numerator / denominator);
}
+
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java Tue May 15 16:08:52 2012
@@ -42,6 +42,8 @@ public class BayesTestMapper extends Map
@Override
protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ System.out.println("Setup");
Configuration conf = context.getConfiguration();
Path modelPath = HadoopUtil.cachedFile(conf);
NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf);
@@ -55,9 +57,8 @@ public class BayesTestMapper extends Map
@Override
protected void map(Text key, VectorWritable value, Context context) throws IOException, InterruptedException {
- Vector result = classifier.classify(value.get());
+ Vector result = classifier.classifyFull(value.get());
//the key is the expected value
context.write(key, new VectorWritable(result));
-
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java Tue May 15 16:08:52 2012
@@ -17,8 +17,15 @@
package org.apache.mahout.classifier.naivebayes.test;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.Reader;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
@@ -26,7 +33,11 @@ import org.apache.hadoop.mapreduce.lib.o
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.classifier.ClassifierResult;
import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.Pair;
@@ -39,9 +50,6 @@ import org.apache.mahout.math.VectorWrit
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.List;
-import java.util.Map;
-
/**
* Test the (Complementary) Naive Bayes model that was built during training
* by running the iterating the test set and comparing it to the model
@@ -64,6 +72,7 @@ public class TestNaiveBayesDriver extend
addOption(addOption(DefaultOptionCreator.overwriteOption().create()));
addOption("model", "m", "The path to the model built during training", true);
addOption(buildOption("testComplementary", "c", "test complementary?", false, false, String.valueOf(false)));
+ addOption(buildOption("runSequential", "seq", "run sequential?", true, false, String.valueOf(false)));
addOption("labelIndex", "l", "The path to the location of the label index", true);
Map<String, List<String>> parsedArgs = parseArguments(args);
if (parsedArgs == null) {
@@ -72,18 +81,35 @@ public class TestNaiveBayesDriver extend
if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
HadoopUtil.delete(getConf(), getOutputPath());
}
- Path model = new Path(getOption("model"));
- HadoopUtil.cacheFiles(model, getConf());
- //the output key is the expected value, the output value are the scores for all the labels
- Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class,
- Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
- //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels"));
+
boolean complementary = parsedArgs.containsKey("testComplementary");
- testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary));
- boolean succeeded = testJob.waitForCompletion(true);
- if (!succeeded) {
- return -1;
+ boolean sequential = Boolean.parseBoolean(getOption("runSequential"));
+ if (sequential) {
+ FileSystem fs = FileSystem.get(getConf());
+ NaiveBayesModel model = NaiveBayesModel.materialize(new Path(getOption("model")), getConf());
+ AbstractNaiveBayesClassifier classifier;
+ if (complementary) {
+ classifier = new ComplementaryNaiveBayesClassifier(model);
+ } else {
+ classifier = new StandardNaiveBayesClassifier(model);
+ }
+ SequenceFile.Writer writer =
+ new SequenceFile.Writer(fs, getConf(), getOutputPath(), Text.class, VectorWritable.class);
+ SequenceFile.Reader reader = new Reader(fs, getInputPath(), getConf());
+ Text key = new Text();
+ VectorWritable vw = new VectorWritable();
+ while (reader.next(key, vw)) {
+ writer.append(key, new VectorWritable(classifier.classifyFull(vw.get())));
+ }
+ writer.close();
+ reader.close();
+ } else {
+ boolean succeeded = runMapReduce(parsedArgs);
+ if (!succeeded) {
+ return -1;
+ }
}
+
//load the labels
Map<Integer, String> labelMap = BayesUtils.readLabelIndex(getConf(), new Path(getOption("labelIndex")));
@@ -100,6 +126,20 @@ public class TestNaiveBayesDriver extend
return 0;
}
+ private boolean runMapReduce(Map<String, List<String>> parsedArgs) throws IOException,
+ InterruptedException, ClassNotFoundException {
+ Path model = new Path(getOption("model"));
+ HadoopUtil.cacheFiles(model, getConf());
+ //the output key is the expected value, the output value are the scores for all the labels
+ Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class,
+ Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+ //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels"));
+ boolean complementary = parsedArgs.containsKey("testComplementary");
+ testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary));
+ boolean succeeded = testJob.waitForCompletion(true);
+ return succeeded;
+ }
+
private static void analyzeResults(Map<Integer, String> labelMap,
SequenceFileDirIterable<Text, VectorWritable> dirIterable,
ResultAnalyzer analyzer) {
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java Tue May 15 16:08:52 2012
@@ -1,16 +1,18 @@
/**
- * Licensed to the Apache Software Foundation (ASF) under one or more contributor license
- * agreements. See the NOTICE file distributed with this work for additional information regarding
- * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance with the License. You may obtain a
- * copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software distributed under the License
- * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
- * or implied. See the License for the specific language governing permissions and limitations under
- * the License.
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
*/
package org.apache.mahout.classifier.naivebayes.training;
@@ -18,16 +20,31 @@ package org.apache.mahout.classifier.nai
import java.io.IOException;
import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.math.MultiLabelVectorWritable;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class IndexInstancesMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
+
+ public enum Counter { SKIPPED_INSTANCES }
+
+ private OpenObjectIntHashMap<String> labelIndex;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
+ labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration());
+ }
-public class IndexInstancesMapper
- extends Mapper<IntWritable, MultiLabelVectorWritable, IntWritable, VectorWritable> {
@Override
- protected void map(IntWritable key, MultiLabelVectorWritable instance, Context ctx)
- throws IOException, InterruptedException {
- VectorWritable vw = new VectorWritable(instance.getVector());
- ctx.write(new IntWritable(instance.getLabels()[0]), vw);
+ protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException {
+ String label = labelText.toString();
+ if (labelIndex.containsKey(label)) {
+ ctx.write(new IntWritable(labelIndex.get(label)), instance);
+ } else {
+ ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
+ }
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java Tue May 15 16:08:52 2012
@@ -17,6 +17,8 @@
package org.apache.mahout.classifier.naivebayes.training;
+import java.util.Iterator;
+
import org.apache.mahout.math.Vector;
public class StandardThetaTrainer extends AbstractThetaTrainer {
@@ -27,7 +29,14 @@ public class StandardThetaTrainer extend
@Override
public void train(int label, Vector instance) {
- double weight = Math.log((instance.zSum() + alphaI()) / (labelWeight(label) + alphaI() * numFeatures()));
- updatePerLabelThetaNormalizer(label, weight);
+ double sigmaK = labelWeight(label);
+ Iterator<Vector.Element> it = instance.iterateNonZero();
+ while (it.hasNext()) {
+ Vector.Element e = it.next();
+ double numerator = e.get() + alphaI();
+ double denominator = sigmaK + alphaI() * numFeatures();
+ double weight = Math.log(numerator / denominator);
+ updatePerLabelThetaNormalizer(label, weight);
+ }
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java Tue May 15 16:08:52 2012
@@ -37,6 +37,7 @@ public class ThetaMapper extends Mapper<
@Override
protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
Configuration conf = ctx.getConfiguration();
float alphaI = conf.getFloat(ALPHA_I, 1.0f);
@@ -60,5 +61,6 @@ public class ThetaMapper extends Mapper<
protected void cleanup(Context ctx) throws IOException, InterruptedException {
ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER),
new VectorWritable(trainer.retrievePerLabelThetaNormalizer()));
+ super.cleanup(ctx);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java Tue May 15 16:08:52 2012
@@ -17,10 +17,9 @@
package org.apache.mahout.classifier.naivebayes.training;
-import java.util.List;
-import java.util.Map;
-
+import com.google.common.base.Splitter;
import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
@@ -32,9 +31,16 @@ import org.apache.mahout.classifier.naiv
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
import org.apache.mahout.common.mapreduce.VectorSumReducer;
import org.apache.mahout.math.VectorWritable;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
/**
* This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes and Complementary Naive Bayes)
*/
@@ -54,11 +60,15 @@ public final class TrainNaiveBayesJob ex
@Override
public int run(String[] args) throws Exception {
+
addInputOption();
addOutputOption();
- addOption("labelSize", "ls", "Number of labels in the input data", String.valueOf(2));
+ addOption("labels", "l", "comma-separated list of labels to include in training", false);
+
+ addOption(buildOption("extractLabels", "el", "Extract the labels from the input", false, false, ""));
addOption("alphaI", "a", "smoothing parameter", String.valueOf(1.0f));
addOption(buildOption("trainComplementary", "c", "train complementary?", false, false, String.valueOf(false)));
+ addOption("labelIndex", "li", "The path to store the label index in", false);
addOption(DefaultOptionCreator.overwriteOption().create());
Map<String, List<String>> parsedArgs = parseArguments(args);
if (parsedArgs == null) {
@@ -68,12 +78,21 @@ public final class TrainNaiveBayesJob ex
HadoopUtil.delete(getConf(), getOutputPath());
HadoopUtil.delete(getConf(), getTempPath());
}
- int labelSize = Integer.parseInt(getOption("labelSize"));
+ Path labPath;
+ String labPathStr = getOption("labelIndex");
+ if (labPathStr != null) {
+ labPath = new Path(labPathStr);
+ } else {
+ labPath = getTempPath("labelIndex");
+ }
+ long labelSize = createLabelIndex(labPath);
float alphaI = Float.parseFloat(getOption("alphaI"));
boolean trainComplementary = Boolean.parseBoolean(getOption("trainComplementary"));
+
HadoopUtil.setSerializations(getConf());
-
+ HadoopUtil.cacheFiles(labPath, getConf());
+
//add up all the vectors with the same labels, while mapping the labels into our index
Job indexInstances = prepareJob(getInputPath(), getTempPath(SUMMED_OBSERVATIONS), SequenceFileInputFormat.class,
IndexInstancesMapper.class, IntWritable.class, VectorWritable.class, VectorSumReducer.class, IntWritable.class,
@@ -113,4 +132,18 @@ public final class TrainNaiveBayesJob ex
return 0;
}
+
+ private long createLabelIndex(Path labPath) throws IOException {
+ long labelSize = 0;
+ if (hasOption("labels")) {
+ Iterable<String> labels = Splitter.on(",").split(getOption("labels"));
+ labelSize = BayesUtils.writeLabelIndex(getConf(), labels, labPath);
+ } else if (hasOption("extractLabels")) {
+ SequenceFileDirIterable<Text, IntWritable> iterable =
+ new SequenceFileDirIterable<Text, IntWritable>(getInputPath(), PathType.LIST, PathFilters.logsCRCFilter(), getConf());
+ labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable);
+ }
+ return labelSize;
+ }
+
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java Tue May 15 16:08:52 2012
@@ -37,6 +37,7 @@ public class WeightsMapper extends Mappe
@Override
protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS));
Preconditions.checkArgument(numLabels > 0);
weightsPerLabel = new RandomAccessSparseVector(numLabels);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java Tue May 15 16:08:52 2012
@@ -37,10 +37,10 @@ public final class ComplementaryNaiveBay
@Test
public void testNaiveBayes() throws Exception {
assertEquals(4, classifier.numCategories());
- assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
- assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
- assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
- assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
+ assertEquals(0, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
+ assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
+ assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
+ assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java Tue May 15 16:08:52 2012
@@ -17,24 +17,23 @@
package org.apache.mahout.classifier.naivebayes;
-import java.io.File;
-
+import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.MultiLabelVectorWritable;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.hadoop.MathHelper;
import org.junit.Before;
import org.junit.Test;
-import com.google.common.io.Closeables;
+import java.io.File;
public class NaiveBayesTest extends MahoutTestCase {
@@ -43,8 +42,8 @@ public class NaiveBayesTest extends Maho
private File outputDir;
private File tempDir;
- static final String LABEL_STOLEN = "stolen";
- static final String LABEL_NOT_STOLEN = "not_stolen";
+ static final Text LABEL_STOLEN = new Text("stolen");
+ static final Text LABEL_NOT_STOLEN = new Text("not_stolen");
static final Vector.Element COLOR_RED = MathHelper.elem(0, 1);
static final Vector.Element COLOR_YELLOW = MathHelper.elem(1, 1);
@@ -67,19 +66,19 @@ public class NaiveBayesTest extends Maho
tempDir = getTestTempDir("tmp");
SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(conf), conf,
- new Path(inputFile.getAbsolutePath()), IntWritable.class, MultiLabelVectorWritable.class);
+ new Path(inputFile.getAbsolutePath()), Text.class, VectorWritable.class);
try {
- writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
- writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
- writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
- writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC));
- writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED));
- writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
- writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
- writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC));
- writer.append(new IntWritable(0), trainingInstance(LABEL_NOT_STOLEN, COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED));
- writer.append(new IntWritable(0), trainingInstance(LABEL_STOLEN, COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED));
} finally {
Closeables.closeQuietly(writer);
}
@@ -90,7 +89,7 @@ public class NaiveBayesTest extends Maho
TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
trainNaiveBayes.setConf(conf);
trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
- "--labelSize", "2", "--tempDir", tempDir.getAbsolutePath() });
+ "--labels", "stolen,not_stolen", "--tempDir", tempDir.getAbsolutePath() });
NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf);
@@ -98,7 +97,7 @@ public class NaiveBayesTest extends Maho
assertEquals(2, classifier.numCategories());
- Vector prediction = classifier.classify(trainingInstance("", COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).getVector());
+ Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
// should be classified as not stolen
assertTrue(prediction.get(0) < prediction.get(1));
@@ -109,7 +108,7 @@ public class NaiveBayesTest extends Maho
TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
trainNaiveBayes.setConf(conf);
trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
- "--labelSize", "2", "--trainComplementary",
+ "--labels", "stolen,not_stolen", "--trainComplementary",
"--tempDir", tempDir.getAbsolutePath() });
NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf);
@@ -118,18 +117,18 @@ public class NaiveBayesTest extends Maho
assertEquals(2, classifier.numCategories());
- Vector prediction = classifier.classify(trainingInstance("", COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).getVector());
+ Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
// should be classified as not stolen
assertTrue(prediction.get(0) < prediction.get(1));
}
- static MultiLabelVectorWritable trainingInstance(String label, Vector.Element... elems) {
+ static VectorWritable trainingInstance(Vector.Element... elems) {
DenseVector trainingInstance = new DenseVector(6);
for (Vector.Element elem : elems) {
trainingInstance.set(elem.index(), elem.get());
}
- return new MultiLabelVectorWritable(trainingInstance, new int[] {label.equals("stolen") ? 0 : 1});
+ return new VectorWritable(trainingInstance);
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java Tue May 15 16:08:52 2012
@@ -38,10 +38,10 @@ public final class StandardNaiveBayesCla
@Test
public void testNaiveBayes() throws Exception {
assertEquals(4, classifier.numCategories());
- assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
- assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
- assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
- assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
+ assertEquals(0, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
+ assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
+ assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
+ assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java Tue May 15 16:08:52 2012
@@ -18,20 +18,22 @@
package org.apache.mahout.classifier.naivebayes.training;
import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Counter;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
-import org.apache.mahout.math.MultiLabelVectorWritable;
import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
import org.easymock.EasyMock;
import org.junit.Before;
import org.junit.Test;
-@SuppressWarnings("unchecked")
public class IndexInstancesMapperTest extends MahoutTestCase {
- private static final DenseVector VECTOR = new DenseVector(new double[] { 1, 0, 1, 1, 0 });
+
private Mapper.Context ctx;
- private MultiLabelVectorWritable instance;
+ private OpenObjectIntHashMap<String> labelIndex;
+ private VectorWritable instance;
@Override
@Before
@@ -39,16 +41,45 @@ public class IndexInstancesMapperTest ex
super.setUp();
ctx = EasyMock.createMock(Mapper.Context.class);
- instance = new MultiLabelVectorWritable(VECTOR,
- new int[] {0});
+ instance = new VectorWritable(new DenseVector(new double[] { 1, 0, 1, 1, 0 }));
+
+ labelIndex = new OpenObjectIntHashMap<String>();
+ labelIndex.put("bird", 0);
+ labelIndex.put("cat", 1);
}
-
+
+
@Test
public void index() throws Exception {
- ctx.write(new IntWritable(0), new VectorWritable(VECTOR));
+
+ ctx.write(new IntWritable(0), instance);
+
EasyMock.replay(ctx);
+
IndexInstancesMapper indexInstances = new IndexInstancesMapper();
- indexInstances.map(new IntWritable(-1), instance, ctx);
+ setField(indexInstances, "labelIndex", labelIndex);
+
+ indexInstances.map(new Text("bird"), instance, ctx);
+
EasyMock.verify(ctx);
}
+
+ @Test
+ public void skip() throws Exception {
+
+ Counter skippedInstances = EasyMock.createMock(Counter.class);
+
+ EasyMock.expect(ctx.getCounter(IndexInstancesMapper.Counter.SKIPPED_INSTANCES)).andReturn(skippedInstances);
+ skippedInstances.increment(1);
+
+ EasyMock.replay(ctx, skippedInstances);
+
+ IndexInstancesMapper indexInstances = new IndexInstancesMapper();
+ setField(indexInstances, "labelIndex", labelIndex);
+
+ indexInstances.map(new Text("fish"), instance, ctx);
+
+ EasyMock.verify(ctx, skippedInstances);
+ }
+
}
Modified: mahout/trunk/examples/bin/classify-20newsgroups.sh
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/bin/classify-20newsgroups.sh?rev=1338770&r1=1338769&r2=1338770&view=diff
==============================================================================
--- mahout/trunk/examples/bin/classify-20newsgroups.sh (original)
+++ mahout/trunk/examples/bin/classify-20newsgroups.sh Tue May 15 16:08:52 2012
@@ -23,7 +23,7 @@
# examples/bin/build-20news.sh
if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
- echo "This script runs SGD and Bayes classifiers over the classic 20 News Groups."
+ echo "This script runs the SGD classifier over the classic 20 News Groups."
exit
fi
@@ -34,14 +34,13 @@ fi
START_PATH=`pwd`
WORK_DIR=/tmp/mahout-work-${USER}
-algorithm=( naivebayes sgd clean)
+algorithm=( sgd clean)
if [ -n "$1" ]; then
choice=$1
else
echo "Please select a number to choose the corresponding task to run"
echo "1. ${algorithm[0]}"
- echo "2. ${algorithm[1]}"
- echo "3. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
+ echo "2. ${algorithm[1]} -- cleans up the work area in $WORK_DIR"
read -p "Enter your choice : " choice
fi
@@ -68,15 +67,7 @@ cd ../..
set -e
-if [ "x$alg" == "xnaivebayes" ]; then
- if [ ! -e "/tmp/news-group.model" ]; then
- echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
- ./bin/mahout org.apache.mahout.classifier.naivebayes.TrainNewsGroups ${WORK_DIR}/20news-bydate/20news-bydate-train/ 0 \
- --input /tmp/news-group-train/ --output ${WORK_DIR}/news-group.naivebayes.model -ls 20 --tempDir ${WORK_DIR}/tmp/ -ow
- fi
- echo "Testing on ${WORK_DIR}/20news-bydate/20news-bydate-test/ with model: /tmp/news-group.model"
- # ./bin/mahout org.apache.mahout.classifier.sgd.TestNewsGroups --input ${WORK_DIR}/20news-bydate/20news-bydate-test/ --model /tmp/news-group.model
-elif [ "x$alg" == "xsgd" ]; then
+if [ "x$alg" == "xsgd" ]; then
if [ ! -e "/tmp/news-group.model" ]; then
echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
./bin/mahout org.apache.mahout.classifier.sgd.TrainNewsGroups ${WORK_DIR}/20news-bydate/20news-bydate-train/