You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/03/04 06:40:03 UTC

svn commit: r918860 - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/clustering/lda/ main/java/org/apache/mahout/common/ test/java/org/apache/mahout/clustering/lda/ test/java/org/apache/mahout/common/

Author: robinanil
Date: Thu Mar  4 05:40:03 2010
New Revision: 918860

URL: http://svn.apache.org/viewvc?rev=918860&view=rev
Log:
MAHOUT-320 Improvements in LDA

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java
      - copied, changed from r918394, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java
Removed:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=918860&r1=918859&r2=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Thu Mar  4 05:40:03 2010
@@ -41,6 +41,7 @@
 import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
 import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.IntPairWritable;
 import org.apache.mahout.common.RandomUtils;
 import org.apache.mahout.math.DenseMatrix;
 import org.slf4j.Logger;
@@ -52,9 +53,9 @@
  */
 public final class LDADriver {
   
-  static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";  
+  static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
   static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
-  static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";  
+  static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
   static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";
   
   static final int LOG_LIKELIHOOD_KEY = -2;
@@ -63,7 +64,7 @@
   
   private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
   
-  private LDADriver() { }
+  private LDADriver() {}
   
   public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {
     
@@ -196,8 +197,7 @@
       log.info("Iteration {}", iteration);
       // point the output to a new directory per iteration
       String stateOut = output + "/state-" + (iteration + 1);
-      double ll = runIteration(input, stateIn, stateOut, numTopics, numWords, topicSmoothing,
-        numReducers);
+      double ll = runIteration(input, stateIn, stateOut, numTopics, numWords, topicSmoothing, numReducers);
       double relChange = (oldLL - ll) / oldLL;
       
       // now point the input to the old output directory
@@ -216,7 +216,6 @@
     Configuration job = new Configuration();
     FileSystem fs = dir.getFileSystem(job);
     
-    IntPairWritable kw = new IntPairWritable();
     DoubleWritable v = new DoubleWritable();
     
     Random random = RandomUtils.getRandom();
@@ -226,20 +225,18 @@
       SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class,
           DoubleWritable.class);
       
-      kw.setX(k);
       double total = 0.0; // total number of pseudo counts we made
       for (int w = 0; w < numWords; ++w) {
-        kw.setY(w);
+        IntPairWritable kw = new IntPairWritable(k, w);
         // A small amount of random noise, minimized by having a floor.
         double pseudocount = random.nextDouble() + 1.0E-8;
         total += pseudocount;
         v.set(Math.log(pseudocount));
         writer.append(kw, v);
       }
-      
-      kw.setY(TOPIC_SUM_KEY);
+      IntPairWritable kTsk = new IntPairWritable(k, TOPIC_SUM_KEY);
       v.set(Math.log(total));
-      writer.append(kw, v);
+      writer.append(kTsk, v);
       
       writer.close();
     }
