You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/10/06 23:38:42 UTC
svn commit: r1005262 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/classifier/naivebayes/
main/java/org/apache/mahout/classifier/naivebayes/trainer/
main/java/org/apache/mahout/common/
test/java/org/apache/mahout/classifier/naivebayes/
Author: robinanil
Date: Wed Oct 6 21:38:41 2010
New Revision: 1005262
URL: http://svn.apache.org/viewvc?rev=1005262&view=rev
Log:
MAHOUT-287 Vector input based NaiveBayes classifier
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesSumReducer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaComplementaryMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/common/IntTuple.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.util.Iterator;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm
+ *
+ */
+public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier {
+ protected NaiveBayesModel model;
+
+ public AbstractNaiveBayesClassifier(NaiveBayesModel model) {
+ this.model = model;
+ }
+
+ public abstract double getScoreForLabelFeature(int label, int feature);
+
+ public double getScoreForLabelInstance(int label, Vector instance) {
+ double result = 0.0;
+ Iterator<Element> it = instance.iterateNonZero();
+ while (it.hasNext()) {
+ Element e = it.next();
+ result += getScoreForLabelFeature(label, e.index());
+ }
+ return result;
+ }
+
+ @Override
+ public int numCategories() {
+ return model.getNumLabels();
+ }
+
+ @Override
+ public Vector classify(Vector instance) {
+ Vector score = model.getLabelSum().like();
+ for (int i = 0; i < score.size(); i++) {
+ score.set(i, getScoreForLabelInstance(i, instance));
+ }
+ return score;
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ throw new UnsupportedOperationException("Not supported in Naive Bayes");
+ }
+
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/BayesConstants.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+/**
+ * Class containing Constants used by Naive Bayes classifier classes
+ *
+ */
+public final class BayesConstants {
+
+ // Ensure all the strings are unique
+ public static final String ALPHA_SMOOTHING_FACTOR = "__SF"; // -
+
+ public static final String WEIGHT = "__WT";
+
+ public static final String FEATURE_SUM = "__SJ";
+
+ public static final String LABEL_SUM = "__SK";
+
+ public static final String LABEL_THETA_NORMALIZER = "_LTN";
+
+ private BayesConstants() { }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm
+ *
+ */
+public class ComplementaryNaiveBayesClassifier extends AbstractNaiveBayesClassifier {
+
+ public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) {
+ super(model);
+ }
+
+ @Override
+ public double getScoreForLabelFeature(int label, int feature) {
+ double result = model.getWeightMatrix().get(label, feature);
+ double vocabCount = model.getVocabCount();
+ double featureSum = model.getFeatureSum().get(feature);
+ double totalSum = model.getTotalSum();
+ double labelSum = model.getLabelSum().get(label);
+ double numerator = featureSum - result + model.getAlphaI();
+ double denominator = totalSum - labelSum + vocabCount;
+ double weight = Math.log(numerator / denominator);
+ result = weight / model.getPerlabelThetaNormalizer().get(label);
+ return result;
+ }
+
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,304 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.classifier.naivebayes.trainer.NaiveBayesTrainer;
+import org.apache.mahout.math.JsonMatrixAdapter;
+import org.apache.mahout.math.JsonVectorAdapter;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import com.google.gson.JsonDeserializationContext;
+import com.google.gson.JsonDeserializer;
+import com.google.gson.JsonElement;
+import com.google.gson.JsonObject;
+import com.google.gson.JsonParseException;
+import com.google.gson.JsonPrimitive;
+import com.google.gson.JsonSerializationContext;
+import com.google.gson.JsonSerializer;
+
+/**
+ *
+ * NaiveBayesModel holds the weight Matrix, the feature and label sums and the weight normalizer vectors.
+ *
+ */
+public class NaiveBayesModel implements JsonDeserializer<NaiveBayesModel>, JsonSerializer<NaiveBayesModel>, Cloneable {
+
+ private Vector labelSum;
+ private Vector perlabelThetaNormalizer;
+ private Vector featureSum;
+ private Matrix weightMatrix;
+ private float alphaI;
+ private double vocabCount;
+ private double totalSum;
+
+ private NaiveBayesModel() {
+ // do nothing
+ }
+
+ public NaiveBayesModel(Matrix matrix, Vector featureSum, Vector labelSum, Vector thetaNormalizer, float alphaI) {
+ this.weightMatrix = matrix;
+ this.featureSum = featureSum;
+ this.labelSum = labelSum;
+ this.perlabelThetaNormalizer = thetaNormalizer;
+ this.vocabCount = featureSum.getNumNondefaultElements();
+ this.totalSum = labelSum.zSum();
+ this.alphaI = alphaI;
+ }
+
+ private void setLabelSum(Vector labelSum) {
+ this.labelSum = labelSum;
+ }
+
+
+ public void setPerlabelThetaNormalizer(Vector perlabelThetaNormalizer) {
+ this.perlabelThetaNormalizer = perlabelThetaNormalizer;
+ }
+
+
+ public void setFeatureSum(Vector featureSum) {
+ this.featureSum = featureSum;
+ }
+
+
+ public void setWeightMatrix(Matrix weightMatrix) {
+ this.weightMatrix = weightMatrix;
+ }
+
+
+ public void setAlphaI(float alphaI) {
+ this.alphaI = alphaI;
+ }
+
+
+ public void setVocabCount(double vocabCount) {
+ this.vocabCount = vocabCount;
+ }
+
+
+ public void setTotalSum(double totalSum) {
+ this.totalSum = totalSum;
+ }
+
+ public Vector getLabelSum() {
+ return labelSum;
+ }
+
+ public Vector getPerlabelThetaNormalizer() {
+ return perlabelThetaNormalizer;
+ }
+
+ public Vector getFeatureSum() {
+ return featureSum;
+ }
+
+ public Matrix getWeightMatrix() {
+ return weightMatrix;
+ }
+
+ public float getAlphaI() {
+ return alphaI;
+ }
+
+ public double getVocabCount() {
+ return vocabCount;
+ }
+
+ public double getTotalSum() {
+ return totalSum;
+ }
+
+ public int getNumLabels() {
+ return labelSum.size();
+ }
+
+ public static String getModelName() {
+ return MODEL;
+ }
+
+ // CODE USED FOR SERIALIZATION
+ public static NaiveBayesModel fromMRTrainerOutput(Path output, Configuration conf) throws IOException {
+ Path classVectorPath = new Path(output, NaiveBayesTrainer.CLASS_VECTORS);
+ Path sumVectorPath = new Path(output, NaiveBayesTrainer.SUM_VECTORS);
+ Path thetaSumPath = new Path(output, NaiveBayesTrainer.THETA_SUM);
+
+ NaiveBayesModel model = new NaiveBayesModel();
+ model.setAlphaI(conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f));
+
+ FileSystem fs = sumVectorPath.getFileSystem(conf);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, sumVectorPath, conf);
+ Text key = new Text();
+ VectorWritable value = new VectorWritable();
+
+ int featureCount = 0;
+ int labelCount = 0;
+ // read feature sums and label sums
+ while (reader.next(key, value)) {
+ if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
+ model.setFeatureSum(value.get());
+ featureCount = value.get().getNumNondefaultElements();
+ model.setVocabCount(featureCount);
+ } else if (key.toString().equals(BayesConstants.LABEL_SUM)) {
+ model.setLabelSum(value.get());
+ model.setTotalSum(value.get().zSum());
+ labelCount = value.get().size();
+ }
+ }
+ reader.close();
+
+ // read the class matrix
+ reader = new SequenceFile.Reader(fs, classVectorPath, conf);
+ IntWritable label = new IntWritable();
+ Matrix matrix = new SparseMatrix(new int[] {labelCount, featureCount});
+ while (reader.next(label, value)) {
+ matrix.assignRow(label.get(), value.get());
+ }
+ reader.close();
+
+ model.setWeightMatrix(matrix);
+
+
+
+ reader = new SequenceFile.Reader(fs, thetaSumPath, conf);
+ // read theta normalizer
+ while (reader.next(key, value)) {
+ if (key.toString().equals(BayesConstants.LABEL_THETA_NORMALIZER)) {
+ model.setPerlabelThetaNormalizer(value.get());
+ }
+ }
+ reader.close();
+
+ return model;
+ }
+
+ /**
+ * Encode this NaiveBayesModel as a JSON string
+ *
+ * @return String containing the JSON of this model
+ */
+ public String toJson() {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(NaiveBayesModel.class, this);
+ Gson gson = builder.create();
+ return gson.toJson(this);
+ }
+
+ /**
+ * Decode this NaiveBayesModel from a JSON string
+ *
+ * @param json String containing JSON representation of this model
+ * @return Initialized model
+ */
+ public static NaiveBayesModel fromJson(String json) {
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(NaiveBayesModel.class, new NaiveBayesModel());
+ Gson gson = builder.create();
+ return gson.fromJson(json, NaiveBayesModel.class);
+ }
+
+ private static final String MODEL = "NaiveBayesModel";
+
+ @Override
+ public JsonElement serialize(NaiveBayesModel model,
+ Type type,
+ JsonSerializationContext context) {
+ // now register the builders for matrix / vector
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Matrix.class, new JsonMatrixAdapter());
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ Gson gson = builder.create();
+ // create a model
+ JsonObject json = new JsonObject();
+ // first, we add the model
+ json.add(MODEL, new JsonPrimitive(gson.toJson(model)));
+ return json;
+ }
+
+ @Override
+ public NaiveBayesModel deserialize(JsonElement json,
+ Type type,
+ JsonDeserializationContext context) throws JsonParseException {
+ // register the builders for matrix / vector
+ GsonBuilder builder = new GsonBuilder();
+ builder.registerTypeAdapter(Matrix.class, new JsonMatrixAdapter());
+ builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
+ Gson gson = builder.create();
+ // now decode the original model
+ JsonObject obj = json.getAsJsonObject();
+ String modelString = obj.get(MODEL).getAsString();
+ NaiveBayesModel model = gson.fromJson(modelString, NaiveBayesModel.class);
+
+ // return the model
+ return model;
+ }
+
+ public static void validate(NaiveBayesModel model) {
+ if (model == null) {
+ return; // empty models are valid
+ }
+
+ if (model.getAlphaI() <= 0) {
+ throw new IllegalArgumentException(
+ "Error: AlphaI has to be greater than 0!");
+ }
+
+ if (model.getVocabCount() <= 0) {
+ throw new IllegalArgumentException(
+ "Error: The vocab count has to be greater than 0!");
+ }
+
+ if (model.getVocabCount() <= 0) {
+ throw new IllegalArgumentException(
+ "Error: The vocab count has to be greater than 0!");
+ }
+
+ if (model.getTotalSum() <= 0) {
+ throw new IllegalArgumentException(
+ "Error: The vocab count has to be greater than 0!");
+ }
+
+ if (model.getLabelSum() == null || model.getLabelSum().getNumNondefaultElements() <= 0) {
+ throw new IllegalArgumentException(
+ "Error: The number of labels has to be greater than 0 or defined!");
+ }
+
+ if (model.getPerlabelThetaNormalizer() == null ||
+ model.getPerlabelThetaNormalizer().getNumNondefaultElements() <= 0) {
+ throw new IllegalArgumentException(
+ "Error: The number of theta normalizers has to be greater than 0 or defined!");
+ }
+
+ if (model.getFeatureSum() == null ||model.getFeatureSum().getNumNondefaultElements() <= 0) {
+ throw new IllegalArgumentException(
+ "Error: The number of features has to be greater than 0 or defined!");
+ }
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm
+ *
+ */
+public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier {
+
+ public StandardNaiveBayesClassifier(NaiveBayesModel model) {
+ super(model);
+ }
+
+ @Override
+ public double getScoreForLabelFeature(int label, int feature) {
+ double result = model.getWeightMatrix().get(label, feature);
+ double vocabCount = model.getVocabCount();
+ double sumLabelWeight = model.getLabelSum().get(label);
+ double numerator = result + model.getAlphaI();
+ double denominator = sumLabelWeight + vocabCount;
+ double weight = - Math.log(numerator / denominator);
+ result = weight / model.getPerlabelThetaNormalizer().get(label);
+ return result;
+ }
+
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesInstanceMapper.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+import java.net.URI;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class NaiveBayesInstanceMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
+
+ private OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
+
+ @Override
+ protected void map(Text key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ if (!labelMap.containsKey(key.toString())) {
+ context.getCounter("NaiveBayes", "Skipped instance: not in label list");
+ return;
+ }
+ int label = labelMap.get(key.toString());
+ context.write(new IntWritable(label), value);
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ try {
+ URI[] localFiles = DistributedCache.getCacheFiles(conf);
+ if (localFiles == null || localFiles.length < 1) {
+ throw new IllegalArgumentException("missing paths from the DistributedCache");
+ }
+ Path labelMapFile = new Path(localFiles[0].getPath());
+ FileSystem fs = labelMapFile.getFileSystem(conf);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, labelMapFile, conf);
+ Writable key = new Text();
+ IntWritable value = new IntWritable();
+
+ // key is word value is id
+ while (reader.next(key, value)) {
+ labelMap.put(key.toString(), value.get());
+ }
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesSumReducer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesSumReducer.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesSumReducer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesSumReducer.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * Can also be used as a local Combiner. This accumulates all the features and the weights and sums them up.
+ */
+public class NaiveBayesSumReducer extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context context)
+ throws IOException, InterruptedException {
+ Vector vector = null;
+ for (VectorWritable v : values) {
+ if (vector == null) {
+ vector = v.get();
+ } else {
+ v.get().addTo(vector);
+ }
+ }
+ context.write(key, new VectorWritable(vector));
+ }
+
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaComplementaryMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaComplementaryMapper.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaComplementaryMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaComplementaryMapper.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,112 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesConstants;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class NaiveBayesThetaComplementaryMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+ private OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
+ private Vector featureSum;
+ private Vector labelSum;
+ private Vector perLabelThetaNormalizer;
+ private double alphaI = 1.0;
+ private double vocabCount;
+ private double totalSum = 0;
+
+ @Override
+ protected void map(IntWritable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ Vector vector = value.get();
+ int label = key.get();
+ double sigmaK = labelSum.get(label);
+ Iterator<Element> it = vector.iterateNonZero();
+ while (it.hasNext()) {
+ Element e = it.next();
+ double numerator = featureSum.get(e.index()) - e.get() + alphaI;
+ double denominator = totalSum - sigmaK + vocabCount;
+ double weight = Math.log(numerator / denominator);
+ perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + weight);
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ try {
+ URI[] localFiles = DistributedCache.getCacheFiles(conf);
+ if (localFiles == null || localFiles.length < 2) {
+ throw new IllegalArgumentException("missing paths from the DistributedCache");
+ }
+ alphaI = conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f);
+ Path weightFile = new Path(localFiles[0].getPath());
+ FileSystem fs = weightFile.getFileSystem(conf);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, weightFile, conf);
+ Text key = new Text();
+ VectorWritable value = new VectorWritable();
+
+ while (reader.next(key, value)) {
+ if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
+ featureSum = value.get();
+ } else if (key.toString().equals(BayesConstants.LABEL_SUM)) {
+ labelSum = value.get();
+ }
+ }
+ perLabelThetaNormalizer = labelSum.like();
+ totalSum = labelSum.zSum();
+ vocabCount = featureSum.getNumNondefaultElements();
+
+ Path labelMapFile = new Path(localFiles[1].getPath());
+ fs = labelMapFile.getFileSystem(conf);
+
+ reader.close();
+ reader = new SequenceFile.Reader(fs, labelMapFile, conf);
+ IntWritable intValue = new IntWritable();
+
+ // key is word value is id
+ while (reader.next(key, intValue)) {
+ labelMap.put(key.toString(), intValue.get());
+ }
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ context.write(new Text(BayesConstants.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer));
+ super.cleanup(context);
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesThetaMapper.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,101 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+import java.net.URI;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesConstants;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class NaiveBayesThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+ private OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
+ private Vector featureSum;
+ private Vector labelSum;
+ private Vector perLabelThetaNormalizer;
+ private double alphaI = 1.0;
+ private double vocabCount;
+
+ @Override
+ protected void map(IntWritable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ Vector vector = value.get();
+ int label = key.get();
+ double weight = Math.log((vector.zSum() + alphaI) / (labelSum.get(label) + vocabCount));
+ perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + weight);
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ try {
+ URI[] localFiles = DistributedCache.getCacheFiles(conf);
+ if (localFiles == null || localFiles.length < 2) {
+ throw new IllegalArgumentException("missing paths from the DistributedCache");
+ }
+ alphaI = conf.getFloat(NaiveBayesTrainer.ALPHA_I, 1.0f);
+ Path weightFile = new Path(localFiles[0].getPath());
+ FileSystem fs = weightFile.getFileSystem(conf);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, weightFile, conf);
+ Text key = new Text();
+ VectorWritable value = new VectorWritable();
+
+ while (reader.next(key, value)) {
+ if (key.toString().equals(BayesConstants.FEATURE_SUM)) {
+ featureSum = value.get();
+ } else if (key.toString().equals(BayesConstants.LABEL_SUM)) {
+ labelSum = value.get();
+ }
+ }
+ perLabelThetaNormalizer = labelSum.like();
+ vocabCount = featureSum.getNumNondefaultElements();
+
+ Path labelMapFile = new Path(localFiles[1].getPath());
+ fs = labelMapFile.getFileSystem(conf);
+
+ reader.close();
+ reader = new SequenceFile.Reader(fs, labelMapFile, conf);
+ IntWritable intValue = new IntWritable();
+
+ // key is word value is id
+ while (reader.next(key, intValue)) {
+ labelMap.put(key.toString(), intValue.get());
+ }
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ context.write(new Text(BayesConstants.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer));
+ super.cleanup(context);
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesTrainer.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,202 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * This class trains a Naive Bayes Classifier (Parameters for both Naive Bayes and Complementary Naive Bayes)
+ *
+ *
+ */
+public final class NaiveBayesTrainer {
+
+ public static final String THETA_SUM = "thetaSum";
+ public static final String SUM_VECTORS = "sumVectors";
+ public static final String CLASS_VECTORS = "classVectors";
+ public static final String LABEL_MAP = "labelMap";
+ public static final String ALPHA_I = "alphaI";
+
+ public static void trainNaiveBayes(Path input,
+ Configuration conf,
+ List<String> inputLabels,
+ Path output,
+ int numReducers,
+ float alphaI,
+ boolean trainComplementary)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ conf.setFloat(ALPHA_I, alphaI);
+ Path labelMapPath = createLabelMapFile(inputLabels, conf, new Path(output, LABEL_MAP));
+ Path classVectorPath = new Path(output, CLASS_VECTORS);
+ runNaiveBayesByLabelSummer(input, conf, labelMapPath, classVectorPath, numReducers);
+ Path weightFilePath = new Path(output, SUM_VECTORS);
+ runNaiveBayesWeightSummer(classVectorPath, conf, labelMapPath, weightFilePath, numReducers);
+ Path thetaFilePath = new Path(output, THETA_SUM);
+ if (trainComplementary) {
+ runNaiveBayesThetaComplementarySummer(classVectorPath, conf, weightFilePath, thetaFilePath, numReducers);
+ } else {
+ runNaiveBayesThetaSummer(classVectorPath, conf, weightFilePath, thetaFilePath, numReducers);
+ }
+ }
+
+ private static void runNaiveBayesByLabelSummer(Path input, Configuration conf, Path labelMapPath,
+ Path output, int numReducers)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ DistributedCache.setCacheFiles(new URI[] {labelMapPath.toUri()}, conf);
+
+ Job job = new Job(conf);
+ job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
+ + labelMapPath.toString());
+ job.setJarByClass(NaiveBayesTrainer.class);
+ FileInputFormat.setInputPaths(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+ job.setMapperClass(NaiveBayesInstanceMapper.class);
+ job.setCombinerClass(NaiveBayesSumReducer.class);
+ job.setReducerClass(NaiveBayesSumReducer.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setNumReduceTasks(numReducers);
+ HadoopUtil.overwriteOutput(output);
+ job.waitForCompletion(true);
+ }
+
+ private static void runNaiveBayesWeightSummer(Path input, Configuration conf,
+ Path labelMapPath, Path output, int numReducers)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ DistributedCache.setCacheFiles(new URI[] {labelMapPath.toUri()}, conf);
+
+ Job job = new Job(conf);
+ job.setJobName("Train Naive Bayes: input-folder: " + input);
+ job.setJarByClass(NaiveBayesTrainer.class);
+ FileInputFormat.setInputPaths(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+ job.setMapperClass(NaiveBayesWeightsMapper.class);
+ job.setReducerClass(NaiveBayesSumReducer.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setNumReduceTasks(numReducers);
+ HadoopUtil.overwriteOutput(output);
+ job.waitForCompletion(true);
+ }
+
+ private static void runNaiveBayesThetaSummer(Path input, Configuration conf,
+ Path weightFilePath, Path output, int numReducers)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ DistributedCache.setCacheFiles(new URI[] {weightFilePath.toUri()}, conf);
+
+ Job job = new Job(conf);
+ job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
+ + weightFilePath.toString());
+ job.setJarByClass(NaiveBayesTrainer.class);
+ FileInputFormat.setInputPaths(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+ job.setMapperClass(NaiveBayesThetaMapper.class);
+ job.setReducerClass(NaiveBayesSumReducer.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setNumReduceTasks(numReducers);
+ HadoopUtil.overwriteOutput(output);
+ job.waitForCompletion(true);
+ }
+
+ private static void runNaiveBayesThetaComplementarySummer(Path input, Configuration conf,
+ Path weightFilePath, Path output, int numReducers)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ DistributedCache.setCacheFiles(new URI[] {weightFilePath.toUri()}, conf);
+
+ Job job = new Job(conf);
+ job.setJobName("Train Naive Bayes: input-folder: " + input + ", label-map-file: "
+ + weightFilePath.toString());
+ job.setJarByClass(NaiveBayesTrainer.class);
+ FileInputFormat.setInputPaths(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+ job.setMapperClass(NaiveBayesThetaComplementaryMapper.class);
+ job.setReducerClass(NaiveBayesSumReducer.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setNumReduceTasks(numReducers);
+ HadoopUtil.overwriteOutput(output);
+ job.waitForCompletion(true);
+ }
+
+
+
+ /**
+ * Write the list of labels into a map file
+ *
+ * @param wordCountPath
+ * @param dictionaryPathBase
+ * @throws IOException
+ */
+ public static Path createLabelMapFile(List<String> labels,
+ Configuration conf,
+ Path labelMapPathBase) throws IOException {
+ FileSystem fs = FileSystem.get(labelMapPathBase.toUri(), conf);
+ Path labelMapPath = new Path(labelMapPathBase, LABEL_MAP);
+
+ SequenceFile.Writer dictWriter = new SequenceFile.Writer(fs, conf, labelMapPath, Text.class, IntWritable.class);
+ int i = 0;
+ for (String label : labels) {
+ Writable key = new Text(label);
+ dictWriter.append(key, new IntWritable(i++));
+ }
+ return labelMapPath;
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/trainer/NaiveBayesWeightsMapper.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.trainer;
+
+import java.io.IOException;
+import java.net.URI;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesConstants;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class NaiveBayesWeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+ private OpenObjectIntHashMap<String> labelMap = new OpenObjectIntHashMap<String>();
+ Vector featureSum;
+ Vector labelSum;
+
+ @Override
+ protected void map(IntWritable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ Vector vector = value.get();
+ if (featureSum == null) {
+ featureSum = new RandomAccessSparseVector(vector.size(), vector.getNumNondefaultElements());
+ labelSum = new RandomAccessSparseVector(labelMap.size());
+ }
+
+ int label = key.get();
+ vector.addTo(featureSum);
+ labelSum.set(label, labelSum.get(label) + vector.zSum());
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ try {
+ URI[] localFiles = DistributedCache.getCacheFiles(conf);
+ if (localFiles == null || localFiles.length < 1) {
+ throw new IllegalArgumentException("missing paths from the DistributedCache");
+ }
+ Path labelMapFile = new Path(localFiles[0].getPath());
+ FileSystem fs = labelMapFile.getFileSystem(conf);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, labelMapFile, conf);
+ Writable key = new Text();
+ IntWritable value = new IntWritable();
+
+ // key is word value is id
+ while (reader.next(key, value)) {
+ labelMap.put(key.toString(), value.get());
+ }
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ context.write(new Text(BayesConstants.FEATURE_SUM), new VectorWritable(featureSum));
+ context.write(new Text(BayesConstants.LABEL_SUM), new VectorWritable(labelSum));
+ super.cleanup(context);
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/common/IntTuple.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntTuple.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/common/IntTuple.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/common/IntTuple.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,170 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.mahout.math.list.IntArrayList;
+
+/**
+ * An Ordered List of Integers which can be used in a Hadoop Map/Reduce Job
+ */
+public final class IntTuple implements WritableComparable<IntTuple> {
+
+ private IntArrayList tuple = new IntArrayList();
+
+ public IntTuple() {}
+
+ public IntTuple(int firstEntry) {
+ add(firstEntry);
+ }
+
+ public IntTuple(Iterable<Integer> entries) {
+ for (Integer entry : entries) {
+ add(entry);
+ }
+ }
+
+ public IntTuple(int[] entries) {
+ for (int entry : entries) {
+ add(entry);
+ }
+ }
+
+ /**
+ * add an entry to the end of the list
+ *
+ * @param entry
+ * @return true if the items get added
+ */
+ public void add(int entry) {
+ tuple.add(entry);
+ }
+
+ /**
+ * Fetches the string at the given location
+ *
+ * @param index
+ * @return Integer value at the given location in the tuple list
+ */
+ public int at(int index) {
+ return tuple.get(index);
+ }
+
+ /**
+ * Replaces the string at the given index with the given newInteger
+ *
+ * @param index
+ * @param newInteger
+ * @return The previous value at that location
+ */
+ public int replaceAt(int index, int newInteger) {
+ int old = tuple.get(index);
+ tuple.set(index, newInteger);
+ return old;
+ }
+
+ /**
+ * Fetch the list of entries from the tuple
+ *
+ * @return a List containing the strings in the order of insertion
+ */
+ public IntArrayList getEntries() {
+ return new IntArrayList(this.tuple.elements());
+ }
+
+ /**
+ * Returns the length of the tuple
+ *
+ * @return length
+ */
+ public int length() {
+ return this.tuple.size();
+ }
+
+ @Override
+ public int hashCode() {
+ return tuple.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ IntTuple other = (IntTuple) obj;
+ if (tuple == null) {
+ if (other.tuple != null) {
+ return false;
+ }
+ } else if (!tuple.equals(other.tuple)) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int len = in.readInt();
+ tuple = new IntArrayList(len);
+ IntWritable value = new IntWritable();
+ for (int i = 0; i < len; i++) {
+ value.readFields(in);
+ tuple.add(value.get());
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(tuple.size());
+ IntWritable value = new IntWritable();
+ for (int entry : tuple.elements()) {
+ value.set(entry);
+ value.write(out);
+ }
+ }
+
+ @Override
+ public int compareTo(IntTuple otherTuple) {
+ int thisLength = length();
+ int otherLength = otherTuple.length();
+ int min = Math.min(thisLength, otherLength);
+ for (int i = 0; i < min; i++) {
+ if (this.tuple.get(i) == otherTuple.at(i)) return 0;
+ return this.tuple.get(i) - otherTuple.at(i);
+ }
+ if (thisLength < otherLength) {
+ return -1;
+ } else if (thisLength > otherLength) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,31 @@
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.math.DenseVector;
+import org.junit.Before;
+import org.junit.Test;
+
+
+public final class ComplementaryNaiveBayesClassifierTest extends NaiveBayesTestBase{
+
+ NaiveBayesModel model;
+ ComplementaryNaiveBayesClassifier classifier;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ model = createComplementaryNaiveBayesModel();
+ classifier = new ComplementaryNaiveBayesClassifier(model);
+ }
+
+ @Test
+ public void testNaiveBayes() throws Exception {
+ assertEquals(classifier.numCategories(), 4);
+ assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] {1.0, 0.0, 0.0, 0.0}))));
+ assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 1.0, 0.0, 0.0}))));
+ assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 1.0, 0.0}))));
+ assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 0.0, 1.0}))));
+
+ }
+
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.junit.Test;
+
+public class NaiveBayesModelTest extends NaiveBayesTestBase {
+
+ @Test
+ public void testRandomModelGeneration() {
+ // make sure we generate a valid random model
+ NaiveBayesModel model = getModel();
+ // check whether the model is valid
+ NaiveBayesModel.validate(model);
+ }
+
+ @Test
+ public void testSerialization() {
+ String serialized = getModel().toJson();
+ NaiveBayesModel model2 = NaiveBayesModel.fromJson(serialized);
+ String serialized2 = model2.toJson();
+ // since there are no equals methods for the underlying objects, we
+ // check identity via the serialization string
+ assertEquals(serialized, serialized2);
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,116 @@
+package org.apache.mahout.classifier.naivebayes;
+
+import java.util.Iterator;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+public class NaiveBayesTestBase extends MahoutTestCase {
+
+ private NaiveBayesModel model;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ model = createNaiveBayesModel();
+
+ // make sure the model is valid :)
+ NaiveBayesModel.validate(model);
+ }
+
+ protected NaiveBayesModel getModel() {
+ return model;
+ }
+
+ public double complementaryNaiveBayesThetaWeight(int label,
+ Matrix weightMatrix,
+ Vector labelSum,
+ Vector featureSum) {
+ double weight = 0.0;
+ double alpha = 1.0d;
+ for (int i = 0; i < featureSum.size(); i++) {
+ double score = weightMatrix.get(i, label);
+ double lSum = labelSum.get(label);
+ double fSum = featureSum.get(i);
+ double totalSum = featureSum.zSum();
+ double numerator = fSum - score + alpha;
+ double denominator = totalSum - lSum + featureSum.size();
+ weight += Math.log(numerator / denominator);
+ }
+ return weight;
+ }
+
+ public double naiveBayesThetaWeight(int label,
+ Matrix weightMatrix,
+ Vector labelSum,
+ Vector featureSum) {
+ double weight = 0.0;
+ double alpha = 1.0d;
+ for (int i = 0; i < featureSum.size(); i++) {
+ double score = weightMatrix.get(i, label);
+ double lSum = labelSum.get(label);
+ double numerator = score + alpha;
+ double denominator = lSum + featureSum.size();
+ weight += Math.log(numerator / denominator);
+ }
+ return weight;
+ }
+
+ public NaiveBayesModel createNaiveBayesModel() {
+ double[][] matrix = { {0.7, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
+ {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
+ double[] labelSumArray = {1.2, 1.0, 1.0, 1.0};
+ double[] featureSumArray = {1.3, 0.6, 1.1, 1.2};
+
+ DenseMatrix weightMatrix = new DenseMatrix(matrix);
+ DenseVector labelSum = new DenseVector(labelSumArray);
+ DenseVector featureSum = new DenseVector(featureSumArray);
+
+ double[] thetaNormalizerSum = {naiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum),
+ naiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
+ naiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
+ naiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum)};
+ // now generate the model
+ NaiveBayesModel model = new NaiveBayesModel(weightMatrix, featureSum,
+ labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
+ return model;
+ }
+
+ public NaiveBayesModel createComplementaryNaiveBayesModel() {
+ double[][] matrix = { {0.7, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
+ {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
+ double[] labelSumArray = {1.2, 1.0, 1.0, 1.0};
+ double[] featureSumArray = {1.3, 0.6, 1.1, 1.2};
+
+ DenseMatrix weightMatrix = new DenseMatrix(matrix);
+ DenseVector labelSum = new DenseVector(labelSumArray);
+ DenseVector featureSum = new DenseVector(featureSumArray);
+
+ double[] thetaNormalizerSum = {complementaryNaiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum),
+ complementaryNaiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
+ complementaryNaiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
+ complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum)};
+ // now generate the model
+ NaiveBayesModel model = new NaiveBayesModel(weightMatrix, featureSum,
+ labelSum, new DenseVector(thetaNormalizerSum), 1.0f);
+ return model;
+ }
+
+ public int maxIndex(Vector instance) {
+ Iterator<Element> it = instance.iterator();
+ int maxIndex = -1;
+ double val = Integer.MIN_VALUE;
+ while (it.hasNext()) {
+ Element e = it.next();
+ if (val <= e.get()) {
+ maxIndex = e.index();
+ val = e.get();
+ }
+ }
+ return maxIndex;
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java?rev=1005262&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java Wed Oct 6 21:38:41 2010
@@ -0,0 +1,31 @@
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.math.DenseVector;
+import org.junit.Before;
+import org.junit.Test;
+
+
+public final class StandardNaiveBayesClassifierTest extends NaiveBayesTestBase{
+
+ NaiveBayesModel model;
+ StandardNaiveBayesClassifier classifier;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ model = createNaiveBayesModel();
+ classifier = new StandardNaiveBayesClassifier(model);
+ }
+
+ @Test
+ public void testNaiveBayes() throws Exception {
+ assertEquals(classifier.numCategories(), 4);
+ assertEquals(0, maxIndex(classifier.classify(new DenseVector(new double[] {1.0, 0.0, 0.0, 0.0}))));
+ assertEquals(1, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 1.0, 0.0, 0.0}))));
+ assertEquals(2, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 1.0, 0.0}))));
+ assertEquals(3, maxIndex(classifier.classify(new DenseVector(new double[] {0.0, 0.0, 0.0, 1.0}))));
+
+ }
+
+}