You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2009/10/10 20:17:31 UTC
svn commit: r823911 [2/2] - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/classifier/bayes/datastore/
core/src/main/java/org/apache/mahout/classifier/bayes/io/
core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/bayes/
core/...
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfReducer.java?rev=823911&r1=823910&r2=823911&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesTfIdfReducer.java Sat Oct 10 18:17:30 2009
@@ -22,56 +22,61 @@
import org.apache.hadoop.hbase.client.Put;
import org.apache.hadoop.hbase.util.Bytes;
import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.common.Parameters;
+import org.apache.mahout.common.StringTuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.Iterator;
-
/**
* Can also be used as a local Combiner beacuse only two values should be there
* inside the values
*/
public class BayesTfIdfReducer extends MapReduceBase implements
- Reducer<Text, DoubleWritable, Text, DoubleWritable> {
+ Reducer<StringTuple, DoubleWritable, StringTuple, DoubleWritable> {
private static final Logger log = LoggerFactory
.getLogger(BayesTfIdfReducer.class);
private HTable table;
-
+ private HBaseConfiguration HBconf; //reloading configuration causes
+ //a new htable session to get
+ //created(from HBASE IRC)
+
boolean useHbase = false;
@Override
- public void reduce(Text key, Iterator<DoubleWritable> values,
- OutputCollector<Text, DoubleWritable> output, Reporter reporter)
+ public void reduce(StringTuple key, Iterator<DoubleWritable> values,
+ OutputCollector<StringTuple, DoubleWritable> output, Reporter reporter)
throws IOException {
// Key is label,word, value is the number of times we've seen this label
// word per local node. Output is the same
- String token = key.toString();
- if (token.startsWith("*vocabCount")) {
+
+ if (key.stringAt(0).equals(BayesConstants.FEATURE_SET_SIZE)) {
double vocabCount = 0.0;
+
while (values.hasNext()) {
reporter.setStatus("Bayes TfIdf Reducer: vocabCount " + vocabCount);
vocabCount += values.next().get();
}
- log.info("{}\t{}", token, vocabCount);
+
+ log.info("{}\t{}", key, vocabCount);
if (useHbase) {
- Put bu = new Put(Bytes.toBytes("*totalCounts"));
- bu.add(Bytes.toBytes("label"), Bytes.toBytes("vocabCount"), Bytes
+ Put bu = new Put(Bytes.toBytes(BayesConstants.HBASE_COUNTS_ROW));
+ bu.add(Bytes.toBytes(BayesConstants.HBASE_COLUMN_FAMILY), Bytes
+ .toBytes(BayesConstants.FEATURE_SET_SIZE), Bytes
.toBytes(vocabCount));
table.put(bu);
}
output.collect(key, new DoubleWritable(vocabCount));
- } else {
+ } else if (key.stringAt(0).equals(BayesConstants.WEIGHT)) {
double idfTimes_D_ij = 1.0;
int numberofValues = 0;
while (values.hasNext()) {
@@ -79,14 +84,12 @@
numberofValues++;
}
if (numberofValues == 2) { // Found TFIdf
-
- int comma = token.indexOf(',');
- String label = comma < 0 ? token : token.substring(0, comma);
- String feature = token.substring(label.length() + 1);
+ String label = key.stringAt(1);
+ String feature = key.stringAt(2);
if (useHbase) {
Put bu = new Put(Bytes.toBytes(feature));
- bu.add(Bytes.toBytes("label"), Bytes.toBytes(label), Bytes
- .toBytes(idfTimes_D_ij));
+ bu.add(Bytes.toBytes(BayesConstants.HBASE_COLUMN_FAMILY), Bytes
+ .toBytes(label), Bytes.toBytes(idfTimes_D_ij));
table.put(bu);
}
@@ -94,18 +97,22 @@
reporter
.setStatus("Bayes TfIdf Reducer: " + key + " => " + idfTimes_D_ij);
output.collect(key, new DoubleWritable(idfTimes_D_ij));
+ } else {
+ throw new RuntimeException("Unexpected StringTuple: " + key);
}
}
@Override
public void configure(JobConf job) {
try {
- Parameters params = Parameters.fromString(job.get(
- "bayes.parameters", ""));
- if(params.get("dataSource").equals("hbase"))useHbase = true;
- else return;
+ Parameters params = Parameters
+ .fromString(job.get("bayes.parameters", ""));
+ if (params.get("dataSource").equals("hbase"))
+ useHbase = true;
+ else
+ return;
- HBaseConfiguration HBconf = new HBaseConfiguration(job);
+ HBconf = new HBaseConfiguration(job);
table = new HTable(HBconf, job.get("output.table"));
@@ -118,8 +125,8 @@
@Override
public void close() throws IOException {
if (useHbase) {
- table.close();
- }
+ table.close();
+ }
super.close();
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerDriver.java?rev=823911&r1=823910&r2=823911&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerDriver.java Sat Oct 10 18:17:30 2009
@@ -20,13 +20,13 @@
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.FileInputFormat;
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.mahout.classifier.bayes.common.BayesParameters;
+import org.apache.mahout.common.StringTuple;
import java.io.IOException;
@@ -34,18 +34,6 @@
public class BayesWeightSummerDriver implements BayesJob {
/**
- * Takes in two arguments: <ol> <li>The input {@link org.apache.hadoop.fs.Path} where the input documents live</li>
- * <li>The output {@link org.apache.hadoop.fs.Path} where to write the the interim files as a {@link
- * org.apache.hadoop.io.SequenceFile}</li> </ol>
- *
- * @param args The args - should contain input and output path.
- */
- public static void main(String[] args) throws Exception {
- JobExecutor executor = new JobExecutor();
- executor.execute(args, new BayesWeightSummerDriver());
- }
-
- /**
* Run the job
*
* @param input the input pathname String
@@ -58,7 +46,7 @@
conf.setJobName("Bayes Weight Summer Driver running over input: " + input);
- conf.setOutputKeyClass(Text.class);
+ conf.setOutputKeyClass(StringTuple.class);
conf.setOutputValueClass(DoubleWritable.class);
FileInputFormat.addInputPath(conf, new Path(output + "/trainer-tfIdf/trainer-tfIdf"));
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerMapper.java?rev=823911&r1=823910&r2=823911&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerMapper.java Sat Oct 10 18:17:30 2009
@@ -18,37 +18,40 @@
package org.apache.mahout.classifier.bayes.mapreduce.common;
import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.common.StringTuple;
import java.io.IOException;
public class BayesWeightSummerMapper extends MapReduceBase implements
- Mapper<Text, DoubleWritable, Text, DoubleWritable> {
+ Mapper<StringTuple, DoubleWritable, StringTuple, DoubleWritable> {
/**
- * We need to calculate the idf of each feature in each label
- *
- * @param key The label,feature pair (can either be the freq Count or the term Document count
+ * We need to calculate the weight sums across each label and each feature
+ *
+ * @param key The label,feature tuple containing the tfidf value
*/
@Override
- public void map(Text key, DoubleWritable value,
- OutputCollector<Text, DoubleWritable> output, Reporter reporter)
+ public void map(StringTuple key, DoubleWritable value,
+ OutputCollector<StringTuple, DoubleWritable> output, Reporter reporter)
throws IOException {
-
- String labelFeaturePair = key.toString();
- int i = labelFeaturePair.indexOf(',');
-
-
- String label = labelFeaturePair.substring(0,i);
- String feature = labelFeaturePair.substring(i+1);
- reporter.setStatus("Bayes Weight Summer Mapper: " + labelFeaturePair);
- output.collect(new Text(',' + feature), value);//sum of weight for all labels for a feature Sigma_j
- output.collect(new Text('_' + label), value);//sum of weight for all features for a label Sigma_k
- output.collect(new Text("*"), value);//sum of weight of all features for all label Sigma_kSigma_j
+ String label = key.stringAt(1);
+ String feature = key.stringAt(2);
+ reporter.setStatus("Bayes Weight Summer Mapper: " + key);
+ StringTuple featureSum = new StringTuple(BayesConstants.FEATURE_SUM);
+ featureSum.add(feature);
+ output.collect(featureSum, value);// sum of weight for all labels for a
+ // feature Sigma_j
+ StringTuple labelSum = new StringTuple(BayesConstants.LABEL_SUM);
+ labelSum.add(label);
+ output.collect(labelSum, value);// sum of weight for all features for a
+ // label Sigma_k
+ StringTuple totalSum = new StringTuple(BayesConstants.TOTAL_SUM);
+ output.collect(totalSum, value);// sum of weight of all features for all
+ // label Sigma_kSigma_j
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerOutputFormat.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerOutputFormat.java?rev=823911&r1=823910&r2=823911&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerOutputFormat.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerOutputFormat.java Sat Oct 10 18:17:30 2009
@@ -18,7 +18,6 @@
package org.apache.mahout.classifier.bayes.mapreduce.common;
import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobConf;
@@ -26,6 +25,7 @@
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.hadoop.mapred.lib.MultipleOutputFormat;
import org.apache.hadoop.util.Progressable;
+import org.apache.mahout.common.StringTuple;
import java.io.IOException;
@@ -50,17 +50,22 @@
@Override
protected String generateFileNameForKeyValue(WritableComparable<?> k, Writable v,
String name) {
- Text key = (Text) k;
+ StringTuple key = (StringTuple) k;
- char firstChar = key.toString().charAt(0);
- if (firstChar == '*') { //sum of weight of all features for all label Sigma_kSigma_j
+ if(key.length() == 1 && key.stringAt(0).equals(BayesConstants.TOTAL_SUM))
+ {
return "Sigma_kSigma_j/" + name;
- } else if (firstChar == ',') { //sum of weight for all labels for a feature Sigma_j
- return "Sigma_j/" + name;
- } else if (firstChar == '_') { //sum of weights for all features for a label Sigma_k
- return "Sigma_k/" + name;
}
- return "JunkFileThisShouldNotHappen";
+ else{
+ if(key.stringAt(0).equals(BayesConstants.FEATURE_SUM))
+ {
+ return "Sigma_j/" + name;
+ }
+ else if(key.stringAt(0).equals(BayesConstants.LABEL_SUM))
+ return "Sigma_k/" + name;
+ else
+ throw new RuntimeException("Unexpected StringTuple: " + key);
+ }
}
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerReducer.java?rev=823911&r1=823910&r2=823911&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/mapreduce/common/BayesWeightSummerReducer.java Sat Oct 10 18:17:30 2009
@@ -22,13 +22,13 @@
import org.apache.hadoop.hbase.client.Put;
import org.apache.hadoop.hbase.util.Bytes;
import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.MapReduceBase;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.common.Parameters;
+import org.apache.mahout.common.StringTuple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -37,7 +37,7 @@
/** Can also be used as a local Combiner */
public class BayesWeightSummerReducer extends MapReduceBase implements
- Reducer<Text, DoubleWritable, Text, DoubleWritable> {
+ Reducer<StringTuple, DoubleWritable, StringTuple, DoubleWritable> {
private static final Logger log = LoggerFactory
.getLogger(BayesWeightSummerReducer.class);
@@ -47,8 +47,8 @@
boolean useHbase = false;
@Override
- public void reduce(Text key, Iterator<DoubleWritable> values,
- OutputCollector<Text, DoubleWritable> output, Reporter reporter)
+ public void reduce(StringTuple key, Iterator<DoubleWritable> values,
+ OutputCollector<StringTuple, DoubleWritable> output, Reporter reporter)
throws IOException {
// Key is label,word, value is the tfidf of the feature of times we've seen
// this label word per local node. Output is the same
@@ -61,26 +61,28 @@
reporter.setStatus("Bayes Weight Summer Reducer: " + key + " => " + sum);
char firstChar = key.toString().charAt(0);
if (useHbase) {
- if (firstChar == ',') { // sum of weight for all labels for a feature
+ if (key.stringAt(0).equals(BayesConstants.FEATURE_SUM)) { // sum of weight
+ // for all
+ // labels for a
+ // feature
// Sigma_j
- String feature = key.toString().substring(1);
+ String feature = key.stringAt(1);
Put bu = new Put(Bytes.toBytes(feature));
- bu.add(Bytes.toBytes("label"), Bytes.toBytes("Sigma_j"), Bytes
- .toBytes(sum));
+ bu.add(Bytes.toBytes(BayesConstants.HBASE_COLUMN_FAMILY), Bytes
+ .toBytes(BayesConstants.FEATURE_SUM), Bytes.toBytes(sum));
table.put(bu);
- } else if (firstChar == '_') {
- String label = key.toString().substring(1);
- Put bu = new Put(Bytes.toBytes("*labelWeight"));
- bu.add(Bytes.toBytes("label"), Bytes.toBytes(label), Bytes
- .toBytes(sum));
+ } else if (key.stringAt(0).equals(BayesConstants.LABEL_SUM)) {
+ String label = key.stringAt(1);
+ Put bu = new Put(Bytes.toBytes(BayesConstants.LABEL_SUM));
+ bu.add(Bytes.toBytes(BayesConstants.HBASE_COLUMN_FAMILY), Bytes
+ .toBytes(label), Bytes.toBytes(sum));
table.put(bu);
- }
- else if (firstChar == '*') {
- Put bu = new Put(Bytes.toBytes("*totalCounts"));
- bu.add(Bytes.toBytes("label"), Bytes.toBytes("sigma_jSigma_k"), Bytes
- .toBytes(sum));
+ } else if (key.stringAt(0).equals(BayesConstants.TOTAL_SUM)) {
+ Put bu = new Put(Bytes.toBytes(BayesConstants.HBASE_COUNTS_ROW));
+ bu.add(Bytes.toBytes(BayesConstants.HBASE_COLUMN_FAMILY), Bytes
+ .toBytes(BayesConstants.TOTAL_SUM), Bytes.toBytes(sum));
table.put(bu);
}
}
@@ -91,8 +93,8 @@
@Override
public void configure(JobConf job) {
try {
- Parameters params = Parameters.fromString(job.get(
- "bayes.parameters", ""));
+ Parameters params = Parameters
+ .fromString(job.get("bayes.parameters", ""));
if (params.get("dataSource").equals("hbase"))
useHbase = true;
else
@@ -108,9 +110,9 @@
@Override
public void close() throws IOException {
- if (useHbase) {
+ if (useHbase) {
table.close();
- }
+ }
super.close();
}
}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/StringTuple.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/StringTuple.java?rev=823911&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/StringTuple.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/StringTuple.java Sat Oct 10 18:17:30 2009
@@ -0,0 +1,168 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.hadoop.hbase.util.Bytes;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableComparable;
+
+/**
+ * An Ordered List of Strings which can be used in a Hadoop Map/Reduce Job
+ *
+ *
+ */
+public class StringTuple implements Writable, WritableComparable<StringTuple> {
+
+ private List<String> tuple = new ArrayList<String>();
+
+ public StringTuple() {
+ }
+
+ public StringTuple(String firstEntry) {
+ add(firstEntry);
+ }
+
+ public StringTuple(Collection<String> entries) {
+ for(String entry: entries)
+ add(entry);
+ }
+
+ public StringTuple(String[] entries) {
+ for(String entry: entries)
+ add(entry);
+ }
+
+ /**
+ * add an entry to the end of the list
+ *
+ * @param entry
+ * @return true if the items get added
+ */
+ public boolean add(String entry) {
+ return tuple.add(entry);
+ }
+
+ /**
+ * Fetches the string at the given location
+ *
+ * @param index
+ * @return String value at the given location in the tuple list
+ */
+ public String stringAt(int index) {
+ return tuple.get(index);
+ }
+
+ /**
+ * Replaces the string at the given index with the given newString
+ *
+ * @param index
+ * @param newString
+ * @return The previous value at that location
+ */
+ public String replaceAt(int index, String newString) {
+ return tuple.set(index, newString);
+ }
+
+ /**
+ * Fetch the list of entries from the tuple
+ *
+ * @return a List containing the strings in the order of insertion
+ */
+ public List<String> getEntries() {
+ return Collections.unmodifiableList(this.tuple);
+ }
+
+ /**
+ * Returns the length of the tuple
+ *
+ * @return length
+ */
+ public int length() {
+ return this.tuple.size();
+ }
+
+ @Override
+ public String toString() {
+ return tuple.toString();
+ };
+
+ @Override
+ public int hashCode() {
+ return tuple.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+ if (obj == null)
+ return false;
+ if (getClass() != obj.getClass())
+ return false;
+ StringTuple other = (StringTuple) obj;
+ if (tuple == null) {
+ if (other.tuple != null)
+ return false;
+ } else if (!tuple.equals(other.tuple))
+ return false;
+ return true;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int len = in.readInt();
+ tuple = new ArrayList<String>(len);
+ for (int i = 0; i < len; i++) {
+ int fieldLen = in.readInt();
+ byte[] entry = new byte[fieldLen];
+ in.readFully(entry);
+ tuple.add(Bytes.toString(entry));
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(tuple.size());
+ for (String entry : tuple) {
+ byte[] data = Bytes.toBytes(entry);
+ out.writeInt(data.length);
+ out.write(data);
+ }
+ }
+
+ @Override
+ public int compareTo(StringTuple otherTuple) {
+ int min = Math.min(this.length(), otherTuple.length());
+ for (int i = 0; i < min; i++) {
+ int ret = this.tuple.get(i).compareTo(otherTuple.stringAt(i));
+ if (ret == 0)
+ continue;
+ return ret;
+ }
+ return this.length() - otherTuple.length();
+ }
+
+}
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFeatureMapperTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFeatureMapperTest.java?rev=823911&r1=823910&r2=823911&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFeatureMapperTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFeatureMapperTest.java Sat Oct 10 18:17:30 2009
@@ -25,6 +25,7 @@
import org.apache.mahout.classifier.bayes.mapreduce.common.BayesFeatureMapper;
import org.apache.mahout.classifier.bayes.common.BayesParameters;
import org.apache.mahout.common.DummyOutputCollector;
+import org.apache.mahout.common.StringTuple;
import java.util.List;
import java.util.Map;
@@ -39,7 +40,7 @@
conf.set("bayes.parameters", new BayesParameters(3).toString());
mapper.configure(conf);
- DummyOutputCollector<Text, DoubleWritable> output = new DummyOutputCollector<Text, DoubleWritable>();
+ DummyOutputCollector<StringTuple, DoubleWritable> output = new DummyOutputCollector<StringTuple, DoubleWritable>();
mapper.map(new Text("foo"), new Text("big brown shoe"), output, Reporter.NULL);
Map<String, List<DoubleWritable>> outMap = output.getData();
System.out.println("Map: " + outMap);
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=823911&r1=823910&r2=823911&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java Sat Oct 10 18:17:30 2009
@@ -52,57 +52,85 @@
public class TestClassifier {
- private static final Logger log = LoggerFactory.getLogger(TestClassifier.class);
+ private static final Logger log = LoggerFactory
+ .getLogger(TestClassifier.class);
private TestClassifier() {
// do nothing
}
- public static void main(String[] args) throws IOException, OptionException, InvalidDatastoreException {
+ public static void main(String[] args) throws IOException, OptionException,
+ InvalidDatastoreException, ClassNotFoundException,
+ InstantiationException, IllegalAccessException {
DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
- Option pathOpt = obuilder.withLongName("model").withRequired(true).withArgument(
- abuilder.withName("model").withMinimum(1).withMaximum(1).create()).withDescription(
- "The path on HDFS / Name of Hbase Table as defined by the -source parameter").withShortName("m").create();
-
- Option dirOpt = obuilder.withLongName("testDir").withRequired(true).withArgument(
- abuilder.withName("testDir").withMinimum(1).withMaximum(1).create()).withDescription(
- "The directory where test documents resides in").withShortName("d").create();
+ Option pathOpt = obuilder
+ .withLongName("model")
+ .withRequired(true)
+ .withArgument(
+ abuilder.withName("model").withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "The path on HDFS / Name of Hbase Table as defined by the -source parameter")
+ .withShortName("m").create();
+
+ Option dirOpt = obuilder
+ .withLongName("testDir")
+ .withRequired(true)
+ .withArgument(
+ abuilder.withName("testDir").withMinimum(1).withMaximum(1).create())
+ .withDescription("The directory where test documents resides in")
+ .withShortName("d").create();
Option encodingOpt = obuilder.withLongName("encoding").withArgument(
- abuilder.withName("encoding").withMinimum(1).withMaximum(1).create()).withDescription(
- "The file encoding. Defaults to UTF-8").withShortName("e").create();
+ abuilder.withName("encoding").withMinimum(1).withMaximum(1).create())
+ .withDescription("The file encoding. Defaults to UTF-8")
+ .withShortName("e").create();
Option analyzerOpt = obuilder.withLongName("analyzer").withArgument(
- abuilder.withName("analyzer").withMinimum(1).withMaximum(1).create()).withDescription("The Analyzer to use")
- .withShortName("a").create();
+ abuilder.withName("analyzer").withDefault(
+ "org.apache.lucene.analysis.standard.StandardAnalyzer")
+ .withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Analyzer to use").withShortName("a").create();
Option defaultCatOpt = obuilder.withLongName("defaultCat").withArgument(
- abuilder.withName("defaultCat").withMinimum(1).withMaximum(1).create()).withDescription("The default category")
- .withShortName("default").create();
-
- Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true).withArgument(
- abuilder.withName("gramSize").withMinimum(1).withMaximum(1).create()).withDescription("Size of the n-gram")
- .withShortName("ng").create();
- Option verboseOutputOpt = obuilder.withLongName("verbose").withRequired(false).withDescription(
- "Output which values were correctly and incorrectly classified").withShortName("v").create();
- Option typeOpt = obuilder.withLongName("classifierType").withRequired(true).withArgument(
- abuilder.withName("classifierType").withMinimum(1).withMaximum(1).create()).withDescription(
- "Type of classifier: bayes|cbayes").withShortName("type").create();
-
- Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(true).withArgument(
- abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create()).withDescription(
- "Location of model: hdfs|hbase").withShortName("source").create();
-
- Option methodOpt = obuilder.withLongName("method").withRequired(true).withArgument(
- abuilder.withName("method").withMinimum(1).withMaximum(1).create()).withDescription(
- "Method of Classification: sequential|mapreduce").withShortName("method").create();
-
- Group group = gbuilder.withName("Options").withOption(analyzerOpt).withOption(defaultCatOpt).withOption(dirOpt)
- .withOption(encodingOpt).withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt).withOption(
- dataSourceOpt).withOption(methodOpt).withOption(verboseOutputOpt).create();
+ abuilder.withName("defaultCat").withMinimum(1).withMaximum(1).create())
+ .withDescription("The default category").withShortName("default")
+ .create();
+
+ Option gramSizeOpt = obuilder.withLongName("gramSize").withRequired(true)
+ .withArgument(
+ abuilder.withName("gramSize").withMinimum(1).withMaximum(1)
+ .create()).withDescription("Size of the n-gram").withShortName(
+ "ng").create();
+ Option verboseOutputOpt = obuilder.withLongName("verbose").withRequired(
+ false).withDescription(
+ "Output which values were correctly and incorrectly classified")
+ .withShortName("v").create();
+ Option typeOpt = obuilder.withLongName("classifierType").withRequired(true)
+ .withArgument(
+ abuilder.withName("classifierType").withMinimum(1).withMaximum(1)
+ .create()).withDescription("Type of classifier: bayes|cbayes")
+ .withShortName("type").create();
+
+ Option dataSourceOpt = obuilder.withLongName("dataSource").withRequired(
+ true).withArgument(
+ abuilder.withName("dataSource").withMinimum(1).withMaximum(1).create())
+ .withDescription("Location of model: hdfs|hbase").withShortName(
+ "source").create();
+
+ Option methodOpt = obuilder.withLongName("method").withRequired(true)
+ .withArgument(
+ abuilder.withName("method").withMinimum(1).withMaximum(1).create())
+ .withDescription("Method of Classification: sequential|mapreduce")
+ .withShortName("method").create();
+
+ Group group = gbuilder.withName("Options").withOption(analyzerOpt)
+ .withOption(defaultCatOpt).withOption(dirOpt).withOption(encodingOpt)
+ .withOption(gramSizeOpt).withOption(pathOpt).withOption(typeOpt)
+ .withOption(dataSourceOpt).withOption(methodOpt).withOption(
+ verboseOutputOpt).create();
Parser parser = new Parser();
parser.setGroup(group);
@@ -131,15 +159,8 @@
}
boolean verbose = cmdLine.hasOption(verboseOutputOpt);
- // Analyzer analyzer = null;
- // if (cmdLine.hasOption(analyzerOpt)) {
- // String className = (String) cmdLine.getValue(analyzerOpt);
- // Class clazz = Class.forName(className);
- // analyzer = (Analyzer) clazz.newInstance();
- // }
- // if (analyzer == null) {
- // analyzer = new StandardAnalyzer();
- // }
+
+ String className = (String) cmdLine.getValue(analyzerOpt);
String testDirPath = (String) cmdLine.getValue(dirOpt);
@@ -150,6 +171,7 @@
params.set("classifierType", classifierType);
params.set("dataSource", dataSource);
params.set("defaultCat", defaultCat);
+ params.set("analyzer", className);
params.set("encoding", encoding);
params.set("testDirPath", testDirPath);
if (classificationMethod.equalsIgnoreCase("sequential"))
@@ -158,7 +180,8 @@
classifyParallel(params);
}
- public static void classifySequential(BayesParameters params) throws IOException, InvalidDatastoreException {
+ public static void classifySequential(BayesParameters params)
+ throws IOException, InvalidDatastoreException {
log.info("Loading model from: {}", params.print());
boolean verbose = Boolean.valueOf(params.get("verbose"));
File dir = new File(params.get("testDirPath"));
@@ -182,7 +205,8 @@
algorithm = new CBayesAlgorithm();
datastore = new InMemoryBayesDatastore(params);
} else {
- throw new IllegalArgumentException("Unrecognized classifier type: " + params.get("classifierType"));
+ throw new IllegalArgumentException("Unrecognized classifier type: "
+ + params.get("classifierType"));
}
} else if (params.get("dataSource").equals("hbase")) {
@@ -195,15 +219,18 @@
algorithm = new CBayesAlgorithm();
datastore = new HBaseBayesDatastore(params.get("basePath"), params);
} else {
- throw new IllegalArgumentException("Unrecognized classifier type: " + params.get("classifierType"));
+ throw new IllegalArgumentException("Unrecognized classifier type: "
+ + params.get("classifierType"));
}
} else {
- throw new IllegalArgumentException("Unrecognized dataSource type: " + params.get("dataSource"));
+ throw new IllegalArgumentException("Unrecognized dataSource type: "
+ + params.get("dataSource"));
}
ClassifierContext classifier = new ClassifierContext(algorithm, datastore);
classifier.initialize();
- ResultAnalyzer resultAnalyzer = new ResultAnalyzer(classifier.getLabels(), params.get("defaultCat"));
+ ResultAnalyzer resultAnalyzer = new ResultAnalyzer(classifier.getLabels(),
+ params.get("defaultCat"));
final TimingStatistics totalStatistics = new TimingStatistics();
if (subdirs != null) {
@@ -214,24 +241,29 @@
final TimingStatistics operationStats = new TimingStatistics();
long lineNum = 0;
- for (String line : new FileLineIterable(new File(file.getPath()), Charset.forName(params.get("encoding")), false)) {
+ for (String line : new FileLineIterable(new File(file.getPath()),
+ Charset.forName(params.get("encoding")), false)) {
- Map<String, List<String>> document = new NGrams(line, Integer.parseInt(params.get("gramSize")))
- .generateNGrams();
- for (Map.Entry<String, List<String>> stringListEntry : document.entrySet()) {
+ Map<String, List<String>> document = new NGrams(line, Integer
+ .parseInt(params.get("gramSize"))).generateNGrams();
+ for (Map.Entry<String, List<String>> stringListEntry : document
+ .entrySet()) {
List<String> strings = stringListEntry.getValue();
TimingStatistics.Call call = operationStats.newCall();
TimingStatistics.Call outercall = totalStatistics.newCall();
- ClassifierResult classifiedLabel = classifier.classifyDocument(strings
- .toArray(new String[strings.size()]), params.get("defaultCat"));
+ ClassifierResult classifiedLabel = classifier.classifyDocument(
+ strings.toArray(new String[strings.size()]), params
+ .get("defaultCat"));
call.end();
outercall.end();
- boolean correct = resultAnalyzer.addInstance(correctLabel, classifiedLabel);
+ boolean correct = resultAnalyzer.addInstance(correctLabel,
+ classifiedLabel);
if (verbose) {
// We have one document per line
log.info("Line Number: " + lineNum + " Line(30): "
- + (line.length() > 30 ? line.substring(0, 30) : line) + " Expected Label: " + correctLabel
- + " Classified Label: " + classifiedLabel.getLabel() + " Correct: " + correct);
+ + (line.length() > 30 ? line.substring(0, 30) : line)
+ + " Expected Label: " + correctLabel + " Classified Label: "
+ + classifiedLabel.getLabel() + " Correct: " + correct);
}
// log.info("{} {}", correctLabel, classifiedLabel);
@@ -250,7 +282,8 @@
log.info(resultAnalyzer.summarize());
}
- public static void classifyParallel(BayesParameters params) throws IOException {
+ public static void classifyParallel(BayesParameters params)
+ throws IOException {
BayesClassifierDriver.runJob(params);
}
}