@@ -257,7 +254,7 @@
       Path path = status.getPath();
       SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
       while (reader.next(key, value)) {
-        if (key.getX() == LOG_LIKELIHOOD_KEY) {
+        if (key.getFirst() == LOG_LIKELIHOOD_KEY) {
           ll = value.get();
           break;
         }
@@ -336,8 +333,8 @@
       Path path = status.getPath();
       SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
       while (reader.next(key, value)) {
-        int topic = key.getX();
-        int word = key.getY();
+        int topic = key.getFirst();
+        int word = key.getSecond();
         if (word == TOPIC_SUM_KEY) {
           logTotals[topic] = value.get();
           if (Double.isInfinite(value.get())) {

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java?rev=918860&r1=918859&r2=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java Thu Mar  4 05:40:03 2010
@@ -25,7 +25,6 @@
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.function.BinaryFunction;
-import org.apache.mahout.math.map.OpenIntIntHashMap;
 
 /**
  * Class for performing infererence on a document, which involves computing (an approximation to)
@@ -49,14 +48,14 @@
     private final Vector wordCounts;
     private final Vector gamma; // p(topic)
     private final Matrix mphi; // log p(columnMap(w)|t)
-    private final OpenIntIntHashMap columnMap; // maps words into the matrix's column map
+    private final int[] columnMap; // maps words into the matrix's column map
     public final double logLikelihood;
     
     public double phi(int k, int w) {
-      return mphi.getQuick(k, columnMap.get(w));
+      return mphi.getQuick(k, columnMap[w]);
     }
     
-    InferredDocument(Vector wordCounts, Vector gamma, OpenIntIntHashMap columnMap, Matrix phi, double ll) {
+    InferredDocument(Vector wordCounts, Vector gamma, int[] columnMap, Matrix phi, double ll) {
       this.wordCounts = wordCounts;
       this.gamma = gamma;
       this.mphi = phi;
@@ -78,7 +77,7 @@
    */
   public InferredDocument infer(Vector wordCounts) {
     double docTotal = wordCounts.zSum();
-    int docLength = wordCounts.size();
+    int docLength = wordCounts.size(); // cardinality of document vectors
     
     // initialize variational approximation to p(z|doc)
     Vector gamma = new DenseVector(state.numTopics);
@@ -86,13 +85,9 @@
     Vector nextGamma = new DenseVector(state.numTopics);
     createPhiMatrix(docLength);
     
-    // digamma is expensive, precompute
-    Vector digammaGamma = digamma(gamma);
-    // and log normalize:
-    double digammaSumGamma = digamma(gamma.zSum());
-    digammaGamma = digammaGamma.plus(-digammaSumGamma);
+    Vector digammaGamma = digammaGamma(gamma);
     
-    OpenIntIntHashMap columnMap = new OpenIntIntHashMap();
+    int[] map = new int[docLength];
     
     int iteration = 0;
     
@@ -108,12 +103,12 @@
         Vector phiW = eStepForWord(word, digammaGamma);
         phi.assignColumn(mapping, phiW);
         if (iteration == 0) { // first iteration
-          columnMap.put(word, mapping);
+          map[word] = mapping;
         }
         
         for (int k = 0; k < nextGamma.size(); ++k) {
           double g = nextGamma.getQuick(k);
-          nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.get(k)));
+          nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.getQuick(k)));
         }
         
         mapping++;
@@ -123,31 +118,36 @@
       gamma = nextGamma;
       nextGamma = tempG;
       
-      // digamma is expensive, precompute
-      digammaGamma = digamma(gamma);
-      // and log normalize:
-      digammaSumGamma = digamma(gamma.zSum());
-      digammaGamma = digammaGamma.plus(-digammaSumGamma);
+      digammaGamma = digammaGamma(gamma);
       
-      double ll = computeLikelihood(wordCounts, columnMap, phi, gamma, digammaGamma);
-      assert !Double.isNaN(ll);
+      double ll = computeLikelihood(wordCounts, map, phi, gamma, digammaGamma);
+      // isNotNaNAssertion(ll);
       converged = (oldLL < 0) && ((oldLL - ll) / oldLL < E_STEP_CONVERGENCE);
       
       oldLL = ll;
       iteration++;
     }
     
-    return new InferredDocument(wordCounts, gamma, columnMap, phi, oldLL);
+    return new InferredDocument(wordCounts, gamma, map, phi, oldLL);
+  }
+  
+  private Vector digammaGamma(Vector gamma) {
+    // digamma is expensive, precompute
+    Vector digammaGamma = digamma(gamma);
+    // and log normalize:
+    double digammaSumGamma = digamma(gamma.zSum());
+    for (int i = 0; i < state.numTopics; i++) {
+      digammaGamma.setQuick(i, digammaGamma.getQuick(i) - digammaSumGamma);
+    }
+    return digammaGamma;
   }
   
   private void createPhiMatrix(int docLength) {
-    if (phi == null){
+    if (phi == null) {
       phi = new DenseMatrix(state.numTopics, docLength);
-    }
-    else if (phi.getRow(0).size() != docLength){
+    } else if (phi.getRow(0).size() != docLength) {
       phi = new DenseMatrix(state.numTopics, docLength);
-    }
-    else {
+    } else {
       phi.assign(0);
     }
   }
@@ -155,46 +155,43 @@
   private DenseMatrix phi;
   private final LDAState state;
   
-  private double computeLikelihood(Vector wordCounts,
-                                   OpenIntIntHashMap columnMap,
-                                   Matrix phi,
-                                   Vector gamma,
-                                   Vector digammaGamma) {
+  private double computeLikelihood(Vector wordCounts, int[] map, Matrix phi, Vector gamma, Vector digammaGamma) {
     double ll = 0.0;
     
     // log normalizer for q(gamma);
     ll += Gamma.logGamma(state.topicSmoothing * state.numTopics);
     ll -= state.numTopics * Gamma.logGamma(state.topicSmoothing);
-    assert !Double.isNaN(ll) : state.topicSmoothing + " " + state.numTopics;
+    // isNotNaNAssertion(ll);
     
     // now for the the rest of q(gamma);
     for (int k = 0; k < state.numTopics; ++k) {
-      ll += (state.topicSmoothing - gamma.get(k)) * digammaGamma.get(k);
-      ll += Gamma.logGamma(gamma.get(k));
+      double gammaK = gamma.get(k);
+      ll += (state.topicSmoothing - gammaK) * digammaGamma.getQuick(k);
+      ll += Gamma.logGamma(gammaK);
       
     }
     ll -= Gamma.logGamma(gamma.zSum());
-    assert !Double.isNaN(ll);
+    // isNotNaNAssertion(ll);
     
     // for each word
     for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
       Vector.Element e = iter.next();
       int w = e.index();
       double n = e.get();
-      int mapping = columnMap.get(w);
+      int mapping = map[w];
       // now for each topic:
       for (int k = 0; k < state.numTopics; k++) {
         double llPart = 0.0;
-        llPart += Math.exp(phi.getQuick(k, mapping))
-                  * (digammaGamma.get(k) - phi.getQuick(k, mapping) + state.logProbWordGivenTopic(w, k));
+        double phiKMapping = phi.getQuick(k, mapping);
+        llPart += Math.exp(phiKMapping)
+                  * (digammaGamma.getQuick(k) - phiKMapping + state.logProbWordGivenTopic(w, k));
         
         ll += llPart * n;
         
-        assert state.logProbWordGivenTopic(w, k) < 0;
-        assert !Double.isNaN(llPart);
+        // likelihoodAssertion(w, k, llPart);
       }
     }
-    assert ll <= 0;
+    // isLessThanOrEqualsZero(ll);
     return ll;
   }
   
@@ -205,13 +202,10 @@
     Vector phi = new DenseVector(state.numTopics); // log q(k|w), for each w
     double phiTotal = Double.NEGATIVE_INFINITY; // log Normalizer
     for (int k = 0; k < state.numTopics; ++k) { // update q(k|w)'s param phi
-      phi.set(k, state.logProbWordGivenTopic(word, k) + digammaGamma.get(k));
-      phiTotal = LDAUtil.logSum(phiTotal, phi.get(k));
+      phi.setQuick(k, state.logProbWordGivenTopic(word, k) + digammaGamma.getQuick(k));
+      phiTotal = LDAUtil.logSum(phiTotal, phi.getQuick(k));
       
-      assert !Double.isNaN(phiTotal);
-      assert !Double.isNaN(state.logProbWordGivenTopic(word, k));
-      assert !Double.isInfinite(state.logProbWordGivenTopic(word, k));
-      assert !Double.isNaN(digammaGamma.get(k));
+      // assertions(word, digammaGamma, phiTotal, k);
     }
     for (int i = 0; i < state.numTopics; i++) {
       phi.setQuick(i, phi.getQuick(i) - phiTotal);// log normalize
@@ -229,7 +223,7 @@
     });
     return digammaGamma;
   }
