You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2012/06/04 02:04:20 UTC
svn commit: r1345807 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/classifier/naivebayes/
core/src/main/java/org/apache/mahout/classifier/naivebayes/training/
examples/bin/ src/conf/
Author: robinanil
Date: Mon Jun 4 00:04:20 2012
New Revision: 1345807
URL: http://svn.apache.org/viewvc?rev=1345807&view=rev
Log:
MAHOUT-1006 Example of 20newsgroups using new naivebayes package, gets 91% accuracy for 20% random split of the dataset
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
mahout/trunk/examples/bin/classify-20newsgroups.sh
mahout/trunk/src/conf/driver.classes.props
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java?rev=1345807&r1=1345806&r2=1345807&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java Mon Jun 4 00:04:20 2012
@@ -45,7 +45,7 @@ public abstract class AbstractNaiveBayes
Element e = elements.next();
result += e.get() * getScoreForLabelFeature(label, e.index());
}
- return result / model.thetaNormalizer(label);
+ return -result;
}
@Override
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java?rev=1345807&r1=1345806&r2=1345807&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java Mon Jun 4 00:04:20 2012
@@ -33,7 +33,6 @@ public class ComplementaryNaiveBayesClas
NaiveBayesModel model = getModel();
double numerator = model.featureWeight(feature) - model.weight(label, feature) + model.alphaI();
double denominator = model.totalWeightSum() - model.labelWeight(label) + model.alphaI() * model.numFeatures();
-
return Math.log(numerator / denominator);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java?rev=1345807&r1=1345806&r2=1345807&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/AbstractThetaTrainer.java Mon Jun 4 00:04:20 2012
@@ -61,7 +61,7 @@ public abstract class AbstractThetaTrain
protected double featureWeight(int feature) {
return weightsPerFeature.get(feature);
}
-
+
protected void updatePerLabelThetaNormalizer(int label, double weight) {
perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + weight);
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java?rev=1345807&r1=1345806&r2=1345807&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java Mon Jun 4 00:04:20 2012
@@ -28,13 +28,13 @@ public class ComplementaryThetaTrainer e
}
@Override
- public void train(int label, Vector instance) {
+ public void train(int label, Vector perLabelWeight) {
double sigmaK = labelWeight(label);
- Iterator<Vector.Element> it = instance.iterateNonZero();
+ Iterator<Vector.Element> it = perLabelWeight.iterateNonZero();
while (it.hasNext()) {
Vector.Element e = it.next();
double numerator = featureWeight(e.index()) - e.get() + alphaI();
- double denominator = totalWeightSum() - sigmaK + alphaI() * numFeatures();
+ double denominator = totalWeightSum() - sigmaK + numFeatures() ;
double weight = Math.log(numerator / denominator);
updatePerLabelThetaNormalizer(label, weight);
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java?rev=1345807&r1=1345806&r2=1345807&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/naivebayes/training/StandardThetaTrainer.java Mon Jun 4 00:04:20 2012
@@ -28,13 +28,13 @@ public class StandardThetaTrainer extend
}
@Override
- public void train(int label, Vector instance) {
+ public void train(int label, Vector perLabelWeight) {
double sigmaK = labelWeight(label);
- Iterator<Vector.Element> it = instance.iterateNonZero();
+ Iterator<Vector.Element> it = perLabelWeight.iterateNonZero();
while (it.hasNext()) {
Vector.Element e = it.next();
double numerator = e.get() + alphaI();
- double denominator = sigmaK + alphaI() * numFeatures();
+ double denominator = sigmaK + numFeatures();
double weight = Math.log(numerator / denominator);
updatePerLabelThetaNormalizer(label, weight);
}
Modified: mahout/trunk/examples/bin/classify-20newsgroups.sh
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/bin/classify-20newsgroups.sh?rev=1345807&r1=1345806&r2=1345807&view=diff
==============================================================================
--- mahout/trunk/examples/bin/classify-20newsgroups.sh (original)
+++ mahout/trunk/examples/bin/classify-20newsgroups.sh Mon Jun 4 00:04:20 2012
@@ -23,7 +23,7 @@
# examples/bin/build-20news.sh
if [ "$1" = "--help" ] || [ "$1" = "--?" ]; then
- echo "This script runs the SGD classifier over the classic 20 News Groups."
+ echo "This script runs SGD and Bayes classifiers over the classic 20 News Groups."
exit
fi
@@ -34,13 +34,14 @@ fi
START_PATH=`pwd`
WORK_DIR=/tmp/mahout-work-${USER}
-algorithm=( sgd clean)
+algorithm=( naivebayes sgd clean)
if [ -n "$1" ]; then
choice=$1
else
echo "Please select a number to choose the corresponding task to run"
echo "1. ${algorithm[0]}"
- echo "2. ${algorithm[1]} -- cleans up the work area in $WORK_DIR"
+ echo "2. ${algorithm[1]}"
+ echo "3. ${algorithm[2]} -- cleans up the work area in $WORK_DIR"
read -p "Enter your choice : " choice
fi
@@ -67,7 +68,54 @@ cd ../..
set -e
-if [ "x$alg" == "xsgd" ]; then
+if [ "x$alg" == "xnaivebayes" ]; then
+ set -x
+ echo "Preparing Training Data"
+ rm -rf ${WORK_DIR}/20news-all
+ mkdir ${WORK_DIR}/20news-all
+ cp -R ${WORK_DIR}/20news-bydate/*/* ${WORK_DIR}/20news-all
+
+ echo "Creating sequence files from 20newsgroups data"
+ ./bin/mahout seqdirectory \
+ -i ${WORK_DIR}/20news-all \
+ -o ${WORK_DIR}/20news-seq
+
+ echo "Converting sequence files to vectors"
+ ./bin/mahout seq2sparse \
+ -i ${WORK_DIR}/20news-seq \
+ -o ${WORK_DIR}/20news-vectors -lnorm -nv -wt tfidf
+
+ echo "Creating training and holdout set with a random 20% split of whole dataset"
+ ./bin/mahout split \
+ -i ${WORK_DIR}/20news-vectors/tfidf-vectors \
+ --trainingOutput ${WORK_DIR}/20news-train-vectors \
+ --testOutput ${WORK_DIR}/20news-test-vectors \
+ --randomSelectionPct 20 --overwrite --sequenceFiles -xm sequential
+
+ echo "Training Naive Bayes model"
+ ./bin/mahout trainnb \
+ -i ${WORK_DIR}/20news-train-vectors -el \
+ -o ${WORK_DIR}/model \
+ -li ${WORK_DIR}/labelindex \
+ -ow -c
+
+ echo "Self testing on training set"
+
+ ./bin/mahout testnb \
+ -i ${WORK_DIR}/20news-train-vectors\
+ -m ${WORK_DIR}/model \
+ -l ${WORK_DIR}/labelindex \
+ -ow -o ${WORK_DIR}/20news-testing
+
+ echo "Testing on holdout set"
+
+ ./bin/mahout testnb \
+ -i ${WORK_DIR}/20news-test-vectors\
+ -m ${WORK_DIR}/model \
+ -l ${WORK_DIR}/labelindex \
+ -ow -o ${WORK_DIR}/20news-testing
+
+elif [ "x$alg" == "xsgd" ]; then
if [ ! -e "/tmp/news-group.model" ]; then
echo "Training on ${WORK_DIR}/20news-bydate/20news-bydate-train/"
./bin/mahout org.apache.mahout.classifier.sgd.TrainNewsGroups ${WORK_DIR}/20news-bydate/20news-bydate-train/
Modified: mahout/trunk/src/conf/driver.classes.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1345807&r1=1345806&r2=1345807&view=diff
==============================================================================
--- mahout/trunk/src/conf/driver.classes.props (original)
+++ mahout/trunk/src/conf/driver.classes.props Mon Jun 4 00:04:20 2012
@@ -60,4 +60,7 @@ org.apache.mahout.cf.taste.hadoop.als.Fa
org.apache.mahout.cf.taste.hadoop.similarity.item.ItemSimilarityJob = itemsimilarity : Compute the item-item-similarities for item-based collaborative filtering
org.apache.mahout.cf.taste.hadoop.item.RecommenderJob = recommenditembased : Compute recommendations using item-based collaborative filtering
org.apache.mahout.cf.taste.hadoop.als.ParallelALSFactorizationJob = parallelALS : ALS-WR factorization of a rating matrix
-org.apache.mahout.cf.taste.hadoop.als.RecommenderJob = recommendfactorized : Compute recommendations using the factorization of a rating matrix
\ No newline at end of file
+org.apache.mahout.cf.taste.hadoop.als.RecommenderJob = recommendfactorized : Compute recommendations using the factorization of a rating matrix
+prepare20newsgroups = deprecated : Use new naivebayes classifier see examples/bin/classify-20newsgroups.sh
+trainclassifier = deprecated : Use new naivebayes classifier see examples/bin/classify-20newsgroups.sh
+testclassifier = deprecated : Use new naivebayes classifier see examples/bin/classify-20newsgroups.sh