You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2008/08/19 14:55:48 UTC
svn commit: r687042 [1/4] - in /lucene/mahout/trunk: core/
core/src/main/java/org/apache/mahout/classifier/
core/src/main/java/org/apache/mahout/classifier/bayes/
core/src/main/java/org/apache/mahout/classifier/bayes/common/
core/src/main/java/org/apac...
Author: gsingers
Date: Tue Aug 19 05:55:45 2008
New Revision: 687042
URL: http://svn.apache.org/viewvc?rev=687042&view=rev
Log:
MAHOUT-60: First implementation of Naive Bayes and Complementary Naive Bayes
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/BayesFileFormatter.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/Classify.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesClassifier.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesDriver.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesModel.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerDriver.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerMapper.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerReducer.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureDriver.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureMapper.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureOutputFormat.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureReducer.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesTfIdfDriver.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesTfIdfMapper.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesTfIdfOutputFormat.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesTfIdfReducer.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesWeightSummerDriver.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesWeightSummerMapper.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesWeightSummerOutputFormat.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesWeightSummerReducer.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/io/
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/io/SequenceFileModelReader.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/package.html (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesClassifier.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesDriver.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesModel.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesNormalizedWeightDriver.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesNormalizedWeightMapper.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesNormalizedWeightReducer.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaDriver.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaMapper.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerDriver.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerMapper.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaNormalizerReducer.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/cbayes/CBayesThetaReducer.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Classifier.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Model.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/Summarizable.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableFloat.java (with props)
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/utils/UpdatableLong.java (with props)
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesClassifierTest.java (with props)
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/BayesFileFormatterTest.java (with props)
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/bayes/CBayesClassifierTest.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/PrepareTwentyNewsgroups.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TestClassifier.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/TrainClassifier.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreator.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorDriver.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorMapper.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorOutputFormat.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaDatasetCreatorReducer.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/WikipediaXmlSplitter.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/XmlInputFormat.java (with props)
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/package.html (with props)
Modified:
lucene/mahout/trunk/core/pom.xml
Modified: lucene/mahout/trunk/core/pom.xml
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/pom.xml?rev=687042&r1=687041&r2=687042&view=diff
==============================================================================
--- lucene/mahout/trunk/core/pom.xml (original)
+++ lucene/mahout/trunk/core/pom.xml Tue Aug 19 05:55:45 2008
@@ -301,8 +301,21 @@
<artifactId>xstream</artifactId>
<version>1.2.1</version>
</dependency>
-
-
+ <dependency>
+ <groupId>org.apache.lucene</groupId>
+ <artifactId>lucene-analyzers</artifactId>
+ <version>2.3.2</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.lucene</groupId>
+ <artifactId>lucene-core</artifactId>
+ <version>2.3.2</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-cli</groupId>
+ <artifactId>commons-cli</artifactId>
+ <version>2.0-SNAPSHOT</version>
+ </dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/BayesFileFormatter.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/BayesFileFormatter.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/BayesFileFormatter.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/BayesFileFormatter.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,297 @@
+package org.apache.mahout.classifier;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.CharArraySet;
+import org.apache.lucene.analysis.Token;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.ParseException;
+import org.apache.commons.cli.PosixParser;
+import org.apache.commons.cli.Option;
+import org.apache.commons.cli.OptionBuilder;
+
+import java.io.*;
+import java.nio.charset.Charset;
+import java.util.*;
+
+/**
+ * Flatten a file into format that can be read by the Bayes M/R job. <p/> One
+ * document per line, first token is the label followed by a tab, rest of the
+ * line are the terms.
+ */
+public class BayesFileFormatter {
+ private static String LINE_SEP = System.getProperty("line.separator");
+
+ /**
+ * Collapse all the files in the inputDir into a single file in the proper
+ * Bayes format, 1 document per line
+ *
+ * @param label The label
+ * @param analyzer The analyzer to use
+ * @param inputDir The input Directory
+ * @param charset The charset of the input files
+ * @param outputFile The file to collapse to
+ * @throws java.io.IOException
+ */
+ public static void collapse(String label, Analyzer analyzer, File inputDir,
+ Charset charset, File outputFile) throws IOException {
+ Writer writer = new OutputStreamWriter(new FileOutputStream(outputFile),
+ charset);
+ inputDir.listFiles(new FileProcessor(label, analyzer, charset, writer));
+ writer.close();
+
+ }
+
+ /**
+ * Write the input files to the outdir, one output file per input file
+ *
+ * @param label The label of the file
+ * @param analyzer The analyzer to use
+ * @param input The input file or directory. May not be null
+ * @param charset The Character set of the input files
+ * @param outDir The output directory. Files will be written there with the
+ * same name as the input file
+ * @throws IOException
+ */
+ public static void format(String label, Analyzer analyzer, File input,
+ Charset charset, File outDir) throws IOException {
+ if (input.isDirectory() == false) {
+ Writer writer = new OutputStreamWriter(new FileOutputStream(new File(
+ outDir, input.getName())), charset);
+ writeFile(label, analyzer, new InputStreamReader(new FileInputStream(
+ input), charset), writer);
+ writer.close();
+ } else {
+ input.listFiles(new FileProcessor(label, analyzer, charset, outDir));
+ }
+ }
+
+ /**
+ * Hack the FileFilter mechanism so that we don't get stuck on large
+ * directories and don't have to loop the list twice
+ */
+ private static class FileProcessor implements FileFilter {
+ private String label;
+
+ private Analyzer analyzer;
+
+ private File outputDir;
+
+ private Charset charset;
+
+ private Writer writer;
+
+ /**
+ * Use this when you want to collapse all files to a single file
+ *
+ * @param label The label
+ * @param analyzer
+ * @param charset
+ * @param writer must not be null and will not be closed
+ */
+ private FileProcessor(String label, Analyzer analyzer, Charset charset,
+ Writer writer) {
+ this.label = label;
+ this.analyzer = analyzer;
+ this.charset = charset;
+ this.writer = writer;
+ }
+
+ /**
+ * Use this when you want a writer per file
+ *
+ * @param label
+ * @param analyzer
+ * @param charset
+ * @param outputDir must not be null.
+ */
+ private FileProcessor(String label, Analyzer analyzer, Charset charset,
+ File outputDir) {
+ this.label = label;
+ this.analyzer = analyzer;
+ this.charset = charset;
+ this.outputDir = outputDir;
+ }
+
+ public boolean accept(File file) {
+ if (file.isFile() == true) {
+ try {
+ Writer theWriter;
+ if (writer == null) {
+ theWriter = new OutputStreamWriter(new FileOutputStream(new File(
+ outputDir, file.getName())), charset);
+ } else {
+ theWriter = writer;
+ }
+ writeFile(label, analyzer, new InputStreamReader(new FileInputStream(
+ file), charset), theWriter);
+ if (writer == null) {
+ theWriter.close();// we are writing a single file per input file
+ } else {
+ // just write a new line
+ theWriter.write(LINE_SEP);
+
+ }
+
+ } catch (IOException e) {
+ // TODO: report failed files instead of throwing exception
+ throw new RuntimeException(e);
+ }
+ } else {
+ file.listFiles(this);
+ }
+ return false;
+ }
+ }
+
+ /**
+ * Write the tokens and the label from the Reader to the writer
+ *
+ * @param label The label
+ * @param analyzer The analyzer to use
+ * @param reader The reader to pass to the Analyzer
+ * @param writer The Writer, is not closed by this method
+ * @throws java.io.IOException if there was a problem w/ the reader
+ */
+ public static void writeFile(String label, Analyzer analyzer, Reader reader,
+ Writer writer) throws IOException {
+ TokenStream ts = analyzer.tokenStream(label, reader);
+ writer.write(label);
+ writer.write('\t'); // edit: Inorder to match Hadoop standard
+ // TextInputFormat
+ Token token = new Token();
+ CharArraySet seen = new CharArraySet(256, false);
+ long numTokens = 0;
+ while ((token = ts.next(token)) != null) {
+ char[] termBuffer = token.termBuffer();
+ int termLen = token.termLength();
+
+ writer.write(termBuffer, 0, termLen);
+ writer.write(' ');
+ char[] tmp = new char[termLen];
+ System.arraycopy(termBuffer, 0, tmp, 0, termLen);
+ seen.add(tmp);// do this b/c CharArraySet doesn't allow offsets
+ }
+ numTokens++;
+
+ }
+
+ /**
+ * Convert a Reader to a vector
+ *
+ * @param analyzer The Analyzer to use
+ * @param reader The reader to feed to the Analyzer
+ * @return An array of unique tokens
+ * @throws IOException
+ */
+ public static String[] readerToDocument(Analyzer analyzer, Reader reader)
+ throws IOException {
+ TokenStream ts = analyzer.tokenStream("", reader);
+
+ Token token = null;
+ List<String> coll = new ArrayList<String>();
+ while ((token = ts.next()) != null) {
+ char[] termBuffer = token.termBuffer();
+ int termLen = token.termLength();
+ String val = new String(termBuffer, 0, termLen);
+ coll.add(val);
+ }
+ return (String[]) coll.toArray(new String[coll.size()]);
+ }
+
+ /**
+ * Run the FileFormatter
+ *
+ * @param args The input args. Run with -h to see the help
+ * @throws ClassNotFoundException if the Analyzer can't be found
+ * @throws IllegalAccessException if the Analyzer can't be constructed
+ * @throws InstantiationException if the Analyzer can't be constructed
+ * @throws IOException if the files can't be dealt with properly
+ */
+ @SuppressWarnings("static-access")
+ public static void main(String[] args) throws ClassNotFoundException,
+ IllegalAccessException, InstantiationException, IOException {
+ Options options = new Options();
+ Option inputOpt = OptionBuilder.withLongOpt("input").isRequired().hasArg()
+ .withDescription("The input file").create("i");
+ options.addOption(inputOpt);
+ Option outputOpt = OptionBuilder.withLongOpt("output").isRequired()
+ .hasArg().withDescription("The output file").create("o");
+ options.addOption(outputOpt);
+ Option labelOpt = OptionBuilder.withLongOpt("label").isRequired().hasArg()
+ .withDescription("The label of the file").create("l");
+ options.addOption(labelOpt);
+ Option analyzerOpt = OptionBuilder
+ .withLongOpt("analyzer")
+ .hasArg()
+ .withDescription(
+ "The fully qualified class name of the analyzer to use. Must have a no-arg constructor. Default is the StandardAnalyzer")
+ .create("a");
+ options.addOption(analyzerOpt);
+ Option charsetOpt = OptionBuilder.withLongOpt("charset").hasArg()
+ .withDescription("The character encoding of the input file")
+ .create("c");
+ options.addOption(charsetOpt);
+ Option collapseOpt = OptionBuilder.withLongOpt("collapse").hasArg()
+ .withDescription(
+ "Collapse a whole directory to a single file, one doc per line")
+ .create("p");
+ options.addOption(collapseOpt);
+ Option helpOpt = OptionBuilder.withLongOpt("help").withDescription(
+ "Print out help info").create("h");
+ options.addOption(helpOpt);
+ CommandLine cmdLine = null;
+ try {
+ PosixParser parser = new PosixParser();
+ cmdLine = parser.parse(options, args);
+ if (cmdLine.hasOption(helpOpt.getOpt())) {
+ System.out.println("Options: " + options);
+ System.exit(0);
+ }
+ File input = new File(cmdLine.getOptionValue(inputOpt.getOpt()));
+ File output = new File(cmdLine.getOptionValue(outputOpt.getOpt()));
+ String label = cmdLine.getOptionValue(labelOpt.getOpt());
+ Analyzer analyzer = null;
+ if (cmdLine.hasOption(analyzerOpt.getOpt())) {
+ analyzer = (Analyzer) Class.forName(
+ cmdLine.getOptionValue(analyzerOpt.getOpt())).newInstance();
+ } else {
+ analyzer = new StandardAnalyzer();
+ }
+ Charset charset = Charset.forName("UTF-8");
+ if (cmdLine.hasOption(charsetOpt.getOpt())) {
+ charset = Charset.forName(cmdLine.getOptionValue(charsetOpt.getOpt()));
+ }
+ boolean collapse = cmdLine.hasOption(collapseOpt.getOpt());
+
+ if (collapse == true) {
+ collapse(label, analyzer, input, charset, output);
+ } else {
+ format(label, analyzer, input, charset, output);
+ }
+
+ } catch (ParseException exp) {
+ exp.printStackTrace();
+ System.out.println("Options: " + options);
+ }
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/BayesFileFormatter.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,61 @@
+package org.apache.mahout.classifier;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Result of a Document Classification. The label and the associated score(Usually probabilty)
+ */
+public class ClassifierResult {
+ private String label;
+ private float score;
+
+ public ClassifierResult() {
+ }
+
+ public ClassifierResult(String label, float score) {
+ this.label = label;
+ this.score = score;
+ }
+
+ public ClassifierResult(String label) {
+ this.label = label;
+ }
+
+ public String getLabel() {
+ return label;
+ }
+
+ public float getScore() {
+ return score;
+ }
+
+ public void setLabel(String label) {
+ this.label = label;
+ }
+
+ public void setScore(float score) {
+ this.score = score;
+ }
+
+ public String toString() {
+ return "ClassifierResult{" +
+ "category='" + label + '\'' +
+ ", score=" + score +
+ '}';
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/Classify.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/Classify.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/Classify.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/Classify.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,156 @@
+package org.apache.mahout.classifier;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.PosixParser;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.ParseException;
+import org.apache.commons.cli.OptionBuilder;
+import org.apache.commons.cli.Option;
+import org.apache.mahout.classifier.bayes.BayesClassifier;
+import org.apache.mahout.classifier.bayes.BayesModel;
+import org.apache.mahout.classifier.bayes.io.SequenceFileModelReader;
+import org.apache.mahout.classifier.cbayes.CBayesClassifier;
+import org.apache.mahout.classifier.cbayes.CBayesModel;
+import org.apache.mahout.common.Classifier;
+import org.apache.mahout.common.Model;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.analysis.Analyzer;
+
+import java.io.IOException;
+import java.io.File;
+import java.io.InputStreamReader;
+import java.io.FileInputStream;
+import java.util.*;
+
+
+/**
+ *
+ *
+ **/
+public class Classify {
+
+ @SuppressWarnings({ "static-access", "unchecked" })
+ public static void main(String[] args) throws IOException, ClassNotFoundException, IllegalAccessException, InstantiationException {
+ Options options = new Options();
+ Option pathOpt = OptionBuilder.withLongOpt("path").isRequired().hasArg().withDescription("The local file system path").create("p");
+ options.addOption(pathOpt);
+ Option classifyOpt = OptionBuilder.withLongOpt("classify").isRequired().hasArg().withDescription("The document to classify").create("c");
+ options.addOption(classifyOpt);
+ Option encodingOpt = OptionBuilder.withLongOpt("encoding").hasArg().withDescription("The file encoding. defaults to UTF-8").create("e");
+ options.addOption(encodingOpt);
+ Option analyzerOpt = OptionBuilder.withLongOpt("analyzer").hasArg().withDescription("The Analyzer to use").create("a");
+ options.addOption(analyzerOpt);
+ Option defaultCatOpt = OptionBuilder.withLongOpt("defaultCat").hasArg().withDescription("The default category").create("d");
+ options.addOption(defaultCatOpt);
+ Option gramSizeOpt = OptionBuilder.withLongOpt("gramSize").hasArg().withDescription("Size of the n-gram").create("ng");
+ options.addOption(gramSizeOpt);
+ Option typeOpt = OptionBuilder.withLongOpt("classifierType").isRequired().hasArg().withDescription("Type of classifier").create("type");
+ options.addOption(typeOpt);
+
+
+ CommandLine cmdLine = null;
+ try {
+ PosixParser parser = new PosixParser();
+ cmdLine = parser.parse(options, args);
+ SequenceFileModelReader reader = new SequenceFileModelReader();
+ JobConf conf = new JobConf(Classify.class);
+
+
+ Map<String, Path> modelPaths = new HashMap<String, Path>();
+ String modelBasePath = cmdLine.getOptionValue(pathOpt.getOpt());
+ modelPaths.put("sigma_j", new Path(modelBasePath + "/trainer-weights/Sigma_j/part-*"));
+ modelPaths.put("sigma_k", new Path(modelBasePath + "/trainer-weights/Sigma_k/part-*"));
+ modelPaths.put("sigma_kSigma_j", new Path(modelBasePath + "/trainer-weights/Sigma_kSigma_j/part-*"));
+ modelPaths.put("thetaNormalizer", new Path(modelBasePath + "/trainer-thetaNormalizer/part-*"));
+ modelPaths.put("weight", new Path(modelBasePath + "/trainer-tfIdf/trainer-tfIdf/part-*"));
+
+ FileSystem fs = FileSystem.get(conf);
+
+ System.out.println("Loading model from: " + modelPaths);
+
+ Model model = null;
+ Classifier classifier = null;
+
+ String classifierType = cmdLine.getOptionValue(typeOpt.getOpt());
+
+ if (classifierType.equalsIgnoreCase("bayes")) {
+ System.out.println("Testing Bayes Classifier");
+ model = new BayesModel();
+ classifier = new BayesClassifier();
+ } else if (classifierType.equalsIgnoreCase("cbayes")) {
+ System.out.println("Testing Complementary Bayes Classifier");
+ model = new CBayesModel();
+ classifier = new CBayesClassifier();
+ }
+
+ model = reader.loadModel(model, fs, modelPaths, conf);
+
+ System.out.println("Done loading model: # labels: "
+ + model.getLabels().size());
+
+ System.out.println("Done generating Model ");
+
+
+ String defaultCat = "unknown";
+ if (cmdLine.hasOption(defaultCatOpt.getOpt())) {
+ defaultCat = cmdLine.getOptionValue(defaultCatOpt.getOpt());
+ }
+ File docPath = new File(cmdLine.getOptionValue(classifyOpt.getOpt()));
+ String encoding = "UTF-8";
+ if (cmdLine.hasOption(encodingOpt.getOpt())) {
+ encoding = cmdLine.getOptionValue(encodingOpt.getOpt());
+ }
+ Analyzer analyzer = null;
+ if (cmdLine.hasOption(analyzerOpt.getOpt())) {
+ String className = cmdLine.getOptionValue(analyzerOpt.getOpt());
+ Class clazz = Class.forName(className);
+ analyzer = (Analyzer) clazz.newInstance();
+ }
+ if (analyzer == null) {
+ analyzer = new StandardAnalyzer();
+ }
+
+ int gramSize = 1;
+ if (cmdLine.hasOption(gramSizeOpt.getOpt())) {
+ gramSize = Integer.parseInt(cmdLine
+ .getOptionValue(gramSizeOpt.getOpt()));
+
+ }
+
+ System.out.println("Converting input document to proper format");
+ String [] document = BayesFileFormatter.readerToDocument(analyzer, new InputStreamReader(new FileInputStream(docPath), encoding));
+ StringBuilder line = new StringBuilder();
+ for(String token : document)
+ {
+ line.append(token).append(' ');
+ }
+ List<String> doc = Model.generateNGramsWithoutLabel(line.toString(), gramSize) ;
+ System.out.println("Done converting");
+ System.out.println("Classifying document: " + docPath);
+ ClassifierResult category = classifier.classify(model, doc.toArray(new String[]{}), defaultCat);
+ System.out.println("Category for " + docPath + " is " + category);
+ }
+ catch (ParseException exp) {
+ exp.printStackTrace(System.err);
+ }
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/Classify.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,183 @@
+package org.apache.mahout.classifier;
+
+import java.util.*;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.mahout.common.Summarizable;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+/**
+ * The ConfusionMatrix Class stores the result of Classification of a Test Dataset.
+ */
+public class ConfusionMatrix implements Summarizable {
+
+ Collection<String> labels = new ArrayList<String>();
+
+ Map<String, Integer> labelMap = new HashMap<String, Integer>();
+
+ int[][] confusionMatrix = null;
+
+ public int[][] getConfusionMatrix() {
+ return confusionMatrix;
+ }
+
+ public Collection<String> getLabels() {
+ return labels;
+ }
+
+ public float getAccuracy(String label){
+ int labelId = labelMap.get(label).intValue();
+ int labelTotal = 0;
+ int correct = 0;
+ for(int i = 0 ;i < labels.size() ;i++){
+ labelTotal+= confusionMatrix[labelId][i];
+ if(i==labelId) correct = confusionMatrix[labelId][i];
+ }
+ return (float)100 * correct / labelTotal;
+ }
+
+ public int getCorrect(String label){
+ int labelId = labelMap.get(label).intValue();
+ int correct = 0;
+ correct = confusionMatrix[labelId][labelId];
+
+ return correct;
+ }
+
+ public float getTotal(String label){
+ int labelId = labelMap.get(label).intValue();
+ int labelTotal = 0;
+ for(int i = 0 ;i < labels.size() ;i++){
+ labelTotal+= confusionMatrix[labelId][i];
+ }
+ return labelTotal;
+ }
+
+
+ public ConfusionMatrix(Collection<String> labels) {
+ this.labels = labels;
+ confusionMatrix = new int[labels.size()][labels.size()];
+ for (String label : labels) {
+ labelMap.put(label, labelMap.size());
+ }
+ }
+
+ public void addInstance(String correctLabel, ClassifierResult classifiedResult) throws Exception{
+ incrementCount(correctLabel, classifiedResult.getLabel());
+ }
+
+ public void addInstance(String correctLabel, String classifiedLabel) throws Exception{
+ incrementCount(correctLabel, classifiedLabel);
+ }
+
+ public int getCount(String correctLabel, String classifiedLabel)
+ throws Exception {
+ if (this.getLabels().contains(correctLabel)
+ && this.getLabels().contains(classifiedLabel) == false) {
+ //System.err.println(correctLabel + " " + classifiedLabel);
+ throw new Exception("Label not found " +correctLabel + " " +classifiedLabel );
+ }
+ int correctId = labelMap.get(correctLabel).intValue();
+ int classifiedId = labelMap.get(classifiedLabel).intValue();
+ return confusionMatrix[correctId][classifiedId];
+ }
+
+ public void putCount(String correctLabel, String classifiedLabel, int count)
+ throws Exception {
+ if (this.getLabels().contains(correctLabel)
+ && this.getLabels().contains(classifiedLabel) == false) {
+ throw new Exception("Label not found");
+ }
+ int correctId = labelMap.get(correctLabel).intValue();
+ int classifiedId = labelMap.get(classifiedLabel).intValue();
+ confusionMatrix[correctId][classifiedId] = count;
+ }
+
+ public void incrementCount(String correctLabel, String classifiedLabel,
+ int count) throws Exception {
+ putCount(correctLabel, classifiedLabel, count
+ + getCount(correctLabel, classifiedLabel));
+ }
+
+ public void incrementCount(String correctLabel, String classifiedLabel)
+ throws Exception {
+ incrementCount(correctLabel, classifiedLabel, 1);
+ }
+
+ public ConfusionMatrix Merge(ConfusionMatrix b) throws Exception {
+ if (this.getLabels().size() != b.getLabels().size())
+ throw new Exception("The Labels do not Match");
+
+ if (this.getLabels().containsAll(b.getLabels()))
+ ;
+ for (String correctLabel : this.labels) {
+ for (String classifiedLabel : this.labels) {
+ incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel,
+ classifiedLabel));
+ }
+ }
+ return this;
+ }
+
+ public String summarize() throws Exception {
+ StringBuilder returnString = new StringBuilder();
+ returnString
+ .append("=======================================================\n");
+ returnString.append("Confusion Matrix\n");
+ returnString
+ .append("-------------------------------------------------------\n");
+
+ for (String correctLabel : this.labels) {
+ returnString.append(StringUtils.rightPad(getSmallLabel(labelMap.get(
+ correctLabel).intValue()), 5)
+ + "\t");
+ }
+
+ returnString.append("<--Classified as\n");
+
+ for (String correctLabel : this.labels) {
+ Integer labelTotal = 0;
+ for (String classifiedLabel : this.labels) {
+ returnString.append(StringUtils.rightPad(new Integer(getCount(
+ correctLabel, classifiedLabel)).toString(), 5)
+ + "\t");
+ labelTotal+=getCount(correctLabel, classifiedLabel);
+ }
+ returnString.append(" | "
+ + StringUtils.rightPad(labelTotal.toString(), 6)
+ + "\t"
+ + StringUtils.rightPad(getSmallLabel(labelMap.get(correctLabel)
+ .intValue()), 5) + " = " + correctLabel + "\n");
+ }
+ returnString.append("\n");
+ return returnString.toString();
+ }
+
+ String getSmallLabel(int i) {
+ int val = i;
+ StringBuilder returnString = new StringBuilder();
+ do{
+ int n = val % 26;
+ int c = 'a';
+ returnString.insert(0, (char)(c + n));
+ val /= 26;
+ }while(val>0);
+ return returnString.toString();
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,101 @@
+package org.apache.mahout.classifier;
+
+import java.text.*;
+import java.util.*;
+
+import org.apache.commons.lang.StringUtils;
+import org.apache.mahout.common.Summarizable;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+public class ResultAnalyzer implements Summarizable {
+ ConfusionMatrix confusionMatrix = null;
+
+ /*
+ * === Summary ===
+ *
+ * Correctly Classified Instances 635 92.9722 % Incorrectly Classified
+ * Instances 48 7.0278 % Kappa statistic 0.923 Mean absolute error 0.0096 Root
+ * mean squared error 0.0817 Relative absolute error 9.9344 % Root relative
+ * squared error 37.2742 % Total Number of Instances 683
+ */
+ int correctlyClassified = 0;
+
+ int incorrectlyClassified = 0;
+
+ public ResultAnalyzer(Collection<String> labelSet) {
+ confusionMatrix = new ConfusionMatrix(labelSet);
+ }
+
+ public ConfusionMatrix getConfusionMatrix(){
+ return this.confusionMatrix;
+ }
+ public void addInstance(String correctLabel, ClassifierResult classifiedResult)
+ throws Exception {
+ if (correctLabel.equals(classifiedResult.getLabel()))
+ correctlyClassified++;
+ else
+ incorrectlyClassified++;
+ confusionMatrix.addInstance(correctLabel, classifiedResult);
+ }
+
+ public String toString() {
+ return "";
+ }
+
+ public String summarize() throws Exception {
+ StringBuilder returnString = new StringBuilder();
+
+ returnString
+ .append("=======================================================\n");
+ returnString.append("Summary\n");
+ returnString
+ .append("-------------------------------------------------------\n");
+ int totalClassified = correctlyClassified + incorrectlyClassified;
+ double percentageCorrect = (double) 100 * correctlyClassified
+ / (totalClassified);
+ double percentageIncorrect = (double) 100 * incorrectlyClassified
+ / (totalClassified);
+ DecimalFormat decimalFormatter = new DecimalFormat("0.####");
+
+ returnString.append(StringUtils.rightPad("Correctly Classified Instances",
+ 40)
+ + ": "
+ + StringUtils.leftPad(new Integer(correctlyClassified).toString(), 10)
+ + "\t"
+ + StringUtils.leftPad(decimalFormatter.format(percentageCorrect), 10)
+ + "%\n");
+ returnString.append(StringUtils.rightPad(
+ "Incorrectly Classified Instances", 40)
+ + ": "
+ + StringUtils
+ .leftPad(new Integer(incorrectlyClassified).toString(), 10)
+ + "\t"
+ + StringUtils.leftPad(decimalFormatter.format(percentageIncorrect), 10)
+ + "%\n");
+ returnString.append(StringUtils.rightPad("Total Classified Instances", 40)
+ + ": "
+ + StringUtils.leftPad(new Integer(totalClassified).toString(), 10)
+ + "\n");
+ returnString.append("\n");
+
+ returnString.append(confusionMatrix.summarize());
+
+ return returnString.toString();
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesClassifier.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesClassifier.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesClassifier.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesClassifier.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,137 @@
+package org.apache.mahout.classifier.bayes;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.util.PriorityQueue;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.common.Classifier;
+import org.apache.mahout.common.Model;
+
+import java.util.Collection;
+import java.util.Enumeration;
+import java.util.Hashtable;
+import java.util.LinkedList;
+
+
+/**
+ * Classifies documents based on a {@link BayesModel}}.
+ */
+public class BayesClassifier implements Classifier {
+
+ /**
+ * Classify the document and return the top <code>numResults</code>
+ *
+ * @param model The model
+ * @param document The document to classify
+ * @param defaultCategory The default category to assign
+ * @param numResults The maximum number of results to return, ranked by score. Ties are broken by comparing the category
+ * @return A Collection of {@link org.apache.mahout.classifier.ClassifierResult}s.
+ */
+ public Collection<ClassifierResult> classify(Model model, String[] document, String defaultCategory, int numResults) {
+ Collection<String> categories = model.getLabels();
+
+ PriorityQueue pq = new ClassifierResultPriorityQueue(numResults);
+ ClassifierResult tmp = null;
+ for (String category : categories){
+ float prob = documentProbability(model, category, document);
+ if (prob < 0) {
+ tmp = new ClassifierResult(category, prob);
+ pq.insert(tmp);
+ }
+ }
+
+ LinkedList<ClassifierResult> result = new LinkedList<ClassifierResult>();
+ while ((tmp = (ClassifierResult) pq.pop()) != null) {
+ result.addLast(tmp);
+ }
+ if (result.isEmpty()){
+ result.add(new ClassifierResult(defaultCategory, 0));
+ }
+ return result;
+ }
+
+ /**
+ * Classify the document according to the {@link org.apache.mahout.common.Model}
+ *
+ * @param model The trained {@link org.apache.mahout.common.Model}
+ * @param document The document to classify
+ * @param defaultCategory The default category to assign if one cannot be determined
+ * @return The single best category
+ */
+ public ClassifierResult classify(Model model, String[] document, String defaultCategory) {
+ ClassifierResult result = new ClassifierResult(defaultCategory);
+ float min = 0.0f;
+ Collection<String> categories = model.getLabels();
+
+ for (String category : categories) {
+ float prob = documentProbability(model, category, document);
+
+ if (prob < min) {
+ min = prob;
+ result.setLabel(category);
+ }
+ }
+ result.setScore(min);
+ return result;
+ }
+
+ /**
+ * Calculate the document probability as the multiplication of the {@link org.apache.mahout.common.Model#FeatureWeight(String, String)} for each word given the label
+ *
+ * @param model The {@link org.apache.mahout.common.Model}
+ * @param label The label to calculate the probability of
+ * @param document The document
+ * @return The probability
+ * @see Model#FeatureWeight(String, String)
+ */
+ public float documentProbability(Model model, String label, String[] document) {
+ float result = 0f;
+ Hashtable<String, Integer> wordList = new Hashtable<String, Integer>(1000);
+ for (String word : document) {
+ if (wordList.containsKey(word)) {
+ Integer count = wordList.get(word);
+ count++;
+ wordList.put(word, count);
+ } else {
+ wordList.put(word, 1);
+ }
+ }
+ for (Enumeration<String> e = wordList.keys(); e.hasMoreElements();) {
+ String word = e.nextElement();
+ Integer count = wordList.get(word);
+ result += count * model.FeatureWeight(label, word);
+ }
+ return result;
+ }
+
+
+ private static class ClassifierResultPriorityQueue extends PriorityQueue {
+
+ private ClassifierResultPriorityQueue(int numResults) {
+ initialize(numResults);
+ }
+
+ protected boolean lessThan(Object a, Object b) {
+ ClassifierResult cr1 = (ClassifierResult) a;
+ ClassifierResult cr2 = (ClassifierResult) b;
+
+ float score1 = cr1.getScore();
+ float score2 = cr2.getScore();
+ return score1 == score2 ? cr1.getLabel().compareTo(cr2.getLabel()) < 0 : score1 < score2;
+ }
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesClassifier.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesDriver.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesDriver.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesDriver.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,126 @@
+package org.apache.mahout.classifier.bayes;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.mahout.classifier.bayes.common.BayesFeatureDriver;
+import org.apache.mahout.classifier.bayes.common.BayesTfIdfDriver;
+import org.apache.mahout.classifier.bayes.common.BayesWeightSummerDriver;
+
+/**
+ * Create and run the Bayes Trainer.
+ *
+ */
+public class BayesDriver {
+ /**
+ * Takes in two arguments:
+ * <ol>
+ * <li>The input {@link org.apache.hadoop.fs.Path} where the input documents
+ * live</li>
+ * <li>The output {@link org.apache.hadoop.fs.Path} where to write the
+ * {@link org.apache.mahout.common.Model} as a
+ * {@link org.apache.hadoop.io.SequenceFile}</li>
+ * </ol>
+ *
+ * @param args The args
+ */
+ public static void main(String[] args) {
+ String input = args[0];
+ String output = args[1];
+
+ runJob(input, output, 1);
+ }
+
+ /**
+ * Run the job
+ *
+ * @param input the input pathname String
+ * @param output the output pathname String
+ *
+ */
+ @SuppressWarnings("deprecation")
+ public static void runJob(String input, String output, int gramSize) {
+ JobConf conf = new JobConf(BayesDriver.class);
+ try {
+ FileSystem dfs = FileSystem.get(conf);
+ Path outPath = new Path(output);
+ if (dfs.exists(outPath))
+ dfs.delete(outPath);
+
+ System.out.println("Reading features...");
+ //Read the features in each document normalized by length of each document
+ BayesFeatureDriver.runJob(input, output, gramSize);
+
+ System.out.println("Calculating Tf-Idf...");
+ //Calculate the TfIdf for each word in each label
+ BayesTfIdfDriver.runJob(input, output);
+
+ System.out.println("Calculating weight sums for labels and features...");
+ //Calculate the Sums of weights for each label, for each feature and for each feature and for each label
+ BayesWeightSummerDriver.runJob(input, output);
+
+ //System.out.println("Calculating the weight of the features of each label in the complement class...");
+ //Calculate the W_ij = log(Theta) for each label, feature. This step actually generates the complement class
+ //CBayesThetaDriver.runJob(input, output);
+
+ System.out.println("Calculating the weight Normalisation factor for each class...");
+ //Calculate the normalization factor Sigma_W_ij for each complement class.
+ BayesThetaNormalizerDriver.runJob(input, output);
+
+ //System.out.println("Calculating the final Weight Normalized Complementary Naive Bayes Model...");
+ //Calculate the normalization factor Sigma_W_ij for each complement class.
+ //CBayesNormalizedWeightDriver.runJob(input, output);
+
+ Path docCountOutPath = new Path(output+ "/trainer-docCount");
+ if (dfs.exists(docCountOutPath))
+ dfs.delete(docCountOutPath, true);
+ Path termDocCountOutPath = new Path(output+ "/trainer-termDocCount");
+ if (dfs.exists(termDocCountOutPath))
+ dfs.delete(termDocCountOutPath, true);
+ Path featureCountOutPath = new Path(output+ "/trainer-featureCount");
+ if (dfs.exists(featureCountOutPath))
+ dfs.delete(featureCountOutPath, true);
+ Path wordFreqOutPath = new Path(output+ "/trainer-wordFreq");
+ if (dfs.exists(wordFreqOutPath))
+ dfs.delete(wordFreqOutPath, true);
+ Path vocabCountPath = new Path(output+ "/trainer-tfIdf/trainer-vocabCount");
+ if (dfs.exists(vocabCountPath))
+ dfs.delete(vocabCountPath, true);
+ /*Path tfIdfOutPath = new Path(output+ "/trainer-tfIdf");
+ if (dfs.exists(tfIdfOutPath))
+ dfs.delete(tfIdfOutPath, true);*/
+ Path vocabCountOutPath = new Path(output+ "/trainer-vocabCount");
+ if (dfs.exists(vocabCountOutPath))
+ dfs.delete(vocabCountOutPath, true);
+ /* Path weightsOutPath = new Path(output+ "/trainer-weights");
+ if (dfs.exists(weightsOutPath))
+ dfs.delete(weightsOutPath, true);*/
+ /*Path thetaOutPath = new Path(output+ "/trainer-theta");
+ if (dfs.exists(thetaOutPath))
+ dfs.delete(thetaOutPath, true);*/
+ /*Path thetaNormalizerOutPath = new Path(output+ "/trainer-thetaNormalizer");
+ if (dfs.exists(thetaNormalizerOutPath))
+ dfs.delete(thetaNormalizerOutPath, true);*/
+
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesDriver.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesModel.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesModel.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesModel.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,160 @@
+package org.apache.mahout.classifier.bayes;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import org.apache.mahout.common.Model;
+
+import java.util.Map;
+
+
+/**
+ *
+ *
+ */
+public class BayesModel extends Model {
+
+ @Override
+ protected float getWeight(Integer label, Integer feature) {
+ float result = 0.0f;
+ Map<Integer, Float> featureWeights = featureLabelWeights.get(feature);
+
+
+ if (featureWeights.containsKey(label)) {
+ result = featureWeights.get(label).floatValue();
+ }
+
+ float vocabCount = featureList.size();
+ float sumLabelWeight = getSumLabelWeight(label);
+
+
+ float numerator = result + alpha_i;
+ float denominator =(sumLabelWeight + vocabCount);
+
+ float weight = new Double(Math.log(numerator /denominator)).floatValue();
+ result = (-1.0f * (weight));
+
+ return result;
+ }
+
+ @Override
+ protected float getWeightUnprocessed(Integer label, Integer feature) {
+ float result = 0.0f;
+ Map<Integer, Float> featureWeights = featureLabelWeights.get(feature);
+
+ if (featureWeights.containsKey(label)) {
+ result = featureWeights.get(label).floatValue();
+ } else {
+ result = 0;
+ }
+ return result;
+ }
+
+ @Override
+ public void InitializeNormalizer() {
+ float perLabelWeightSumNormalisationFactor = Float.MAX_VALUE;
+
+ System.out.println(thetaNormalizer);
+ for (Integer label : thetaNormalizer.keySet()) {
+ float Sigma_W_ij = thetaNormalizer.get(label);
+ if (perLabelWeightSumNormalisationFactor > Math.abs(Sigma_W_ij)) {
+ perLabelWeightSumNormalisationFactor = Math.abs(Sigma_W_ij);
+ }
+ }
+
+ for (Integer label : thetaNormalizer.keySet()) {
+ float Sigma_W_ij = thetaNormalizer.get(label);
+ thetaNormalizer.put(label, Sigma_W_ij
+ / perLabelWeightSumNormalisationFactor);
+ }
+ System.out.println(thetaNormalizer);
+ }
+
+ @Override
+ public void GenerateModel() {
+ try {
+ float vocabCount = featureList.size();
+
+ float[] perLabelThetaNormalizer = new float[labelList.size()];
+
+ float perLabelWeightSumNormalisationFactor = Float.MAX_VALUE;
+
+ for (int feature = 0, maxFeatures = featureList.size(); feature < maxFeatures; feature++) {
+ for (int label = 0, maxLabels = labelList.size(); label < maxLabels; label++) {
+
+ float D_ij = getWeightUnprocessed(label, feature);
+ float sumLabelWeight = getSumLabelWeight(label);
+ float sigma_j = getSumFeatureWeight(feature);
+
+ float numerator = D_ij + alpha_i;
+ float denominator = sumLabelWeight + vocabCount;
+
+ Float weight = (float) Math.log(numerator / denominator);
+
+ if (D_ij != 0)
+ setWeight(label, feature, weight);
+
+ perLabelThetaNormalizer[label] += weight;
+
+ }
+ }
+ System.out.println("Normalizing Weights");
+ for (int label = 0, maxLabels = labelList.size(); label < maxLabels; label++) {
+ float Sigma_W_ij = perLabelThetaNormalizer[label];
+ if (perLabelWeightSumNormalisationFactor > Math.abs(Sigma_W_ij)) {
+ perLabelWeightSumNormalisationFactor = Math.abs(Sigma_W_ij);
+ }
+ }
+
+ for (int label = 0, maxLabels = labelList.size(); label < maxLabels; label++) {
+ float Sigma_W_ij = perLabelThetaNormalizer[label];
+ perLabelThetaNormalizer[label] = Sigma_W_ij
+ / perLabelWeightSumNormalisationFactor;
+ }
+
+ for (int feature = 0, maxFeatures = featureList.size(); feature < maxFeatures; feature++) {
+ for (int label = 0, maxLabels = labelList.size(); label < maxLabels; label++) {
+ float W_ij = getWeightUnprocessed(label, feature);
+ if (W_ij == 0)
+ continue;
+ float Sigma_W_ij = perLabelThetaNormalizer[label];
+ float normalizedWeight = -1.0f * (W_ij / Sigma_W_ij);
+ setWeight(label, feature, normalizedWeight);
+ }
+ }
+ } catch (Exception e) {
+ // TODO Auto-generated catch block
+ e.printStackTrace();
+ }
+ }
+
+
+ /**
+ * Get the weighted probability of the feature.
+ *
+ * @param label The label of the feature
+ * @param feature The feature to calc. the prob. for
+ * @return The weighted probability
+ */
+ @Override
+ public float FeatureWeight(Integer label, Integer feature) {
+ float weight = getWeight(label, feature);
+ return weight;
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesModel.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerDriver.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerDriver.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerDriver.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,126 @@
+package org.apache.mahout.classifier.bayes;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.util.HashMap;
+
+import org.apache.hadoop.mapred.JobClient;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.SequenceFileInputFormat;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.util.GenericsUtil;
+import org.apache.hadoop.io.DefaultStringifier;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.mahout.classifier.bayes.io.SequenceFileModelReader;
+
+
+/**
+ * Create and run the Bayes Theta Normalization Step.
+ *
+ **/
+public class BayesThetaNormalizerDriver {
+ /**
+ * Takes in two arguments:
+ * <ol>
+ * <li>The input {@link org.apache.hadoop.fs.Path} where the input documents live</li>
+ * <li>The output {@link org.apache.hadoop.fs.Path} where to write the the interim filesas a {@link org.apache.hadoop.io.SequenceFile}</li>
+ * </ol>
+ * @param args The args
+ */
+ public static void main(String[] args) {
+ String input = args[0];
+ String output = args[1];
+
+ runJob(input, output);
+ }
+
+ /**
+ * Run the job
+ *
+ * @param input the input pathname String
+ * @param output the output pathname String
+
+ */
+ public static void runJob(String input, String output) {
+ JobClient client = new JobClient();
+ JobConf conf = new JobConf(BayesThetaNormalizerDriver.class);
+
+
+ conf.setOutputKeyClass(Text.class);
+ conf.setOutputValueClass(FloatWritable.class);
+ SequenceFileInputFormat.addInputPath(conf, new Path(output + "/trainer-tfIdf/trainer-tfIdf"));
+ Path outPath = new Path(output + "/trainer-thetaNormalizer");
+ SequenceFileOutputFormat.setOutputPath(conf, outPath);
+ conf.setNumMapTasks(100);
+ //conf.setNumReduceTasks(1);
+ conf.setMapperClass(BayesThetaNormalizerMapper.class);
+ conf.setInputFormat(SequenceFileInputFormat.class);
+ conf.setCombinerClass(BayesThetaNormalizerReducer.class);
+ conf.setReducerClass(BayesThetaNormalizerReducer.class);
+ conf.setOutputFormat(SequenceFileOutputFormat.class);
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,org.apache.hadoop.io.serializer.WritableSerialization"); // Dont ever forget this. People should keep track of how hadoop conf parameters and make or break a piece of code
+
+ try {
+ FileSystem dfs = FileSystem.get(conf);
+ if (dfs.exists(outPath))
+ dfs.delete(outPath, true);
+
+ SequenceFileModelReader reader = new SequenceFileModelReader();
+
+ Path Sigma_kFiles = new Path(output+"/trainer-weights/Sigma_k/*");
+ HashMap<String,Float> labelWeightSum= reader.readLabelSums(dfs, Sigma_kFiles, conf);
+ DefaultStringifier<HashMap<String,Float>> mapStringifier = new DefaultStringifier<HashMap<String,Float>>(conf, GenericsUtil.getClass(labelWeightSum));
+ String labelWeightSumString = mapStringifier.toString(labelWeightSum);
+
+ System.out.println("Sigma_k for Each Label");
+ HashMap<String,Float> c = mapStringifier.fromString(labelWeightSumString);
+ System.out.println(c);
+ conf.set("cnaivebayes.sigma_k", labelWeightSumString);
+
+
+ Path sigma_kSigma_jFile = new Path(output+"/trainer-weights/Sigma_kSigma_j/*");
+ Float sigma_jSigma_k = reader.readSigma_jSigma_k(dfs, sigma_kSigma_jFile, conf);
+ DefaultStringifier<Float> floatStringifier = new DefaultStringifier<Float>(conf, Float.class);
+ String sigma_jSigma_kString = floatStringifier.toString(sigma_jSigma_k);
+
+ System.out.println("Sigma_kSigma_j for each Label and for each Features");
+ Float retSigma_jSigma_k = floatStringifier.fromString(sigma_jSigma_kString);
+ System.out.println(retSigma_jSigma_k);
+ conf.set("cnaivebayes.sigma_jSigma_k", sigma_jSigma_kString);
+
+ Path vocabCountFile = new Path(output+"/trainer-tfIdf/trainer-vocabCount/*");
+ Float vocabCount = reader.readVocabCount(dfs, vocabCountFile, conf);
+ String vocabCountString = floatStringifier.toString(vocabCount);
+
+ System.out.println("Vocabulary Count");
+ conf.set("cnaivebayes.vocabCount", vocabCountString);
+ Float retvocabCount = floatStringifier.fromString(vocabCountString);
+ System.out.println(retvocabCount);
+
+ client.setConf(conf);
+
+ JobClient.runJob(conf);
+
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerDriver.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerMapper.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerMapper.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,105 @@
+package org.apache.mahout.classifier.bayes;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.io.DefaultStringifier;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.util.GenericsUtil;
+
+import java.io.IOException;
+import java.util.*;
+
+/**
+ *
+ *
+ */
+public class BayesThetaNormalizerMapper extends MapReduceBase implements
+ Mapper<Text, FloatWritable, Text, FloatWritable> {
+
+ public HashMap<String, Float> labelWeightSum = null;
+
+ String labelWeightSumString = " ";
+
+ Float sigma_jSigma_k = 0f;
+
+ String sigma_jSigma_kString = " ";
+
+ Float vocabCount = 0f;
+
+ String vocabCountString = " ";
+
+ /**
+ * We need to calculate the thetaNormalization factor of each label
+ *
+ * @param key The label,feature pair
+ * @param value The tfIdf of the pair
+ * @param output
+ * @param reporter
+ * @throws IOException
+ */
+ public void map(Text key, FloatWritable value,
+ OutputCollector<Text, FloatWritable> output, Reporter reporter)
+ throws IOException {
+
+ String labelFeaturePair = key.toString();
+ Float alpha_i = 1.0f;
+
+ String label = labelFeaturePair.split(",")[0];
+ float weight = (float) Math.log((value.get() + alpha_i) / (labelWeightSum.get(label) + vocabCount));
+ output.collect(new Text(("_" + label).trim()), new FloatWritable(weight));
+ }
+
+ @Override
+ public void configure(JobConf job) {
+ try {
+ if (labelWeightSum == null) {
+ labelWeightSum = new HashMap<String, Float>();
+
+ DefaultStringifier<HashMap<String, Float>> mapStringifier = new DefaultStringifier<HashMap<String, Float>>(
+ job, GenericsUtil.getClass(labelWeightSum));
+
+ labelWeightSumString = mapStringifier.toString(labelWeightSum);
+ labelWeightSumString = job.get("cnaivebayes.sigma_k",
+ labelWeightSumString);
+ labelWeightSum = mapStringifier.fromString(labelWeightSumString);
+
+ DefaultStringifier<Float> floatStringifier = new DefaultStringifier<Float>(
+ job, GenericsUtil.getClass(sigma_jSigma_k));
+ sigma_jSigma_kString = floatStringifier.toString(sigma_jSigma_k);
+ sigma_jSigma_kString = job.get("cnaivebayes.sigma_jSigma_k",
+ sigma_jSigma_kString);
+ sigma_jSigma_k = floatStringifier.fromString(sigma_jSigma_kString);
+
+ vocabCountString = floatStringifier.toString(vocabCount);
+ vocabCountString = job.get("cnaivebayes.vocabCount", vocabCountString);
+ vocabCount = floatStringifier.fromString(vocabCountString);
+
+ }
+ } catch (IOException ex) {
+
+ ex.printStackTrace();
+ }
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerMapper.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerReducer.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerReducer.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerReducer.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,71 @@
+package org.apache.mahout.classifier.bayes;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reducer;
+import org.apache.hadoop.mapred.Reporter;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Iterator;
+
+/**
+ * Can also be used as a local Combiner beacuse only two values should be there
+ * inside the values
+ *
+ */
+
+public class BayesThetaNormalizerReducer extends MapReduceBase implements
+ Reducer<Text, FloatWritable, Text, FloatWritable> {
+
+ public HashMap<String, Float> labelWeightSum = null;
+
+ String labelWeightSumString = " ";
+
+ Float sigma_jSigma_k = 0f;
+
+ String sigma_jSigma_kString = " ";
+
+ Float vocabCount = 0f;
+
+ String vocabCountString = " ";
+
+ @SuppressWarnings("unused")
+ public void reduce(Text key, Iterator<FloatWritable> values,
+ OutputCollector<Text, FloatWritable> output, Reporter reporter)
+ throws IOException {
+ // Key is label,word, value is the number of times we've seen this label
+ // word per local node. Output is the same
+
+ String token = key.toString();
+
+ float weightSumPerLabel = 0.0f;
+
+ while (values.hasNext()) {
+ weightSumPerLabel += values.next().get();
+ }
+ // System.out.println(token + "=>"+ weightSumPerLabel);
+ output.collect(key, new FloatWritable(weightSumPerLabel));
+
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/BayesThetaNormalizerReducer.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureDriver.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureDriver.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureDriver.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,100 @@
+package org.apache.mahout.classifier.bayes.common;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.mapred.JobClient;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.KeyValueTextInputFormat;
+import org.apache.hadoop.io.DefaultStringifier;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FileSystem;
+
+
+/**
+ * Create and run the Bayes Feature Reader Step.
+ *
+ **/
+public class BayesFeatureDriver {
+ /**
+ * Takes in two arguments:
+ * <ol>
+ * <li>The input {@link org.apache.hadoop.fs.Path} where the input documents live</li>
+ * <li>The output {@link org.apache.hadoop.fs.Path} where to write the interim files as a {@link org.apache.hadoop.io.SequenceFile}</li>
+ * </ol>
+ * @param args The args
+ */
+ public static void main(String[] args) {
+ String input = args[0];
+ String output = args[1];
+
+ runJob(input, output, 1);
+ }
+
+ /**
+ * Run the job
+ *
+ * @param input the input pathname String
+ * @param output the output pathname String
+
+ */
+
+ @SuppressWarnings("deprecation")
+ public static void runJob(String input, String output, int gramSize) {
+ JobClient client = new JobClient();
+ JobConf conf = new JobConf(BayesFeatureDriver.class);
+
+ conf.setOutputKeyClass(Text.class);
+ conf.setOutputValueClass(FloatWritable.class);
+
+ conf.setInputPath(new Path(input));
+ Path outPath = new Path(output);
+ conf.setOutputPath(outPath);
+ conf.setNumMapTasks(100);
+ //conf.setNumReduceTasks(1);
+ conf.setMapperClass(BayesFeatureMapper.class);
+
+ conf.setInputFormat(KeyValueTextInputFormat.class);
+ conf.setCombinerClass(BayesFeatureReducer.class);
+ conf.setReducerClass(BayesFeatureReducer.class);
+ conf.setOutputFormat(BayesFeatureOutputFormat.class);
+
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,org.apache.hadoop.io.serializer.WritableSerialization"); // Dont ever forget this. People should keep track of how hadoop conf parameters and make or break a piece of code
+
+ try {
+ FileSystem dfs = FileSystem.get(conf);
+ if (dfs.exists(outPath))
+ dfs.delete(outPath, true);
+
+ DefaultStringifier<Integer> intStringifier = new DefaultStringifier<Integer>(conf, Integer.class);
+ String gramSizeString = intStringifier.toString(new Integer(gramSize));
+
+ Integer retGramSize = intStringifier.fromString(gramSizeString);
+ System.out.println(retGramSize);
+ conf.set("bayes.gramSize", gramSizeString);
+
+ client.setConf(conf);
+ JobClient.runJob(conf);
+
+
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureDriver.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureMapper.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureMapper.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureMapper.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,134 @@
+package org.apache.mahout.classifier.bayes.common;
+
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.io.DefaultStringifier;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.mahout.classifier.BayesFileFormatter;
+import org.apache.mahout.common.Model;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.StringTokenizer;
+import java.util.Hashtable;
+import java.util.Enumeration;
+
+/**
+ * Reads the input train set(preprocessed using the {@link BayesFileFormatter}).
+ *
+ */
+public class BayesFeatureMapper extends MapReduceBase implements
+ Mapper<Text, Text, Text, FloatWritable> {
+ private final static FloatWritable one = new FloatWritable(1.00f);
+
+ private Text labelWord = new Text();
+
+ private int gramSize = 1;
+
+ /**
+ * We need to count the number of times we've seen a term with a given label
+ * and we need to output that. But this Mapper does more than just outputing the count. It first does weight normalisation.
+ * Secondly, it outputs for each unique word in a document value 1 for summing up as the Term Document Frequency. Which later is used to calculate the Idf
+ * Thirdly, it outputs for each label the number of times a document was seen(Also used in Idf Calculation)
+ *
+ * @param key The label
+ * @param value the features (all unique) associated w/ this label
+ * @param output
+ * @param reporter
+ * @throws IOException
+ */
+ public void map(Text key, Text value,
+ OutputCollector<Text, FloatWritable> output, Reporter reporter)
+ throws IOException {
+ String line = value.toString();
+ StringTokenizer itr = new StringTokenizer(line);
+ String label = key.toString();
+ int keyLen = label.length();
+
+ Hashtable<String, Integer> wordList = new Hashtable<String, Integer>(1000);
+
+ StringBuilder builder = new StringBuilder(label);
+ builder.ensureCapacity(32);// make sure we have a reasonably size buffer to
+ // begin with
+ List<String> previousN_1Grams = Model.generateNGramsWithoutLabel(line, keyLen);
+
+ Double lengthNormalisation = new Double(0.0d);
+ for (Enumeration<String> e = wordList.keys(); e.hasMoreElements();) {
+ // key is label,word
+ String token = e.nextElement();
+ Double D_kj = new Double(wordList.get(token).doubleValue());
+ lengthNormalisation += (double) (D_kj * D_kj);
+ }
+ lengthNormalisation = Math.sqrt(lengthNormalisation);
+
+ // Ouput Length Normalized + TF Transformed Frequency per Word per Class
+ // Log(1 + D_ij)/SQRT( SIGMA(k, D_kj) )
+ for (Enumeration<String> e = wordList.keys(); e.hasMoreElements();) {
+ // key is label,word
+ String token = e.nextElement();
+ builder.append(",").append(token);
+ labelWord.set(builder.toString());
+ FloatWritable f = new FloatWritable((float) (Math
+ .log((double) (1 + wordList.get(token))) / lengthNormalisation));
+ output.collect(labelWord, f);
+ builder.setLength(keyLen);// truncate back
+ }
+
+ // Ouput Document Frequency per Word per Class
+ String dflabel = "-" + label;
+ int dfKeyLen = dflabel.length();
+ builder = new StringBuilder(dflabel);
+ for (Enumeration<String> e = wordList.keys(); e.hasMoreElements();) {
+ // key is label,word
+ String token = e.nextElement();
+ builder.append(",").append(token);
+ labelWord.set(builder.toString());
+ output.collect(labelWord, one);
+ output.collect(new Text("," + token), one);
+ builder.setLength(dfKeyLen);// truncate back
+
+ }
+
+ // ouput that we have seen the label to calculate the Count of Document per
+ // class
+ output.collect(new Text("_" + label), one);
+ }
+
+ @Override
+ public void configure(JobConf job) {
+ try {
+
+ DefaultStringifier<Integer> intStringifier = new DefaultStringifier<Integer>(job, Integer.class);
+
+ String gramSizeString = intStringifier.toString(gramSize);
+ gramSizeString = job.get("bayes.gramSize", gramSizeString);
+ gramSize = intStringifier.fromString(gramSizeString);
+
+ } catch (IOException ex) {
+
+ ex.printStackTrace();
+ }
+ }
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureMapper.java
------------------------------------------------------------------------------
svn:eol-style = native
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureOutputFormat.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureOutputFormat.java?rev=687042&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureOutputFormat.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureOutputFormat.java Tue Aug 19 05:55:45 2008
@@ -0,0 +1,64 @@
+package org.apache.mahout.classifier.bayes.common;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+
+import org.apache.hadoop.fs.FileSystem;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.RecordWriter;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.mapred.lib.MultipleOutputFormat;
+import org.apache.hadoop.util.Progressable;
+
+/**
+ * This class extends the MultipleOutputFormat, allowing to write the output data to different output files in sequence file output format.
+ */
+public class BayesFeatureOutputFormat extends
+ MultipleOutputFormat<WritableComparable, Writable> {
+
+ private SequenceFileOutputFormat theSequenceFileOutputFormat = null;
+
+ @Override
+ protected RecordWriter<WritableComparable, Writable> getBaseRecordWriter(
+ FileSystem fs, JobConf job, String name, Progressable arg3)
+ throws IOException {
+ if (theSequenceFileOutputFormat == null) {
+ theSequenceFileOutputFormat = new SequenceFileOutputFormat();
+ }
+ return theSequenceFileOutputFormat.getRecordWriter(fs, job, name, arg3);
+ }
+ @Override
+ protected String generateFileNameForKeyValue(WritableComparable k, Writable v,
+ String name) {
+ Text key = (Text)k;
+
+ if(key.toString().startsWith("_"))
+ return "trainer-docCount/"+name;
+ else if(key.toString().startsWith("-"))
+ return "trainer-termDocCount/"+name;
+ else if(key.toString().startsWith(","))
+ return "trainer-featureCount/"+name;
+ else
+ return "trainer-wordFreq/"+name;
+ }
+
+}
Propchange: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/bayes/common/BayesFeatureOutputFormat.java
------------------------------------------------------------------------------
svn:eol-style = native