-  
+
   /**
    * Approximation to the digamma function, from Radford Neal.
    * 
@@ -260,4 +254,25 @@
     return r + Math.log(x) - 0.5 / x + t;
   }
   
+  /*
+  private void assertions(int word, Vector digammaGamma, double phiTotal, int k) {
+    assert !Double.isNaN(phiTotal);
+    assert !Double.isNaN(state.logProbWordGivenTopic(word, k));
+    assert !Double.isInfinite(state.logProbWordGivenTopic(word, k));
+    assert !Double.isNaN(digammaGamma.getQuick(k));
+  }
+  
+  private void likelihoodAssertion(int w, int k, double llPart) {
+    assert state.logProbWordGivenTopic(w, k) < 0;
+    assert !Double.isNaN(llPart);
+  }
+
+  private void isLessThanOrEqualsZero(double ll) {
+    assert ll <= 0;
+  }
+
+  private void isNotNaNAssertion(double ll) {
+    assert !Double.isNaN(ll) : state.topicSmoothing + " " + state.numTopics;
+  }
+  */
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java?rev=918860&r1=918859&r2=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java Thu Mar  4 05:40:03 2010
@@ -25,6 +25,7 @@
 import org.apache.hadoop.io.DoubleWritable;
 import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.IntPairWritable;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
