You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2008/08/19 14:55:48 UTC
svn commit: r687042 [3/4] - in /lucene/mahout/trunk: core/
core/src/main/java/org/apache/mahout/classifier/
core/src/main/java/org/apache/mahout/classifier/bayes/
core/src/main/java/org/apache/mahout/classifier/bayes/common/
core/src/main/java/org/apac...
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaMapper.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaMapper.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,112 @@
+package org.apache.mahout.classifier.cbayes;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.DefaultStringifier;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.util.GenericsUtil;
+
+import java.io.IOException;
+import java.util.*;
+
+/**
+ *
+ *
+ */
+public class CBayesThetaMapper extends MapReduceBase implements
+ Mapper<Text, FloatWritable, Text, FloatWritable> {
+
+ public HashMap<String, Float> labelWeightSum = null;
+ String labelWeightSumString = " ";
+ Float sigma_jSigma_k = 0f;
+ String sigma_jSigma_kString = " ";
+ Float vocabCount = 0f;
+ String vocabCountString = " ";
+
+ /**
+ * We need to calculate the idf of each feature in each label
+ *
+ * @param key The label,feature pair (can either be the freq Count or the term
+ * Document count
+ * @param value
+ * @param output
+ * @param reporter
+ * @throws IOException
+ */
+ public void map(Text key, FloatWritable value,
+ OutputCollector<Text, FloatWritable> output, Reporter reporter)
+ throws IOException {
+
+ String labelFeaturePair = key.toString();
+ float alpha_i = 1.0f;
+
+ if (labelFeaturePair.startsWith(",")) { // if it is from the Sigma_j folder
+ // (feature weight Sum)
+ String feature = labelFeaturePair.substring(1);
+ for (String label : labelWeightSum.keySet()) {
+ double inverseDenominator = 1.0d /(sigma_jSigma_k - labelWeightSum.get(label) + vocabCount);
+ FloatWritable weight = new FloatWritable((float)((value.get() + alpha_i)*inverseDenominator ));
+ output.collect(new Text((label + "," + feature).trim()), weight); //output Sigma_j
+ }
+ } else {
+ String label = labelFeaturePair.split(",")[0];
+ double inverseDenominator = 1.0d /(sigma_jSigma_k - labelWeightSum.get(label) + vocabCount);
+ FloatWritable weight = new FloatWritable((float)(-1 * value.get() * inverseDenominator));
+ output.collect(key, weight);//output -D_ij
+ }
+ }
+
+ @Override
+ public void configure(JobConf job) {
+ try {
+ if (labelWeightSum == null) {
+ labelWeightSum = new HashMap<String, Float>();
+
+ DefaultStringifier<HashMap<String, Float>> mapStringifier = new DefaultStringifier<HashMap<String, Float>>(
+ job, GenericsUtil.getClass(labelWeightSum));
+
+ labelWeightSumString = mapStringifier.toString(labelWeightSum);
+ labelWeightSumString = job.get("cnaivebayes.sigma_k",
+ labelWeightSumString);
+ labelWeightSum = mapStringifier.fromString(labelWeightSumString);
+
+ DefaultStringifier<Float> floatStringifier = new DefaultStringifier<Float>(
+ job, GenericsUtil.getClass(sigma_jSigma_k));
+ sigma_jSigma_kString = floatStringifier.toString(sigma_jSigma_k);
+ sigma_jSigma_kString = job.get("cnaivebayes.sigma_jSigma_k",
+ sigma_jSigma_kString);
+ sigma_jSigma_k = floatStringifier.fromString(sigma_jSigma_kString);
+
+ vocabCountString = floatStringifier.toString(vocabCount);
+ vocabCountString = job.get("cnaivebayes.vocabCount",
+ vocabCountString);
+ vocabCount = floatStringifier.fromString(vocabCountString);
+
+ }
+ } catch (IOException ex) {
+
+ ex.printStackTrace();
+ }
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaMapper.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerDriver.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerDriver.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerDriver.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,127 @@
+package org.apache.mahout.classifier.cbayes;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.util.HashMap;
+
+import org.apache.hadoop.mapred.JobClient;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.SequenceFileInputFormat;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.util.GenericsUtil;
+import org.apache.hadoop.io.DefaultStringifier;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.mahout.classifier.bayes.io.SequenceFileModelReader;
+
+
+/**
+ * Create and run the Bayes Trainer.
+ *
+ **/
+public class CBayesThetaNormalizerDriver {
+ /**
+ * Takes in two arguments:
+ * <ol>
+ * <li>The input {@link org.apache.hadoop.fs.Path} where the input documents live</li>
+ * <li>The output {@link org.apache.hadoop.fs.Path} where to write the {@link org.apache.mahout.common.Model} as a {@link org.apache.hadoop.io.SequenceFile}</li>
+ * </ol>
+ * @param args The args
+ */
+ public static void main(String[] args) {
+ String input = args[0];
+ String output = args[1];
+
+ runJob(input, output);
+ }
+
+ /**
+ * Run the job
+ *
+ * @param input the input pathname String
+ * @param output the output pathname String
+
+ */
+ public static void runJob(String input, String output) {
+ JobClient client = new JobClient();
+ JobConf conf = new JobConf(CBayesThetaNormalizerDriver.class);
+
+
+ conf.setOutputKeyClass(Text.class);
+ conf.setOutputValueClass(FloatWritable.class);
+ SequenceFileInputFormat.addInputPath(conf, new Path(output + "/trainer-weights/Sigma_j"));
+ SequenceFileInputFormat.addInputPath(conf, new Path(output + "/trainer-tfIdf/trainer-tfIdf"));
+ Path outPath = new Path(output + "/trainer-thetaNormalizer");
+ SequenceFileOutputFormat.setOutputPath(conf, outPath);
+ conf.setNumMapTasks(100);
+ //conf.setNumReduceTasks(1);
+ conf.setMapperClass(CBayesThetaNormalizerMapper.class);
+ conf.setInputFormat(SequenceFileInputFormat.class);
+ conf.setCombinerClass(CBayesThetaNormalizerReducer.class);
+ conf.setReducerClass(CBayesThetaNormalizerReducer.class);
+ conf.setOutputFormat(SequenceFileOutputFormat.class);
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,org.apache.hadoop.io.serializer.WritableSerialization"); // Dont ever forget this. People should keep track of how hadoop conf parameters and make or break a piece of code
+
+ try {
+ FileSystem dfs = FileSystem.get(conf);
+ if (dfs.exists(outPath))
+ dfs.delete(outPath, true);
+
+ SequenceFileModelReader reader = new SequenceFileModelReader();
+
+ Path Sigma_kFiles = new Path(output+"/trainer-weights/Sigma_k/*");
+ HashMap<String,Float> labelWeightSum= reader.readLabelSums(dfs, Sigma_kFiles, conf);
+ DefaultStringifier<HashMap<String,Float>> mapStringifier = new DefaultStringifier<HashMap<String,Float>>(conf, GenericsUtil.getClass(labelWeightSum));
+ String labelWeightSumString = mapStringifier.toString(labelWeightSum);
+
+ System.out.println("Sigma_k for Each Label");
+ HashMap<String,Float> c = mapStringifier.fromString(labelWeightSumString);
+ System.out.println(c);
+ conf.set("cnaivebayes.sigma_k", labelWeightSumString);
+
+
+ Path sigma_kSigma_jFile = new Path(output+"/trainer-weights/Sigma_kSigma_j/*");
+ Float sigma_jSigma_k = reader.readSigma_jSigma_k(dfs, sigma_kSigma_jFile, conf);
+ DefaultStringifier<Float> floatStringifier = new DefaultStringifier<Float>(conf, Float.class);
+ String sigma_jSigma_kString = floatStringifier.toString(sigma_jSigma_k);
+
+ System.out.println("Sigma_kSigma_j for each Label and for each Features");
+ Float retSigma_jSigma_k = floatStringifier.fromString(sigma_jSigma_kString);
+ System.out.println(retSigma_jSigma_k);
+ conf.set("cnaivebayes.sigma_jSigma_k", sigma_jSigma_kString);
+
+ Path vocabCountFile = new Path(output+"/trainer-tfIdf/trainer-vocabCount/*");
+ Float vocabCount = reader.readVocabCount(dfs, vocabCountFile, conf);
+ String vocabCountString = floatStringifier.toString(vocabCount);
+
+ System.out.println("Vocabulary Count");
+ conf.set("cnaivebayes.vocabCount", vocabCountString);
+ Float retvocabCount = floatStringifier.fromString(vocabCountString);
+ System.out.println(retvocabCount);
+
+ client.setConf(conf);
+
+ JobClient.runJob(conf);
+
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerDriver.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerMapper.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerMapper.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,121 @@
+package org.apache.mahout.classifier.cbayes;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.io.DefaultStringifier;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.util.GenericsUtil;
+
+import java.io.IOException;
+import java.util.*;
+
+/**
+ *
+ *
+ */
+public class CBayesThetaNormalizerMapper extends MapReduceBase implements
+ Mapper<Text, FloatWritable, Text, FloatWritable> {
+
+ public HashMap<String, Float> labelWeightSum = null;
+
+ String labelWeightSumString = " ";
+
+ Float sigma_jSigma_k = 0f;
+
+ String sigma_jSigma_kString = " ";
+
+ Float vocabCount = 0f;
+
+ String vocabCountString = " ";
+
+ /**
+ * We need to calculate the idf of each feature in each label
+ *
+ * @param key The label,feature pair (can either be the freq Count or the term
+ * Document count
+ * @param value
+ * @param output
+ * @param reporter
+ * @throws IOException
+ */
+ public void map(Text key, FloatWritable value,
+ OutputCollector<Text, FloatWritable> output, Reporter reporter)
+ throws IOException {
+
+ String labelFeaturePair = key.toString();
+ Float alpha_i = 1.0f;
+ if (labelFeaturePair.startsWith(",")) { // if it is from the Sigma_j folder
+
+ for (String label : labelWeightSum.keySet()) {
+ float weight = (float)Math.log((value.get() + alpha_i)/(sigma_jSigma_k - labelWeightSum.get(label) + vocabCount));
+ output.collect(new Text(("_" +label).trim()), new FloatWritable(weight)); //output Sigma_j
+
+ }
+
+ }
+ else {
+ String label = labelFeaturePair.split(",")[0];
+
+ float D_ij = value.get();
+ float denominator = 0.5f *((sigma_jSigma_k / vocabCount) + (D_ij * (float)this.labelWeightSum.size()));
+ float weight = (float) Math.log( 1 - D_ij / denominator);
+ output.collect(new Text(("_" +label).trim()), new FloatWritable(weight));//output -D_ij
+
+ }
+
+ }
+
+ @Override
+ public void configure(JobConf job) {
+ try {
+ if (labelWeightSum == null) {
+ labelWeightSum = new HashMap<String, Float>();
+
+ DefaultStringifier<HashMap<String, Float>> mapStringifier = new DefaultStringifier<HashMap<String, Float>>(
+ job, GenericsUtil.getClass(labelWeightSum));
+
+ labelWeightSumString = mapStringifier.toString(labelWeightSum);
+ labelWeightSumString = job.get("cnaivebayes.sigma_k",
+ labelWeightSumString);
+ labelWeightSum = mapStringifier.fromString(labelWeightSumString);
+
+ DefaultStringifier<Float> floatStringifier = new DefaultStringifier<Float>(
+ job, GenericsUtil.getClass(sigma_jSigma_k));
+ sigma_jSigma_kString = floatStringifier.toString(sigma_jSigma_k);
+ sigma_jSigma_kString = job.get("cnaivebayes.sigma_jSigma_k",
+ sigma_jSigma_kString);
+ sigma_jSigma_k = floatStringifier.fromString(sigma_jSigma_kString);
+
+ vocabCountString = floatStringifier.toString(vocabCount);
+ vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString);
+ vocabCount = floatStringifier.fromString(vocabCountString);
+
+ }
+ } catch (IOException ex) {
+
+ ex.printStackTrace();
+ }
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerMapper.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerReducer.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerReducer.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerReducer.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,69 @@
+package org.apache.mahout.classifier.cbayes;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reducer;
+import org.apache.hadoop.mapred.Reporter;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Iterator;
+
+/**
+ * Can also be used as a local Combiner beacuse only two values should be there
+ * inside the values
+ *
+ */
+
+public class CBayesThetaNormalizerReducer extends MapReduceBase implements
+ Reducer<Text, FloatWritable, Text, FloatWritable> {
+
+ public HashMap<String, Float> labelWeightSum = null;
+
+ String labelWeightSumString = " ";
+
+ Float sigma_jSigma_k = 0f;
+
+ String sigma_jSigma_kString = " ";
+
+ Float vocabCount = 0f;
+
+ String vocabCountString = " ";
+
+ public void reduce(Text key, Iterator<FloatWritable> values,
+ OutputCollector<Text, FloatWritable> output, Reporter reporter)
+ throws IOException {
+ // Key is label,word, value is the number of times we've seen this label
+ // word per local node. Output is the same
+
+ float weightSumPerLabel = 0.0f;
+
+ while (values.hasNext()) {
+ weightSumPerLabel += values.next().get();
+ }
+ // System.out.println(token + "=>"+ weightSumPerLabel);
+ output.collect(key, new FloatWritable(weightSumPerLabel));
+
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerReducer.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaReducer.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaReducer.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaReducer.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,55 @@
+package org.apache.mahout.classifier.cbayes;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reducer;
+import org.apache.hadoop.mapred.Reporter;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+
+/**
+ * Can also be used as a local Combiner beacuse only two values should be there inside the values
+ *
+ **/
+
+public class CBayesThetaReducer extends MapReduceBase implements Reducer<Text, FloatWritable, Text, FloatWritable> {
+
+
+
+ public void reduce(Text key, Iterator<FloatWritable> values, OutputCollector<Text, FloatWritable> output, Reporter reporter) throws IOException {
+ //Key is label,word, value is the number of times we've seen this label word per local node. Output is the same
+ String token = key.toString();
+ float weight = 0.0f;
+ int numberofValues = 0;
+ while (values.hasNext()) {
+ weight += values.next().get();
+ numberofValues ++;
+ }
+ if(numberofValues < 2) return;
+ if(weight<=0.0f)
+ System.out.println(token + "=>"+ weight);
+ output.collect(key, new FloatWritable(weight));
+ }
+
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaReducer.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Classifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Classifier.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Classifier.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Classifier.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,64 @@
+package org.apache.mahout.common;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.common.Model;
+
+import java.util.Collection;
+
+/**
+ * Classifies documents based on a {@link org.apache.mahout.common.Model}.
+ */
+public interface Classifier {
+
+ /**
+ * Classify the document and return the top <code>numResults</code>
+ *
+ * @param model The model
+ * @param document The document to classify
+ * @param defaultCategory The default category to assign
+ * @param numResults The maximum number of results to return, ranked by score. Ties are broken by comparing the category
+ * @return A Collection of {@link org.apache.mahout.classifier.ClassifierResult}s.
+ */
+ public Collection<ClassifierResult> classify(Model model, String[] document, String defaultCategory, int numResults);
+
+
+ /**
+ * Classify the document according to the {@link org.apache.mahout.common.Model}
+ *
+ * @param model The trained {@link org.apache.mahout.common.Model}
+ * @param document The document to classify
+ * @param defaultCategory The default category to assign if one cannot be determined
+ * @return The single best category
+ */
+ public ClassifierResult classify(Model model, String[] document, String defaultCategory);
+
+ /**
+ * Calculate the document probability as the multiplication of the {@link org.apache.mahout.common.Model#FeatureWeight(String, String)} for each word given the label
+ *
+ * @param model The {@link org.apache.mahout.common.Model}
+ * @param label The label to calculate the probability of
+ * @param document The document
+ * @return The probability
+ * @see Model#FeatureWeight(String, String)
+ */
+ public float documentProbability(Model model, String label, String[] document);
+
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Classifier.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Model.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Model.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Model.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Model.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,270 @@
+package org.apache.mahout.common;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.mahout.cf.taste.impl.common.FastMap;
+
+
+import java.util.Map;
+import java.util.HashMap;
+import java.util.List;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.StringTokenizer;
+
+/**
+ * The Base Model Class. Currently there are some Bayes Model elements which have to be refactored out later.
+ *
+ */
+public abstract class Model {
+
+ protected List<Map<Integer, Float>> featureLabelWeights = new ArrayList<Map<Integer, Float>>();
+
+ protected Map<String, Integer> featureList = new FastMap<String, Integer>();
+
+ protected Map<String, Integer> labelList = new HashMap<String, Integer>();
+
+ protected List<Float> sumFeatureWeight = new ArrayList<Float>();
+
+ protected Map<Integer, Float> sumLabelWeight = new HashMap<Integer, Float>();
+
+ protected Map<Integer, Float> thetaNormalizer = new HashMap<Integer, Float>();
+
+ protected Float sigma_jSigma_k = new Float(0);
+
+ protected Float alpha_i = 1.0f; // alpha_i can be improved upon for increased smoothing
+
+ public static float DEFAULT_PROBABILITY = 0.5f;
+
+
+ protected abstract float FeatureWeight(Integer label, Integer feature);
+
+ protected abstract float getWeight(Integer label, Integer feature);
+
+ protected abstract float getWeightUnprocessed(Integer label, Integer feature);
+
+ public abstract void InitializeNormalizer();
+
+ public abstract void GenerateModel();
+
+ protected float getSumLabelWeight(Integer label) {
+ float result = 0;
+ Float numSeen = sumLabelWeight.get(label);
+ if (numSeen != null) {
+ result = ((float) numSeen);
+ }
+ return result;
+ }
+
+ protected float getThetaNormalizer(Integer label) {
+ float result = 0.0f;
+ Float numSeen = thetaNormalizer.get(label);
+ if (numSeen != null) {
+ result = ((float) numSeen);
+ }
+ return result;
+ }
+
+ protected float getSumFeatureWeight(Integer feature) {
+ float result = 0;
+ Float numSeen = sumFeatureWeight.get(feature);
+ if (numSeen != null) {
+ result = ((float) numSeen);
+ }
+ return result;
+ }
+
+ protected Integer getLabel(String label) {
+ if (!labelList.containsKey(label)) {
+
+ Integer labelId = Integer.valueOf(labelList.size());
+ labelList.put(label, labelId);
+ }
+ Integer labelId = labelList.get(label);
+ return labelId;
+ }
+
+ protected Integer getFeature(String feature) {
+ if (!featureList.containsKey(feature)) {
+
+ Integer featureId = Integer.valueOf(featureList.size());
+ featureList.put(feature, featureId);
+ }
+ Integer featureId = featureList.get(feature);
+ return featureId;
+ }
+
+ protected void setWeight(String labelString, String featureString, Float weight)
+ throws Exception {
+ Integer feature = getFeature(featureString);
+ Integer label = getLabel(labelString);
+ setWeight(label, feature, weight);
+ }
+
+ protected void setWeight(Integer label, Integer feature, Float weight) throws Exception {
+ if (featureLabelWeights.size() <= feature) {
+ // System.out.println(feature + "," + featureLabelWeights.size());
+ // System.in.read();
+ throw new Exception("This should not happen");
+
+ }
+ featureLabelWeights.get(feature).put(label, new Float(weight));
+ }
+
+ protected void setSumFeatureWeight(Integer feature, float sum) throws Exception {
+ if (sumFeatureWeight.size() != feature)
+ throw new Exception("This should not happen");
+ sumFeatureWeight.add(feature, new Float(sum));
+ }
+
+ protected void setSumLabelWeight(Integer label, float sum) throws Exception {
+ if (sumLabelWeight.size() != label)
+ throw new Exception("This should not happen");
+ sumLabelWeight.put(label, new Float(sum));
+ }
+
+ protected void setThetaNormalizer(Integer label, float sum) {
+ thetaNormalizer.put(label, new Float(sum));
+ }
+
+ public void initializeWeightMatrix() {
+ System.out.println(featureList.size());
+
+ for (int i = 0; i < featureList.size(); i++)
+ featureLabelWeights.add(new HashMap<Integer, Float>(1));
+ }
+
+ public void setSigma_jSigma_k(Float sigma_jSigma_k) {
+ this.sigma_jSigma_k = sigma_jSigma_k;
+ }
+
+ public void loadFeatureWeight(String labelString, String featureString,
+ float weight) {
+ try {
+ setWeight(labelString, featureString, weight);
+ } catch (Exception e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+
+ public void setSumFeatureWeight(String feature, float sum) {
+ try {
+ setSumFeatureWeight(getFeature(feature), sum);
+ } catch (Exception e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+
+ public void setSumLabelWeight(String label, float sum) {
+ try {
+ setSumLabelWeight(getLabel(label), sum);
+ } catch (Exception e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+
+ public void setThetaNormalizer(String label, float sum) {
+ setThetaNormalizer(getLabel(label), sum);
+ }
+
+
+
+ /**
+ * Get the weighted probability of the feature.
+ *
+ * @param labelString The label of the feature
+ * @param featureString The feature to calc. the prob. for
+ * @return The weighted probability
+ */
+ public float FeatureWeight(String labelString, String featureString) {
+ if (featureList.containsKey(featureString) == false)
+ return 0.0f;
+ Integer feature = getFeature(featureString);
+ Integer label = getLabel(labelString);
+ return FeatureWeight(label, feature);
+ }
+
+
+
+ public Collection<String> getLabels() {
+ return labelList.keySet();
+ }
+
+ public static Map<String, List<String>> generateNGrams(String line, int gramSize)
+ {
+ Map<String, List<String>> returnDocument = new HashMap<String, List<String>>();
+
+ StringTokenizer tokenizer = new StringTokenizer(line);
+ List<String> tokens = new ArrayList<String>();
+ String labelName = tokenizer.nextToken();
+ List<String> previousN_1Grams = new ArrayList<String>();
+ while (tokenizer.hasMoreTokens()) {
+
+ String next_token = tokenizer.nextToken();
+ if(previousN_1Grams.size() == gramSize)
+ previousN_1Grams.remove(0);
+
+ previousN_1Grams.add(next_token);
+
+ StringBuilder gramBuilder = new StringBuilder();
+
+ for(String gram: previousN_1Grams)
+ {
+ gramBuilder.append(gram);
+ String token = gramBuilder.toString();
+ tokens.add(token);
+ gramBuilder.append(" ");
+ }
+ }
+ returnDocument.put(labelName, tokens);
+ return returnDocument;
+ }
+
+ public static List<String> generateNGramsWithoutLabel(String line, int gramSize)
+ {
+
+ StringTokenizer tokenizer = new StringTokenizer(line);
+ List<String> tokens = new ArrayList<String>();
+
+ List<String> previousN_1Grams = new ArrayList<String>();
+ while (tokenizer.hasMoreTokens()) {
+
+ String next_token = tokenizer.nextToken();
+ if(previousN_1Grams.size() == gramSize)
+ previousN_1Grams.remove(0);
+
+ previousN_1Grams.add(next_token);
+
+ StringBuilder gramBuilder = new StringBuilder();
+
+ for(String gram: previousN_1Grams)
+ {
+ gramBuilder.append(gram);
+ String token = gramBuilder.toString();
+ tokens.add(token);
+ gramBuilder.append(" ");
+ }
+ }
+
+ return tokens;
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Model.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Summarizable.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Summarizable.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Summarizable.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Summarizable.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,27 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+/**
+ * A Summarizable Interface. All Classes which implements this has to have a summarize function
+ * which generates a string summary of the data contained in it
+ */
+public interface Summarizable{
+ /** @return Summary of the data inside the class */
+ public abstract String summarize() throws Exception;
+}
\ No newline at end of file
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Summarizable.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableFloat.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableFloat.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableFloat.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableFloat.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,9 @@
+package org.apache.mahout.utils;
+
+public class UpdatableFloat {
+ public float value;
+
+ public UpdatableFloat(float value) {
+ this.value = value;
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableFloat.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableLong.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableLong.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableLong.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableLong.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,10 @@
+package org.apache.mahout.utils;
+
+public class UpdatableLong {
+ public long value;
+
+ public UpdatableLong(long value) {
+ this.value = value;
+ }
+}
+
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableLong.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,104 @@
+package org.apache.mahout.classifier.bayes;
+
+/**
+ * Copyright 2004 The Apache Software Foundation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import junit.framework.TestCase;
+
+import java.util.Collection;
+
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.common.Model;
+
+public class BayesClassifierTest extends TestCase {
+ protected Model model;
+
+
+ public BayesClassifierTest(String s) {
+ super(s);
+ }
+
+ protected void setUp() {
+ model = new BayesModel();
+ String[] labels = new String[]{"a", "b", "c", "d", "e"};
+ long[] labelCounts = new long[]{6, 20, 60, 100, 200};
+ String[] features = new String[]{"aa", "bb", "cc", "dd", "ee"};
+ model.setSigma_jSigma_k(100f);
+
+ model.setSumFeatureWeight("aa", 100);
+ model.setSumFeatureWeight("bb", 100);
+ model.setSumFeatureWeight("cc", 100);
+ model.setSumFeatureWeight("dd", 100);
+ model.setSumFeatureWeight("ee", 100);
+
+ model.setSumLabelWeight("a", 1);
+ model.setSumLabelWeight("b", 1);
+ model.setSumLabelWeight("c", 1);
+ model.setSumLabelWeight("d", 1);
+ model.setSumLabelWeight("e", 1);
+
+ model.initializeWeightMatrix();
+
+ model.loadFeatureWeight("a", "aa", 5);
+ model.loadFeatureWeight("a", "bb", 1);
+
+ model.loadFeatureWeight("b", "bb", 20);
+
+ model.loadFeatureWeight("c", "cc", 30);
+ model.loadFeatureWeight("c", "aa", 25);
+ model.loadFeatureWeight("c", "dd", 5);
+
+ model.loadFeatureWeight("d", "dd", 60);
+ model.loadFeatureWeight("d", "cc", 40);
+
+ model.loadFeatureWeight("e", "ee", 100);
+ model.loadFeatureWeight("e", "aa", 50);
+ model.loadFeatureWeight("e", "dd", 50);
+ }
+
+ protected void tearDown() {
+
+ }
+
+ public void test() {
+ BayesClassifier classifier = new BayesClassifier();
+ ClassifierResult result;
+ String [] document;
+ document = new String[]{"aa", "ff"};
+ result = classifier.classify(model, document, "unknown");
+ assertTrue("category is null and it shouldn't be", result != null);
+ assertTrue(result + " is not equal to " + "e", result.getLabel().equals("e") == true);
+
+ document = new String[]{"ff"};
+ result = classifier.classify(model, document, "unknown");
+ assertTrue("category is null and it shouldn't be", result != null);
+ assertTrue(result + " is not equal to " + "unknown", result.getLabel().equals("unknown") == true);
+
+ document = new String[]{"cc"};
+ result = classifier.classify(model, document, "unknown");
+ assertTrue("category is null and it shouldn't be", result != null);
+ assertTrue(result + " is not equal to " + "d", result.getLabel().equals("d") == true);
+ }
+
+ public void testResults() throws Exception {
+ BayesClassifier classifier = new BayesClassifier();
+ String [] document;
+ document = new String[]{"aa", "ff"};
+ ClassifierResult result = classifier.classify(model, document, "unknown");
+ assertTrue("category is null and it shouldn't be", result != null);
+ System.out.println("Result: " + result);
+ }
+}
\ No newline at end of file
Propchange: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFileFormatterTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFileFormatterTest.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFileFormatterTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFileFormatterTest.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,102 @@
+package org.apache.mahout.classifier.bayes;
+
+/**
+ * Copyright 2004 The Apache Software Foundation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import junit.framework.TestCase;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.WhitespaceAnalyzer;
+import org.apache.mahout.classifier.BayesFileFormatter;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.nio.charset.Charset;
+
+public class BayesFileFormatterTest extends TestCase {
+ protected File input;
+ protected File out;
+ protected String[] words;
+
+
+ public BayesFileFormatterTest(String s) {
+ super(s);
+ }
+
+ protected void setUp() throws IOException {
+ File tmpDir = new File(System.getProperty("java.io.tmpdir"));
+ input = new File(tmpDir, "bayes/input");
+ out = new File(tmpDir, "bayes/out");
+ input.mkdirs();
+ out.mkdirs();
+ File[] files = out.listFiles();
+ for (File file : files) {
+ file.delete();
+ }
+ words = new String[]{"dog", "cat", "fish", "snake", "zebra"};
+ for (String word : words) {
+ File file = new File(input, word);
+ FileWriter writer = new FileWriter(file);
+ writer.write(word);
+ writer.close();
+ }
+ }
+
+ protected void tearDown() {
+
+ }
+
+ public void test() throws IOException {
+ Analyzer analyzer = new WhitespaceAnalyzer();
+ File[] files = out.listFiles();
+ assertTrue("files Size: " + files.length + " is not: " + 0, files.length == 0);
+ Charset charset = Charset.forName("UTF-8");
+ BayesFileFormatter.format("animal", analyzer, input, charset, out);
+
+ files = out.listFiles();
+ assertTrue("files Size: " + files.length + " is not: " + words.length, files.length == words.length);
+ for (File file : files) {
+ //should only be one line in the file, and it should be label label
+ BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file), charset));
+ String line = reader.readLine().trim();
+ String label = "animal" + '\t' + file.getName();
+ assertTrue(line + ":::: is not equal to " + label + "::::", line.equals(label) == true);
+ }
+ }
+
+ public void testCollapse() throws Exception {
+ Analyzer analyzer = new WhitespaceAnalyzer();
+ File[] files = out.listFiles();
+ assertTrue("files Size: " + files.length + " is not: " + 0, files.length == 0);
+ Charset charset = Charset.forName("UTF-8");
+ BayesFileFormatter.collapse("animal", analyzer, input, charset, new File(out, "animal"));
+ files = out.listFiles();
+ assertTrue("files Size: " + files.length + " is not: " + 1, files.length == 1);
+ BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(files[0]), charset));
+ String line = null;
+ int count = 0;
+ while ((line = reader.readLine()) != null){
+ assertTrue("line does not start with label", line.startsWith("animal"));
+ System.out.println("Line: " + line);
+ count++;
+ }
+ assertTrue(count + " does not equal: " + words.length, count == words.length);
+
+ }
+}
\ No newline at end of file
Propchange: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFileFormatterTest.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,112 @@
+package org.apache.mahout.classifier.bayes;
+
+/**
+ * Copyright 2004 The Apache Software Foundation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import junit.framework.TestCase;
+
+import java.util.Collection;
+
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.cbayes.CBayesModel;
+import org.apache.mahout.common.Model;
+
+public class CBayesClassifierTest extends TestCase {
+ protected CBayesModel model;
+
+
+ public CBayesClassifierTest(String s) {
+ super(s);
+ }
+
+ protected void setUp() {
+ model = new CBayesModel();
+ String[] labels = new String[]{"a", "b", "c", "d", "e"};
+ long[] labelCounts = new long[]{6, 20, 60, 100, 200};
+ String[] features = new String[]{"aa", "bb", "cc", "dd", "ee"};
+ model.setSigma_jSigma_k(500f);
+
+ model.setSumFeatureWeight("aa", 80);
+ model.setSumFeatureWeight("bb", 21);
+ model.setSumFeatureWeight("cc", 60);
+ model.setSumFeatureWeight("dd", 115);
+ model.setSumFeatureWeight("ee", 100);
+
+ model.setSumLabelWeight("a", 100);
+ model.setSumLabelWeight("b", 100);
+ model.setSumLabelWeight("c", 100);
+ model.setSumLabelWeight("d", 100);
+ model.setSumLabelWeight("e", 100);
+
+ model.setThetaNormalizer("a", -100);
+ model.setThetaNormalizer("b", -100);
+ model.setThetaNormalizer("c", -100);
+ model.setThetaNormalizer("d", -100);
+ model.setThetaNormalizer("e", -100);
+
+ model.InitializeNormalizer();
+ model.initializeWeightMatrix();
+
+ model.loadFeatureWeight("a", "aa", 5);
+ model.loadFeatureWeight("a", "bb", 1);
+
+ model.loadFeatureWeight("b", "bb", 20);
+
+ model.loadFeatureWeight("c", "cc", 30);
+ model.loadFeatureWeight("c", "aa", 25);
+ model.loadFeatureWeight("c", "dd", 5);
+
+ model.loadFeatureWeight("d", "dd", 60);
+ model.loadFeatureWeight("d", "cc", 40);
+
+ model.loadFeatureWeight("e", "ee", 100);
+ model.loadFeatureWeight("e", "aa", 50);
+ model.loadFeatureWeight("e", "dd", 50);
+ }
+
+ protected void tearDown() {
+
+ }
+
+ public void test() {
+ BayesClassifier classifier = new BayesClassifier();
+ ClassifierResult result;
+ String [] document;
+ document = new String[]{"aa", "ff"};
+ result = classifier.classify(model, document, "unknown");
+ assertTrue("category is null and it shouldn't be", result != null);
+ assertTrue(result + " is not equal to " + "e", result.getLabel().equals("e") == true);
+
+ document = new String[]{"ff"};
+ result = classifier.classify(model, document, "unknown");
+ assertTrue("category is null and it shouldn't be", result != null);
+ assertTrue(result + " is not equal to " + "unknown", result.getLabel().equals("unknown") == true);
+
+ document = new String[]{"cc"};
+ result = classifier.classify(model, document, "unknown");
+ assertTrue("category is null and it shouldn't be", result != null);
+ assertTrue(result + " is not equal to " + "d", result.getLabel().equals("d") == true);
+ }
+
+ public void testResults() throws Exception {
+ BayesClassifier classifier = new BayesClassifier();
+ String [] document;
+ document = new String[]{"aa", "ff"};
+ ClassifierResult result = classifier.classify(model, document, "unknown");
+ assertTrue("category is null and it shouldn't be", result != null);
+ System.out.println("Result: " + result);
+ }
+}
\ No newline at end of file
Propchange: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/PrepareTwentyNewsgroups.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/PrepareTwentyNewsgroups.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/PrepareTwentyNewsgroups.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/PrepareTwentyNewsgroups.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,79 @@
+package org.apache.mahout.classifier.bayes;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.mahout.classifier.BayesFileFormatter;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.PosixParser;
+import org.apache.commons.cli.ParseException;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.OptionBuilder;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.Charset;
+
+
+/**
+ * Prepare the 20 Newsgroups files for training using the {@link org.apache.mahout.classifier.BayesFileFormatter}.
+ *
+ * This class takes the directory containing the extracted newsgroups and collapses them into a single file per category, with
+ * one document per line (first token on each line is the label)
+ *
+ **/
+public class PrepareTwentyNewsgroups {
+ private transient static Log log = LogFactory.getLog(PrepareTwentyNewsgroups.class);
+ @SuppressWarnings("static-access")
+ public static void main(String[] args) throws ClassNotFoundException, IllegalAccessException, InstantiationException, IOException {
+ CommandLine cmdLine = null;
+ Options options = new Options();
+ Option parentOpt = OptionBuilder.withLongOpt("parent").isRequired().hasArg().withDescription("Parent dir containing the newsgroups").create("p");
+ options.addOption(parentOpt);
+ Option outputDirOpt = OptionBuilder.withLongOpt("outputDir").isRequired().hasArg().withDescription("The output directory").create("o");
+ options.addOption(outputDirOpt);
+ Option analyzerNameOpt = OptionBuilder.withLongOpt("analyzerName").isRequired().hasArg().withDescription("The class name of the analyzer").create("a");
+ options.addOption(analyzerNameOpt);
+ Option charsetOpt = OptionBuilder.withLongOpt("charset").hasArg().isRequired().withDescription("The name of the character encoding of the input files").create("c");
+ options.addOption(charsetOpt);
+
+ try {
+
+ PosixParser parser = new PosixParser();
+ cmdLine = parser.parse(options, args);
+ }
+ catch (ParseException exp) {
+ log.error("Cmd line Syntax Error: " + exp.getMessage(), exp);
+ }
+ File parentDir = new File(cmdLine.getOptionValue(parentOpt.getOpt()));
+ File outputDir = new File(cmdLine.getOptionValue(outputDirOpt.getOpt()));
+ String analyzerName = cmdLine.getOptionValue(analyzerNameOpt.getOpt());
+ Charset charset = Charset.forName(cmdLine.getOptionValue(charsetOpt.getOpt()));
+ Analyzer analyzer = (Analyzer) Class.forName(analyzerName).newInstance();
+ //parent dir contains dir by category
+ File [] categoryDirs = parentDir.listFiles();
+ for (File dir : categoryDirs) {
+ if (dir.isDirectory()){
+ File outputFile = new File(outputDir, dir.getName() + ".txt");
+ BayesFileFormatter.collapse(dir.getName(), analyzer, dir, charset, outputFile);
+ }
+ }
+ }
+}
\ No newline at end of file
Propchange: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/PrepareTwentyNewsgroups.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,184 @@
+package org.apache.mahout.classifier.bayes;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.PosixParser;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.OptionBuilder;
+import org.apache.commons.cli.Option;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.bayes.BayesClassifier;
+import org.apache.mahout.classifier.bayes.BayesModel;
+import org.apache.mahout.common.Classifier;
+import org.apache.mahout.common.Model;
+import org.apache.mahout.classifier.bayes.io.SequenceFileModelReader;
+import org.apache.mahout.classifier.cbayes.CBayesClassifier;
+import org.apache.mahout.classifier.cbayes.CBayesModel;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.analysis.Analyzer;
+
+import java.io.*;
+
+import java.util.*;
+
+/**
+ *
+ *
+ */
+public class TestClassifier {
+
+ @SuppressWarnings({ "static-access", "unchecked" })
+ public static void main(String[] args) throws IOException,
+ ClassNotFoundException, IllegalAccessException, InstantiationException {
+ Options options = new Options();
+ Option pathOpt = OptionBuilder.withLongOpt("path").isRequired().hasArg()
+ .withDescription("The local file system path").create("p");
+ options.addOption(pathOpt);
+ Option dirOpt = OptionBuilder.withLongOpt("testDir").isRequired().hasArg()
+ .withDescription("The directory where test documents resides in").create("t");
+ options.addOption(dirOpt);
+ Option encodingOpt = OptionBuilder.withLongOpt("encoding").hasArg()
+ .withDescription("The file encoding. defaults to UTF-8").create("e");
+ options.addOption(encodingOpt);
+ Option analyzerOpt = OptionBuilder.withLongOpt("analyzer").hasArg()
+ .withDescription("The Analyzer to use").create("a");
+ options.addOption(analyzerOpt);
+ Option defaultCatOpt = OptionBuilder.withLongOpt("defaultCat").hasArg()
+ .withDescription("The default category").create("d");
+ options.addOption(defaultCatOpt);
+ Option gramSizeOpt = OptionBuilder.withLongOpt("gramSize").hasArg()
+ .withDescription("Size of the n-gram").create("ng");
+ options.addOption(gramSizeOpt);
+ Option typeOpt = OptionBuilder.withLongOpt("classifierType").isRequired()
+ .hasArg().withDescription("Type of classifier").create("type");
+ options.addOption(typeOpt);
+
+ CommandLine cmdLine = null;
+ try {
+ PosixParser parser = new PosixParser();
+ cmdLine = parser.parse(options, args);
+ SequenceFileModelReader reader = new SequenceFileModelReader();
+ JobConf conf = new JobConf(TestClassifier.class);
+
+ Map<String, Path> modelPaths = new HashMap<String, Path>();
+ String modelBasePath = cmdLine.getOptionValue(pathOpt.getOpt());
+ modelPaths.put("sigma_j", new Path(modelBasePath + "/trainer-weights/Sigma_j/part-*"));
+ modelPaths.put("sigma_k", new Path(modelBasePath + "/trainer-weights/Sigma_k/part-*"));
+ modelPaths.put("sigma_kSigma_j", new Path(modelBasePath + "/trainer-weights/Sigma_kSigma_j/part-*"));
+ modelPaths.put("thetaNormalizer", new Path(modelBasePath + "/trainer-thetaNormalizer/part-*"));
+ modelPaths.put("weight", new Path(modelBasePath + "/trainer-tfIdf/trainer-tfIdf/part-*"));
+
+ FileSystem fs = FileSystem.get(conf);
+
+ System.out.println("Loading model from: " + modelPaths);
+
+ Model model = null;
+ Classifier classifier = null;
+
+ String classifierType = cmdLine.getOptionValue(typeOpt.getOpt());
+
+ if (classifierType.equalsIgnoreCase("bayes")) {
+ System.out.println("Testing Bayes Classifier");
+ model = new BayesModel();
+ classifier = new BayesClassifier();
+ } else if (classifierType.equalsIgnoreCase("cbayes")) {
+ System.out.println("Testing Complementary Bayes Classifier");
+ model = new CBayesModel();
+ classifier = new CBayesClassifier();
+ }
+
+ model = reader.loadModel(model, fs, modelPaths, conf);
+
+ System.out.println("Done loading model: # labels: "
+ + model.getLabels().size());
+
+ System.out.println("Done generating Model ");
+
+
+
+ String defaultCat = "unknown";
+ if (cmdLine.hasOption(defaultCatOpt.getOpt())) {
+ defaultCat = cmdLine.getOptionValue(defaultCatOpt.getOpt());
+ }
+
+ String encoding = "UTF-8";
+ if (cmdLine.hasOption(encodingOpt.getOpt())) {
+ encoding = cmdLine.getOptionValue(encodingOpt.getOpt());
+ }
+ Analyzer analyzer = null;
+ if (cmdLine.hasOption(analyzerOpt.getOpt())) {
+ String className = cmdLine.getOptionValue(analyzerOpt.getOpt());
+ Class clazz = Class.forName(className);
+ analyzer = (Analyzer) clazz.newInstance();
+ }
+ if (analyzer == null) {
+ analyzer = new StandardAnalyzer();
+ }
+ int gramSize = 1;
+ if (cmdLine.hasOption(gramSizeOpt.getOpt())) {
+ gramSize = Integer.parseInt(cmdLine
+ .getOptionValue(gramSizeOpt.getOpt()));
+
+ }
+
+ String testDirPath = cmdLine.getOptionValue(dirOpt.getOpt());
+ File dir = new File(testDirPath);
+ File[] subdirs = dir.listFiles();
+
+ ResultAnalyzer resultAnalyzer = new ResultAnalyzer(model.getLabels());
+
+ if (subdirs != null) {
+ for (int loop = 0; loop < subdirs.length; loop++) {
+
+ String correctLabel = subdirs[loop].getName().split(".txt")[0];
+ System.out.print(correctLabel);
+ BufferedReader fileReader = new BufferedReader(new InputStreamReader(
+ new FileInputStream(subdirs[loop].getPath()), encoding));
+ String line = null;
+ while ((line = fileReader.readLine()) != null) {
+
+ Map<String, List<String>> document = Model.generateNGrams(line, gramSize);
+ for (String labelName : document.keySet()) {
+ ClassifierResult classifiedLabel = classifier.classify(model,
+ (String[]) document.get(labelName).toArray(new String[0]),
+ defaultCat);
+ resultAnalyzer.addInstance(correctLabel, classifiedLabel);
+ }
+ }
+ System.out.println("\t"
+ + resultAnalyzer.getConfusionMatrix().getAccuracy(correctLabel)
+ + "\t"
+ + resultAnalyzer.getConfusionMatrix().getCorrect(correctLabel)
+ + "/"
+ + resultAnalyzer.getConfusionMatrix().getTotal(correctLabel));
+
+ }
+
+ }
+ System.out.println(resultAnalyzer.summarize());
+
+ } catch (Exception exp) {
+ exp.printStackTrace(System.err);
+ }
+ }
+}
Propchange: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,104 @@
+package org.apache.mahout.classifier.bayes;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.PosixParser;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.ParseException;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.OptionBuilder;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+//import org.apache.mahout.classifier.cbayes.CModelTrainerDriver;
+import org.apache.mahout.classifier.bayes.BayesDriver;
+import org.apache.mahout.classifier.cbayes.CBayesDriver;
+
+/**
+ * Train the Naive Bayes Complement classifier with improved weighting on the Twenty Newsgroups data (http://people.csail.mit.edu/jrennie/20Newsgroups/20news-18828.tar.gz)
+ *
+ * To run:
+ * Assume MAHOUT_HOME refers to the location where you checked out/installed Mahout
+ * <ol>
+ * <li>From the main dir: ant extract-20news-18828</li>
+ * <li>ant examples-job</li>
+ * <li>Start up Hadoop and copy the files to the system. See http://hadoop.apache.org/core/docs/r0.16.2/quickstart.html</li>
+ * <li>From the Hadoop dir (where Hadoop is installed):
+ * <ol>
+ * <li>emacs conf/hadoop-site.xml (add in local settings per quickstart)</li>
+ * <li>bin/hadoop namenode -format //Format the HDFS</li>
+ * <li>bin/start-all.sh //Start Hadoop</li>
+ * <li>bin/hadoop dfs -put <MAHOUT_HOME>/work/20news-18828-collapse 20newsInput //Copies the extracted text to HDFS</li>
+ * <li>bin/hadoop jar <MAHOUT_HOME>/build/apache-mahout-0.1-dev-ex.jar org.apache.mahout.examples.classifiers.cbayes.TwentyNewsgroups -t -i 20newsInput -o 20newsOutput</li>
+ * </ol>
+ * </li>
+ * </ol>
+ *
+ **/
+public class TrainClassifier {
+
+ private transient static Log log = LogFactory.getLog(TrainClassifier.class);
+
+ public void trainNaiveBayes(String dir, String outputDir, int gramSize){
+ BayesDriver.runJob(dir, outputDir, gramSize);
+ }
+
+ public void trainCNaiveBayes(String dir, String outputDir, int gramSize){
+ CBayesDriver.runJob(dir, outputDir, gramSize);
+ }
+
+ @SuppressWarnings("static-access")
+ public static void main(String[] args) {
+ Options options = new Options();
+ Option trainOpt = OptionBuilder.withLongOpt("train").withDescription("Train the classifier").create("t");
+ options.addOption(trainOpt);
+ Option inputDirOpt = OptionBuilder.withLongOpt("inputDir").hasArg().withDescription("The Directory on HDFS containing the collapsed, properly formatted files").create("i");
+ options.addOption(inputDirOpt);
+ Option outputOpt = OptionBuilder.withLongOpt("output").isRequired().hasArg().withDescription("The location of the model on the HDFS").create("o");
+ options.addOption(outputOpt);
+ Option gramSizeOpt = OptionBuilder.withLongOpt("gramSize").hasArg().withDescription("Size of the n-gram").create("ng");
+ options.addOption(gramSizeOpt);
+ Option typeOpt = OptionBuilder.withLongOpt("classifierType").isRequired().hasArg().withDescription("Type of classifier").create("type");
+ options.addOption(typeOpt);
+
+ CommandLine cmdLine = null;
+ try {
+ PosixParser parser = new PosixParser();
+ cmdLine = parser.parse(options, args);
+
+ boolean train = cmdLine.hasOption(trainOpt.getOpt());
+ TrainClassifier tn = new TrainClassifier();
+ if (train == true){;
+ String classifierType = cmdLine.getOptionValue(typeOpt.getOpt());
+ if(classifierType.equalsIgnoreCase("bayes")){
+ System.out.println("Training Bayes Classifier");
+ tn.trainNaiveBayes(cmdLine.getOptionValue(inputDirOpt.getOpt()), cmdLine.getOptionValue(outputOpt.getOpt()), Integer.parseInt(cmdLine.getOptionValue(gramSizeOpt.getOpt())));
+
+ }
+ else if(classifierType.equalsIgnoreCase("cbayes"))
+ {
+ System.out.println("Training Complementary Bayes Classifier");
+ //setup the HDFS and copy the files there, then run the trainer
+ tn.trainCNaiveBayes(cmdLine.getOptionValue(inputDirOpt.getOpt()), cmdLine.getOptionValue(outputOpt.getOpt()), Integer.parseInt(cmdLine.getOptionValue(gramSizeOpt.getOpt())));
+ }
+ }
+ }
+ catch (ParseException exp) {
+ log.error("Cmd line Syntax Error: " + exp.getMessage(), exp);
+ }
+ }
+}
Propchange: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreator.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreator.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreator.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreator.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,62 @@
+package org.apache.mahout.classifier.bayes;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.PosixParser;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.OptionBuilder;
+import org.apache.commons.cli.Option;
+
+import java.io.IOException;
+
+/**
+ *
+ *
+ */
+public class WikipediaDatasetCreator {
+
+ @SuppressWarnings("static-access")
+ public static void main(String[] args) throws IOException,
+ ClassNotFoundException, IllegalAccessException, InstantiationException {
+ Options options = new Options();
+ Option dirInputPathOpt = OptionBuilder.withLongOpt("dirInputPath").isRequired().hasArg()
+ .withDescription("The input Directory Path").create("i");
+ options.addOption(dirInputPathOpt);
+ Option dirOutputPathOpt = OptionBuilder.withLongOpt("dirOuputPath").isRequired().hasArg()
+ .withDescription("The output Directory Path").create("o");
+ options.addOption(dirOutputPathOpt);
+ Option countriesFileOpt = OptionBuilder.withLongOpt("countriesFile").isRequired().hasArg()
+ .withDescription("Location of the Countries File").create("c");
+ options.addOption(countriesFileOpt);
+
+ CommandLine cmdLine = null;
+ try {
+ PosixParser parser = new PosixParser();
+ cmdLine = parser.parse(options, args);
+
+ String dirInputPath = cmdLine.getOptionValue(dirInputPathOpt.getOpt());
+ String dirOutputPath = cmdLine.getOptionValue(dirOutputPathOpt.getOpt());
+ String countriesFile = cmdLine.getOptionValue(countriesFileOpt.getOpt());
+
+ WikipediaDatasetCreatorDriver.runJob(dirInputPath, dirOutputPath, countriesFile);
+ } catch (Exception exp) {
+ exp.printStackTrace(System.err);
+ }
+ }
+}
Propchange: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreator.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorDriver.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorDriver.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorDriver.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,114 @@
+package org.apache.mahout.classifier.bayes;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.BufferedReader;
+import java.io.FileInputStream;
+import java.io.InputStreamReader;
+import java.util.*;
+import org.apache.hadoop.mapred.*;
+import org.apache.hadoop.util.GenericsUtil;
+import org.apache.hadoop.io.DefaultStringifier;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FileSystem;
+
+
+
+/**
+ * Create and run the Bayes Trainer.
+ *
+ **/
+public class WikipediaDatasetCreatorDriver {
+ /**
+ * Takes in two arguments:
+ * <ol>
+ * <li>The input {@link org.apache.hadoop.fs.Path} where the input documents live</li>
+ * <li>The output {@link org.apache.hadoop.fs.Path} where to write the {@link org.apache.mahout.classifier.bayes.Model} as a {@link org.apache.hadoop.io.SequenceFile}</li>
+ * </ol>
+ * @param args The args
+ */
+ public static void main(String[] args) {
+ String input = args[0];
+ String output = args[1];
+ String countriesFile = args[2];
+
+ runJob(input, output,countriesFile);
+ }
+
+ /**
+ * Run the job
+ *
+ * @param input the input pathname String
+ * @param output the output pathname String
+
+ */
+ @SuppressWarnings({ "deprecation" })
+ public static void runJob(String input, String output, String countriesFile) {
+ JobClient client = new JobClient();
+ JobConf conf = new JobConf(WikipediaDatasetCreatorDriver.class);
+
+ conf.set("key.value.separator.in.input.line", " ");
+ conf.set("xmlinput.start", "<text xml:space=\"preserve\">");
+ conf.set("xmlinput.end", "</text>");
+ conf.setOutputKeyClass(Text.class);
+ conf.setOutputValueClass(Text.class);
+
+ conf.setInputPath(new Path(input));
+ Path outPath = new Path(output);
+ conf.setOutputPath(outPath);
+
+ conf.setMapperClass(WikipediaDatasetCreatorMapper.class);
+ conf.setNumMapTasks(100);
+ conf.setInputFormat(XmlInputFormat.class);
+ //conf.setCombinerClass(WikipediaDatasetCreatorReducer.class);
+ conf.setReducerClass(WikipediaDatasetCreatorReducer.class);
+ conf.setOutputFormat(WikipediaDatasetCreatorOutputFormat.class);
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,org.apache.hadoop.io.serializer.WritableSerialization"); // Dont ever forget this. People should keep track of how hadoop conf parameters and make or break a piece of code
+
+
+ try {
+ FileSystem dfs = FileSystem.get(conf);
+ if (dfs.exists(outPath))
+ dfs.delete(outPath, true);
+
+ HashSet<String> countries= new HashSet<String>();
+
+
+ BufferedReader reader = new BufferedReader(new InputStreamReader(
+ new FileInputStream(countriesFile), "UTF-8"));
+ String line = null;
+ while((line = reader.readLine())!=null){
+ countries.add(line);
+ }
+ reader.close();
+
+ DefaultStringifier<HashSet<String>> setStringifier = new DefaultStringifier<HashSet<String>>(conf,GenericsUtil.getClass(countries));
+
+ String countriesString = setStringifier.toString(countries);
+
+ conf.set("wikipedia.countries", countriesString);
+
+ client.setConf(conf);
+ JobClient.runJob(conf);
+
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+}
Propchange: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorDriver.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorMapper.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorMapper.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorMapper.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,124 @@
+package org.apache.mahout.classifier.bayes;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.io.DefaultStringifier;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.util.GenericsUtil;
+import org.apache.lucene.analysis.*;
+import org.apache.lucene.analysis.standard.*;
+import org.apache.commons.lang.StringEscapeUtils;
+import java.io.*;
+import java.util.*;
+
+
+/**
+ *
+ *
+ */
+public class WikipediaDatasetCreatorMapper extends MapReduceBase implements
+ Mapper<Text, Text, Text, Text> {
+
+ static HashSet<String> countries = null;
+
+
+ @SuppressWarnings("deprecation")
+ public void map(Text key, Text value,
+ OutputCollector<Text, Text> output, Reporter reporter)
+ throws IOException {
+ String document = value.toString();
+ Analyzer analyzer = new StandardAnalyzer();
+ StringBuilder contents = new StringBuilder();
+
+
+ HashSet<String> categories = new HashSet<String>(findAllCategories(document));
+
+ String country = getCountry(categories);
+
+ if(country != "Unknown"){
+ document = StringEscapeUtils.unescapeHtml(document.replaceFirst("<text xml:space=\"preserve\">", "").replaceAll("</text>", ""));
+ TokenStream stream = analyzer.tokenStream(country, new StringReader(document));
+ while(true){
+ Token token = stream.next();
+ if(token==null) break;
+ contents.append(token.termText()).append(" ");
+ }
+ //System.err.println(country+"\t"+contents.toString());
+ output.collect(new Text(country.replace(" ","_")), new Text(contents.toString()));
+ }
+ }
+
+ public String getCountry(HashSet<String> categories)
+ {
+ for(String category : categories)
+ {
+ for(String country: countries){
+ if(category.indexOf(country)!=-1){
+ return country;
+
+ }
+ }
+ }
+ return "Unknown";
+ }
+
+ public List<String> findAllCategories(String document){
+ List<String> categories = new ArrayList<String>();
+ int startIndex = 0;
+ int categoryIndex = -1;
+
+ while((categoryIndex = document.indexOf("[[Category:", startIndex))!=-1)
+ {
+ categoryIndex+=11;
+ int endIndex = document.indexOf("]]", categoryIndex);
+ if(endIndex>=document.length() || endIndex < 0) break;
+ String category = document.substring(categoryIndex, endIndex);
+ categories.add(category);
+ startIndex = endIndex;
+ }
+
+ return categories;
+ }
+
+ @Override
+ public void configure(JobConf job) {
+ try
+ {
+ if(countries ==null){
+ countries = new HashSet<String>();
+
+ DefaultStringifier<HashSet<String>> setStringifier = new DefaultStringifier<HashSet<String>>(job,GenericsUtil.getClass(countries));
+
+ String countriesString = setStringifier.toString(countries);
+ countriesString = job.get("wikipedia.countries", countriesString);
+
+ countries = setStringifier.fromString(countriesString);
+
+ }
+ }
+ catch(IOException ex){
+
+ ex.printStackTrace();
+ }
+ }
+}
Propchange: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorMapper.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorOutputFormat.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorOutputFormat.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorOutputFormat.java (added)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorOutputFormat.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,31 @@
+package org.apache.mahout.classifier.bayes;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.lib.MultipleTextOutputFormat;
+
+/**
+ * This class extends the MultipleOutputFormat, allowing to write the output data to different output files in sequence file output format.
+ */
+public class WikipediaDatasetCreatorOutputFormat extends MultipleTextOutputFormat<Text, Text> {
+ protected String generateFileNameForKeyValue(Text key, Text v, String name) {
+
+ return key.toString() + ".txt";
+ }
+}
+
Propchange: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorOutputFormat.java
------------------------------------------------------------------------------
svn:eol-style = native