@@ -47,16 +48,15 @@
     Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
     
     // Output sufficient statistics for each word. == pseudo-log counts.
-    IntPairWritable kw = new IntPairWritable();
     DoubleWritable v = new DoubleWritable();
     for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
       Vector.Element e = iter.next();
       int w = e.index();
-      kw.setY(w);
+      
       for (int k = 0; k < state.numTopics; ++k) {
         v.set(doc.phi(k, w) + Math.log(e.get()));
         
-        kw.setX(k);
+        IntPairWritable kw = new IntPairWritable(k, w);
         
         // ouput (topic, word)'s logProb contribution
         context.write(kw, v);
@@ -66,19 +66,16 @@
     
     // Output the totals for the statistics. This is to make
     // normalizing a lot easier.
-    kw.setY(LDADriver.TOPIC_SUM_KEY);
     for (int k = 0; k < state.numTopics; ++k) {
-      kw.setX(k);
+      IntPairWritable kw = new IntPairWritable(k, LDADriver.TOPIC_SUM_KEY);
       v.set(logTotals[k]);
       assert !Double.isNaN(v.get());
       context.write(kw, v);
     }
-    
+    IntPairWritable llk = new IntPairWritable(LDADriver.LOG_LIKELIHOOD_KEY, LDADriver.LOG_LIKELIHOOD_KEY);
     // Output log-likelihoods.
-    kw.setX(LDADriver.LOG_LIKELIHOOD_KEY);
-    kw.setY(LDADriver.LOG_LIKELIHOOD_KEY);
     v.set(doc.logLikelihood);
-    context.write(kw, v);
+    context.write(llk, v);
   }
   
   public void configure(LDAState myState) {

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java?rev=918860&r1=918859&r2=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java Thu Mar  4 05:40:03 2010
@@ -18,6 +18,7 @@
 
 import org.apache.hadoop.io.DoubleWritable;
 import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.common.IntPairWritable;
 
 /**
  * A very simple reducer which simply logSums the input doubles and outputs a new double for sufficient
@@ -31,12 +32,12 @@
                      Context context) throws java.io.IOException, InterruptedException {
     
     // sum likelihoods
-    if (topicWord.getY() == LDADriver.LOG_LIKELIHOOD_KEY) {
+    if (topicWord.getSecond() == LDADriver.LOG_LIKELIHOOD_KEY) {
       double accum = 0.0;
       for (DoubleWritable vw : values) {
         double v = vw.get();
         if (Double.isNaN(v)) {
-          throw new IllegalArgumentException(topicWord.getX() + " " + topicWord.getY());
+          throw new IllegalArgumentException(topicWord.getFirst() + " " + topicWord.getSecond());
         }
         accum += v;
       }
@@ -46,11 +47,11 @@
       for (DoubleWritable vw : values) {
         double v = vw.get();
         if (Double.isNaN(v)) {
-          throw new IllegalArgumentException(topicWord.getX() + " " + topicWord.getY());
+          throw new IllegalArgumentException(topicWord.getFirst() + " " + topicWord.getSecond());
         }
         accum = LDAUtil.logSum(accum, v);
         if (Double.isNaN(accum)) {
-          throw new IllegalArgumentException(topicWord.getX() + " " + topicWord.getY());
+          throw new IllegalArgumentException(topicWord.getFirst() + " " + topicWord.getSecond());
         }
       }
       context.write(topicWord, new DoubleWritable(accum));

Copied: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java (from r918394, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java)
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java?p2=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java&p1=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java&r1=918394&r2=918860&rev=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java Thu Mar  4 05:40:03 2010
@@ -15,117 +15,215 @@
  * limitations under the License.
  */
 
-package org.apache.mahout.clustering.lda;
+package org.apache.mahout.common;
 
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
 import java.io.Serializable;
+import java.util.Arrays;
 
+import org.apache.hadoop.io.BinaryComparable;
 import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.io.WritableComparator;
 
 /**
  * Saves two ints, x and y.
  */
-public class IntPairWritable implements WritableComparable<IntPairWritable> {
+public final class IntPairWritable extends BinaryComparable implements WritableComparable<BinaryComparable> {
   
-  private int x;
-  private int y;
+  private static final int INT_PAIR_BYTE_LENGTH = 8;
+  private byte[] b = new byte[INT_PAIR_BYTE_LENGTH];
   
-  /** For serialization purposes only */
-  public IntPairWritable() { }
+  public IntPairWritable() {
+    setFirst(0);
+    setSecond(0);
+  }
+  
+  public IntPairWritable(IntPairWritable pair) {
+    b = Arrays.copyOf(pair.getBytes(), INT_PAIR_BYTE_LENGTH);
+  }
   
   public IntPairWritable(int x, int y) {
-    this.x = x;
-    this.y = y;
+    putInt(x, b, 0);
+    putInt(y, b, 4);
   }
   
-  public void setX(int x) {
-    this.x = x;
+  public void set(int x, int y) {
+    putInt(x, b, 0);
+    putInt(y, b, 4);
   }
   
-  public int getX() {
-    return x;
+  public void setFirst(int x) {
+    putInt(x, b, 0);
   }
   
-  public void setY(int y) {
-    this.y = y;
+  public int getFirst() {
+    return getInt(b, 0);
   }
   
-  public int getY() {
-    return y;
+  public void setSecond(int y) {
+    putInt(y, b, 4);
   }
   
-  @Override
-  public void write(DataOutput dataOutput) throws IOException {
-    dataOutput.writeInt(x);
-    dataOutput.writeInt(y);
+  public int getSecond() {
+    return getInt(b, 4);
   }
   
   @Override
-  public void readFields(DataInput dataInput) throws IOException {
-    x = dataInput.readInt();
-    y = dataInput.readInt();
+  public void readFields(DataInput in) throws IOException {
+    in.readFully(b);
   }
   
   @Override
-  public int compareTo(IntPairWritable that) {
-    if (this.x < that.getX()) {
-      return -1;
-    } else if (this.x > that.getX()) {
-      return 1;
-    } else {
-      return this.y < that.getY() ? -1 : this.y > that.getY() ? 1 : 0;
-    }
+  public void write(DataOutput out) throws IOException {
+    out.write(b);
   }
   
   @Override
-  public boolean equals(Object o) {
-    if (this == o) {
-      return true;
-    } else if (!(o instanceof IntPairWritable)) {
-      return false;
-    }
-    
-    IntPairWritable that = (IntPairWritable) o;
-    
-    return (that.getX() == this.x) && (this.y == that.getY());
+  public int hashCode() {
+    return 43 * Arrays.hashCode(b);
   }
   
   @Override
-  public int hashCode() {
-    return 43 * x + y;
+  public boolean equals(Object obj) {
+    if (this == obj) return true;
+    if (!super.equals(obj)) return false;
+    if (getClass() != obj.getClass()) return false;
+    IntPairWritable other = (IntPairWritable) obj;
+    if (!Arrays.equals(b, other.b)) return false;
+    return true;
   }
   
   @Override
   public String toString() {
-    return "(" + x + ", " + y + ')';
+    return "(" + getFirst() + ", " + getSecond() + ")";
+  }
+  
+  @Override
+  public byte[] getBytes() {
+    return b;
+  }
+  
+  @Override
+  public int getLength() {
+    return INT_PAIR_BYTE_LENGTH;
+  }
+  
+  private static void putInt(int value, byte[] b, int offset) {
+    if (offset + 4 > INT_PAIR_BYTE_LENGTH) {
+      throw new IllegalArgumentException("offset+4 exceeds byte array length");
+    }
+    
+    for (int i = 0; i < 4; i++) {
+      b[offset + i] = (byte) (((value >>> ((3 - i) * 8)) & 0xFF) ^ 0x80);
+    }
+  }
+  
+  private static int getInt(byte[] b, int offset) {
+    if (offset + 4 > INT_PAIR_BYTE_LENGTH) {
+      throw new IllegalArgumentException("offset+4 exceeds byte array length");
+    }
+    
+    int value = 0;
+    for (int i = 0; i < 4; i++) {
+      value += ((b[i + offset] & 0xFF) ^ 0x80) << (3 - i) * 8;
+    }
+    return value;
   }
   
   static {
     WritableComparator.define(IntPairWritable.class, new Comparator());
   }
   
-  public static class Comparator extends WritableComparator implements Serializable {
+  public static final class Comparator extends WritableComparator implements Serializable {
     public Comparator() {
       super(IntPairWritable.class);
     }
     
     @Override
     public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
-      if (l1 != 8) {
+      if (l1 != 8 || l2 != 8) {
         throw new IllegalArgumentException();
       }
-      int int11 = WritableComparator.readInt(b1, s1);
-      int int21 = WritableComparator.readInt(b2, s2);
-      if (int11 != int21) {
-        return int11 - int21;
+      return WritableComparator.compareBytes(b1, s1, l1, b2, s2, l2);
+    }
+  }
+  
+  /**
+   * Compare only the first part of the pair, so that reduce is called once for each value of the first part.
+   */
+  public static class FirstGroupingComparator extends WritableComparator implements Serializable {
+    
+    public FirstGroupingComparator() {
+      super(IntPairWritable.class);
+    }
+    
+    @Override
+    public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
+      int ret;
+      int firstb1 = WritableComparator.readInt(b1, s1);
+      int firstb2 = WritableComparator.readInt(b2, s2);
+      ret = firstb1 - firstb2;
+      return ret;
+    }
+    
+    @Override
+    public int compare(Object o1, Object o2) {
+      if (o1 == null) {
+        return -1;
+      } else if (o2 == null) {
+        return 1;
+      } else {
+        int firstb1 = ((IntPairWritable) o1).getFirst();
+        int firstb2 = ((IntPairWritable) o2).getFirst();
+        return firstb1 - firstb2;
+      }
+    }
+    
+  }
+  
+  /** A wrapper class that associates pairs with frequency (Occurences) */
+  public static class Frequency implements Comparable<Frequency> {
+    
+    private IntPairWritable pair = new IntPairWritable();
+    private double frequency = 0.0;
+    
+    public double getFrequency() {
+      return frequency;
+    }
+    
+    public IntPairWritable getPair() {
+      return pair;
+    }
+    
+    public Frequency(IntPairWritable bigram, double frequency) {
+      this.pair = new IntPairWritable(bigram);
+      this.frequency = frequency;
+    }
+    
+    @Override
+    public int hashCode() {
+      return pair.hashCode() + (int) Math.abs(Math.round(frequency * 31));
+    }
+    
+    @Override
+    public boolean equals(Object right) {
+      if ((right == null) || !(right instanceof Frequency)) {
+        return false;
       }
-      
-      int int12 = WritableComparator.readInt(b1, s1 + 4);
-      int int22 = WritableComparator.readInt(b2, s2 + 4);
-      return int12 - int22;
+      Frequency that = (Frequency) right;
+      return this.compareTo(that) == 0;
+    }
+    
+    @Override
+    public int compareTo(Frequency that) {
+      return this.frequency > that.frequency ? 1 : -1;
+    }
+    
+    @Override
+    public String toString() {
+      return pair + "\t" + frequency;
     }
   }
 }

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java?rev=918860&r1=918859&r2=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java Thu Mar  4 05:40:03 2010
@@ -16,24 +16,29 @@
  */
 package org.apache.mahout.clustering.lda;
 
+import static org.easymock.EasyMock.expectLastCall;
+import static org.easymock.EasyMock.isA;
+import static org.easymock.classextension.EasyMock.createMock;
+import static org.easymock.classextension.EasyMock.replay;
+import static org.easymock.classextension.EasyMock.verify;
+
 import java.io.File;
 import java.util.Iterator;
 import java.util.Random;
 
+import org.apache.commons.math.MathException;
 import org.apache.commons.math.distribution.PoissonDistribution;
 import org.apache.commons.math.distribution.PoissonDistributionImpl;
-import org.apache.commons.math.MathException;
 import org.apache.hadoop.io.DoubleWritable;
 import org.apache.hadoop.io.Text;
+import org.apache.mahout.common.IntPairWritable;
 import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
 import org.apache.mahout.math.DenseMatrix;
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.common.RandomUtils;
-
-import static org.easymock.classextension.EasyMock.*;
 
 public class TestMapReduce extends MahoutTestCase {
 

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java?rev=918860&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java Thu Mar  4 05:40:03 2010
@@ -0,0 +1,100 @@
+package org.apache.mahout.common;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Arrays;
+
+import junit.framework.Assert;
+
+import org.apache.mahout.common.IntPairWritable;
+import org.junit.Test;
+
+
+public class IntPairWritableTest {
+
+  @Test
+  public void testGetSet() {
+    IntPairWritable n = new IntPairWritable();
+    
+    Assert.assertEquals(0, n.getFirst());
+    Assert.assertEquals(0, n.getSecond());
+    
+    n.setFirst(5);
+    n.setSecond(10);
+    
+    Assert.assertEquals(5, n.getFirst());
+    Assert.assertEquals(10, n.getSecond());
+    
+    n = new IntPairWritable(2,4);
+    
+    Assert.assertEquals(2, n.getFirst());
+    Assert.assertEquals(4, n.getSecond());
+  }
+  
+  @Test
+  public void testWritable() throws IOException {
+    IntPairWritable one = new IntPairWritable(1,2);
+    IntPairWritable two = new IntPairWritable(3,4);
+    
+    Assert.assertEquals(1, one.getFirst());
+    Assert.assertEquals(2, one.getSecond());
+    
+    Assert.assertEquals(3, two.getFirst());
+    Assert.assertEquals(4, two.getSecond());
+    
+    
+    ByteArrayOutputStream bout = new ByteArrayOutputStream();
+    DataOutputStream out = new DataOutputStream(bout);
+    
+    two.write(out);
+    
+    byte[] b = bout.toByteArray();
+    
+    ByteArrayInputStream bin = new ByteArrayInputStream(b);
+    DataInputStream din = new DataInputStream(bin);
+    
+    one.readFields(din);
+    
+    Assert.assertEquals(two.getFirst(), one.getFirst());
+    Assert.assertEquals(two.getSecond(), one.getSecond());    
+  }
+  
+  @Test
+  public void testComparable() throws IOException {
+    IntPairWritable[] input = {
+        new IntPairWritable(2,3),
+        new IntPairWritable(2,2),
+        new IntPairWritable(1,3),
+        new IntPairWritable(1,2),
+        new IntPairWritable(2,1),
+        new IntPairWritable(2,2),
+        new IntPairWritable(1,-2),
+        new IntPairWritable(1,-1),
+        new IntPairWritable(-2,-2),
+        new IntPairWritable(-2,-1),
+        new IntPairWritable(-1,-1),
+        new IntPairWritable(-1,-2),
+        new IntPairWritable(Integer.MAX_VALUE,1),
+        new IntPairWritable(Integer.MAX_VALUE/2,1),
+        new IntPairWritable(Integer.MIN_VALUE,1),
+        new IntPairWritable(Integer.MIN_VALUE/2,1)
+        
+    };
+    
+    IntPairWritable[] sorted = new IntPairWritable[input.length];
+    System.arraycopy(input, 0, sorted, 0, input.length);
+    Arrays.sort(sorted);
+    
+    int[] expected = {
+        14, 15, 8, 9, 11, 10, 6, 7, 3, 2, 4, 1, 5, 0, 13, 12
+    };
+    
+    for (int i=0; i < input.length; i++) {
+      Assert.assertSame(input[expected[i]], sorted[i]);
+    }
+ 
+  }
+}