You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/03/04 06:40:03 UTC
svn commit: r918860 - in /lucene/mahout/trunk/core/src:
main/java/org/apache/mahout/clustering/lda/
main/java/org/apache/mahout/common/
test/java/org/apache/mahout/clustering/lda/
test/java/org/apache/mahout/common/
Author: robinanil
Date: Thu Mar 4 05:40:03 2010
New Revision: 918860
URL: http://svn.apache.org/viewvc?rev=918860&view=rev
Log:
MAHOUT-320 Improvements in LDA
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java
- copied, changed from r918394, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java
Removed:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java?rev=918860&r1=918859&r2=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDADriver.java Thu Mar 4 05:40:03 2010
@@ -41,6 +41,7 @@
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.slf4j.Logger;
@@ -52,9 +53,9 @@
*/
public final class LDADriver {
- static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
+ static final String STATE_IN_KEY = "org.apache.mahout.clustering.lda.stateIn";
static final String NUM_TOPICS_KEY = "org.apache.mahout.clustering.lda.numTopics";
- static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
+ static final String NUM_WORDS_KEY = "org.apache.mahout.clustering.lda.numWords";
static final String TOPIC_SMOOTHING_KEY = "org.apache.mahout.clustering.lda.topicSmoothing";
static final int LOG_LIKELIHOOD_KEY = -2;
@@ -63,7 +64,7 @@
private static final Logger log = LoggerFactory.getLogger(LDADriver.class);
- private LDADriver() { }
+ private LDADriver() {}
public static void main(String[] args) throws ClassNotFoundException, IOException, InterruptedException {
@@ -196,8 +197,7 @@
log.info("Iteration {}", iteration);
// point the output to a new directory per iteration
String stateOut = output + "/state-" + (iteration + 1);
- double ll = runIteration(input, stateIn, stateOut, numTopics, numWords, topicSmoothing,
- numReducers);
+ double ll = runIteration(input, stateIn, stateOut, numTopics, numWords, topicSmoothing, numReducers);
double relChange = (oldLL - ll) / oldLL;
// now point the input to the old output directory
@@ -216,7 +216,6 @@
Configuration job = new Configuration();
FileSystem fs = dir.getFileSystem(job);
- IntPairWritable kw = new IntPairWritable();
DoubleWritable v = new DoubleWritable();
Random random = RandomUtils.getRandom();
@@ -226,20 +225,18 @@
SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path, IntPairWritable.class,
DoubleWritable.class);
- kw.setX(k);
double total = 0.0; // total number of pseudo counts we made
for (int w = 0; w < numWords; ++w) {
- kw.setY(w);
+ IntPairWritable kw = new IntPairWritable(k, w);
// A small amount of random noise, minimized by having a floor.
double pseudocount = random.nextDouble() + 1.0E-8;
total += pseudocount;
v.set(Math.log(pseudocount));
writer.append(kw, v);
}
-
- kw.setY(TOPIC_SUM_KEY);
+ IntPairWritable kTsk = new IntPairWritable(k, TOPIC_SUM_KEY);
v.set(Math.log(total));
- writer.append(kw, v);
+ writer.append(kTsk, v);
writer.close();
}
@@ -257,7 +254,7 @@
Path path = status.getPath();
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
while (reader.next(key, value)) {
- if (key.getX() == LOG_LIKELIHOOD_KEY) {
+ if (key.getFirst() == LOG_LIKELIHOOD_KEY) {
ll = value.get();
break;
}
@@ -336,8 +333,8 @@
Path path = status.getPath();
SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, job);
while (reader.next(key, value)) {
- int topic = key.getX();
- int word = key.getY();
+ int topic = key.getFirst();
+ int word = key.getSecond();
if (word == TOPIC_SUM_KEY) {
logTotals[topic] = value.get();
if (Double.isInfinite(value.get())) {
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java?rev=918860&r1=918859&r2=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAInference.java Thu Mar 4 05:40:03 2010
@@ -25,7 +25,6 @@
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.BinaryFunction;
-import org.apache.mahout.math.map.OpenIntIntHashMap;
/**
* Class for performing infererence on a document, which involves computing (an approximation to)
@@ -49,14 +48,14 @@
private final Vector wordCounts;
private final Vector gamma; // p(topic)
private final Matrix mphi; // log p(columnMap(w)|t)
- private final OpenIntIntHashMap columnMap; // maps words into the matrix's column map
+ private final int[] columnMap; // maps words into the matrix's column map
public final double logLikelihood;
public double phi(int k, int w) {
- return mphi.getQuick(k, columnMap.get(w));
+ return mphi.getQuick(k, columnMap[w]);
}
- InferredDocument(Vector wordCounts, Vector gamma, OpenIntIntHashMap columnMap, Matrix phi, double ll) {
+ InferredDocument(Vector wordCounts, Vector gamma, int[] columnMap, Matrix phi, double ll) {
this.wordCounts = wordCounts;
this.gamma = gamma;
this.mphi = phi;
@@ -78,7 +77,7 @@
*/
public InferredDocument infer(Vector wordCounts) {
double docTotal = wordCounts.zSum();
- int docLength = wordCounts.size();
+ int docLength = wordCounts.size(); // cardinality of document vectors
// initialize variational approximation to p(z|doc)
Vector gamma = new DenseVector(state.numTopics);
@@ -86,13 +85,9 @@
Vector nextGamma = new DenseVector(state.numTopics);
createPhiMatrix(docLength);
- // digamma is expensive, precompute
- Vector digammaGamma = digamma(gamma);
- // and log normalize:
- double digammaSumGamma = digamma(gamma.zSum());
- digammaGamma = digammaGamma.plus(-digammaSumGamma);
+ Vector digammaGamma = digammaGamma(gamma);
- OpenIntIntHashMap columnMap = new OpenIntIntHashMap();
+ int[] map = new int[docLength];
int iteration = 0;
@@ -108,12 +103,12 @@
Vector phiW = eStepForWord(word, digammaGamma);
phi.assignColumn(mapping, phiW);
if (iteration == 0) { // first iteration
- columnMap.put(word, mapping);
+ map[word] = mapping;
}
for (int k = 0; k < nextGamma.size(); ++k) {
double g = nextGamma.getQuick(k);
- nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.get(k)));
+ nextGamma.setQuick(k, g + e.get() * Math.exp(phiW.getQuick(k)));
}
mapping++;
@@ -123,31 +118,36 @@
gamma = nextGamma;
nextGamma = tempG;
- // digamma is expensive, precompute
- digammaGamma = digamma(gamma);
- // and log normalize:
- digammaSumGamma = digamma(gamma.zSum());
- digammaGamma = digammaGamma.plus(-digammaSumGamma);
+ digammaGamma = digammaGamma(gamma);
- double ll = computeLikelihood(wordCounts, columnMap, phi, gamma, digammaGamma);
- assert !Double.isNaN(ll);
+ double ll = computeLikelihood(wordCounts, map, phi, gamma, digammaGamma);
+ // isNotNaNAssertion(ll);
converged = (oldLL < 0) && ((oldLL - ll) / oldLL < E_STEP_CONVERGENCE);
oldLL = ll;
iteration++;
}
- return new InferredDocument(wordCounts, gamma, columnMap, phi, oldLL);
+ return new InferredDocument(wordCounts, gamma, map, phi, oldLL);
+ }
+
+ private Vector digammaGamma(Vector gamma) {
+ // digamma is expensive, precompute
+ Vector digammaGamma = digamma(gamma);
+ // and log normalize:
+ double digammaSumGamma = digamma(gamma.zSum());
+ for (int i = 0; i < state.numTopics; i++) {
+ digammaGamma.setQuick(i, digammaGamma.getQuick(i) - digammaSumGamma);
+ }
+ return digammaGamma;
}
private void createPhiMatrix(int docLength) {
- if (phi == null){
+ if (phi == null) {
phi = new DenseMatrix(state.numTopics, docLength);
- }
- else if (phi.getRow(0).size() != docLength){
+ } else if (phi.getRow(0).size() != docLength) {
phi = new DenseMatrix(state.numTopics, docLength);
- }
- else {
+ } else {
phi.assign(0);
}
}
@@ -155,46 +155,43 @@
private DenseMatrix phi;
private final LDAState state;
- private double computeLikelihood(Vector wordCounts,
- OpenIntIntHashMap columnMap,
- Matrix phi,
- Vector gamma,
- Vector digammaGamma) {
+ private double computeLikelihood(Vector wordCounts, int[] map, Matrix phi, Vector gamma, Vector digammaGamma) {
double ll = 0.0;
// log normalizer for q(gamma);
ll += Gamma.logGamma(state.topicSmoothing * state.numTopics);
ll -= state.numTopics * Gamma.logGamma(state.topicSmoothing);
- assert !Double.isNaN(ll) : state.topicSmoothing + " " + state.numTopics;
+ // isNotNaNAssertion(ll);
// now for the the rest of q(gamma);
for (int k = 0; k < state.numTopics; ++k) {
- ll += (state.topicSmoothing - gamma.get(k)) * digammaGamma.get(k);
- ll += Gamma.logGamma(gamma.get(k));
+ double gammaK = gamma.get(k);
+ ll += (state.topicSmoothing - gammaK) * digammaGamma.getQuick(k);
+ ll += Gamma.logGamma(gammaK);
}
ll -= Gamma.logGamma(gamma.zSum());
- assert !Double.isNaN(ll);
+ // isNotNaNAssertion(ll);
// for each word
for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
Vector.Element e = iter.next();
int w = e.index();
double n = e.get();
- int mapping = columnMap.get(w);
+ int mapping = map[w];
// now for each topic:
for (int k = 0; k < state.numTopics; k++) {
double llPart = 0.0;
- llPart += Math.exp(phi.getQuick(k, mapping))
- * (digammaGamma.get(k) - phi.getQuick(k, mapping) + state.logProbWordGivenTopic(w, k));
+ double phiKMapping = phi.getQuick(k, mapping);
+ llPart += Math.exp(phiKMapping)
+ * (digammaGamma.getQuick(k) - phiKMapping + state.logProbWordGivenTopic(w, k));
ll += llPart * n;
- assert state.logProbWordGivenTopic(w, k) < 0;
- assert !Double.isNaN(llPart);
+ // likelihoodAssertion(w, k, llPart);
}
}
- assert ll <= 0;
+ // isLessThanOrEqualsZero(ll);
return ll;
}
@@ -205,13 +202,10 @@
Vector phi = new DenseVector(state.numTopics); // log q(k|w), for each w
double phiTotal = Double.NEGATIVE_INFINITY; // log Normalizer
for (int k = 0; k < state.numTopics; ++k) { // update q(k|w)'s param phi
- phi.set(k, state.logProbWordGivenTopic(word, k) + digammaGamma.get(k));
- phiTotal = LDAUtil.logSum(phiTotal, phi.get(k));
+ phi.setQuick(k, state.logProbWordGivenTopic(word, k) + digammaGamma.getQuick(k));
+ phiTotal = LDAUtil.logSum(phiTotal, phi.getQuick(k));
- assert !Double.isNaN(phiTotal);
- assert !Double.isNaN(state.logProbWordGivenTopic(word, k));
- assert !Double.isInfinite(state.logProbWordGivenTopic(word, k));
- assert !Double.isNaN(digammaGamma.get(k));
+ // assertions(word, digammaGamma, phiTotal, k);
}
for (int i = 0; i < state.numTopics; i++) {
phi.setQuick(i, phi.getQuick(i) - phiTotal);// log normalize
@@ -229,7 +223,7 @@
});
return digammaGamma;
}
-
+
/**
* Approximation to the digamma function, from Radford Neal.
*
@@ -260,4 +254,25 @@
return r + Math.log(x) - 0.5 / x + t;
}
+ /*
+ private void assertions(int word, Vector digammaGamma, double phiTotal, int k) {
+ assert !Double.isNaN(phiTotal);
+ assert !Double.isNaN(state.logProbWordGivenTopic(word, k));
+ assert !Double.isInfinite(state.logProbWordGivenTopic(word, k));
+ assert !Double.isNaN(digammaGamma.getQuick(k));
+ }
+
+ private void likelihoodAssertion(int w, int k, double llPart) {
+ assert state.logProbWordGivenTopic(w, k) < 0;
+ assert !Double.isNaN(llPart);
+ }
+
+ private void isLessThanOrEqualsZero(double ll) {
+ assert ll <= 0;
+ }
+
+ private void isNotNaNAssertion(double ll) {
+ assert !Double.isNaN(ll) : state.topicSmoothing + " " + state.numTopics;
+ }
+ */
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java?rev=918860&r1=918859&r2=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAMapper.java Thu Mar 4 05:40:03 2010
@@ -25,6 +25,7 @@
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -47,16 +48,15 @@
Arrays.fill(logTotals, Double.NEGATIVE_INFINITY);
// Output sufficient statistics for each word. == pseudo-log counts.
- IntPairWritable kw = new IntPairWritable();
DoubleWritable v = new DoubleWritable();
for (Iterator<Vector.Element> iter = wordCounts.iterateNonZero(); iter.hasNext();) {
Vector.Element e = iter.next();
int w = e.index();
- kw.setY(w);
+
for (int k = 0; k < state.numTopics; ++k) {
v.set(doc.phi(k, w) + Math.log(e.get()));
- kw.setX(k);
+ IntPairWritable kw = new IntPairWritable(k, w);
// ouput (topic, word)'s logProb contribution
context.write(kw, v);
@@ -66,19 +66,16 @@
// Output the totals for the statistics. This is to make
// normalizing a lot easier.
- kw.setY(LDADriver.TOPIC_SUM_KEY);
for (int k = 0; k < state.numTopics; ++k) {
- kw.setX(k);
+ IntPairWritable kw = new IntPairWritable(k, LDADriver.TOPIC_SUM_KEY);
v.set(logTotals[k]);
assert !Double.isNaN(v.get());
context.write(kw, v);
}
-
+ IntPairWritable llk = new IntPairWritable(LDADriver.LOG_LIKELIHOOD_KEY, LDADriver.LOG_LIKELIHOOD_KEY);
// Output log-likelihoods.
- kw.setX(LDADriver.LOG_LIKELIHOOD_KEY);
- kw.setY(LDADriver.LOG_LIKELIHOOD_KEY);
v.set(doc.logLikelihood);
- context.write(kw, v);
+ context.write(llk, v);
}
public void configure(LDAState myState) {
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java?rev=918860&r1=918859&r2=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/LDAReducer.java Thu Mar 4 05:40:03 2010
@@ -18,6 +18,7 @@
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.common.IntPairWritable;
/**
* A very simple reducer which simply logSums the input doubles and outputs a new double for sufficient
@@ -31,12 +32,12 @@
Context context) throws java.io.IOException, InterruptedException {
// sum likelihoods
- if (topicWord.getY() == LDADriver.LOG_LIKELIHOOD_KEY) {
+ if (topicWord.getSecond() == LDADriver.LOG_LIKELIHOOD_KEY) {
double accum = 0.0;
for (DoubleWritable vw : values) {
double v = vw.get();
if (Double.isNaN(v)) {
- throw new IllegalArgumentException(topicWord.getX() + " " + topicWord.getY());
+ throw new IllegalArgumentException(topicWord.getFirst() + " " + topicWord.getSecond());
}
accum += v;
}
@@ -46,11 +47,11 @@
for (DoubleWritable vw : values) {
double v = vw.get();
if (Double.isNaN(v)) {
- throw new IllegalArgumentException(topicWord.getX() + " " + topicWord.getY());
+ throw new IllegalArgumentException(topicWord.getFirst() + " " + topicWord.getSecond());
}
accum = LDAUtil.logSum(accum, v);
if (Double.isNaN(accum)) {
- throw new IllegalArgumentException(topicWord.getX() + " " + topicWord.getY());
+ throw new IllegalArgumentException(topicWord.getFirst() + " " + topicWord.getSecond());
}
}
context.write(topicWord, new DoubleWritable(accum));
Copied: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java (from r918394, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java)
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java?p2=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java&p1=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java&r1=918394&r2=918860&rev=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/lda/IntPairWritable.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java Thu Mar 4 05:40:03 2010
@@ -15,117 +15,215 @@
* limitations under the License.
*/
-package org.apache.mahout.clustering.lda;
+package org.apache.mahout.common;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.io.Serializable;
+import java.util.Arrays;
+import org.apache.hadoop.io.BinaryComparable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.io.WritableComparator;
/**
* Saves two ints, x and y.
*/
-public class IntPairWritable implements WritableComparable<IntPairWritable> {
+public final class IntPairWritable extends BinaryComparable implements WritableComparable<BinaryComparable> {
- private int x;
- private int y;
+ private static final int INT_PAIR_BYTE_LENGTH = 8;
+ private byte[] b = new byte[INT_PAIR_BYTE_LENGTH];
- /** For serialization purposes only */
- public IntPairWritable() { }
+ public IntPairWritable() {
+ setFirst(0);
+ setSecond(0);
+ }
+
+ public IntPairWritable(IntPairWritable pair) {
+ b = Arrays.copyOf(pair.getBytes(), INT_PAIR_BYTE_LENGTH);
+ }
public IntPairWritable(int x, int y) {
- this.x = x;
- this.y = y;
+ putInt(x, b, 0);
+ putInt(y, b, 4);
}
- public void setX(int x) {
- this.x = x;
+ public void set(int x, int y) {
+ putInt(x, b, 0);
+ putInt(y, b, 4);
}
- public int getX() {
- return x;
+ public void setFirst(int x) {
+ putInt(x, b, 0);
}
- public void setY(int y) {
- this.y = y;
+ public int getFirst() {
+ return getInt(b, 0);
}
- public int getY() {
- return y;
+ public void setSecond(int y) {
+ putInt(y, b, 4);
}
- @Override
- public void write(DataOutput dataOutput) throws IOException {
- dataOutput.writeInt(x);
- dataOutput.writeInt(y);
+ public int getSecond() {
+ return getInt(b, 4);
}
@Override
- public void readFields(DataInput dataInput) throws IOException {
- x = dataInput.readInt();
- y = dataInput.readInt();
+ public void readFields(DataInput in) throws IOException {
+ in.readFully(b);
}
@Override
- public int compareTo(IntPairWritable that) {
- if (this.x < that.getX()) {
- return -1;
- } else if (this.x > that.getX()) {
- return 1;
- } else {
- return this.y < that.getY() ? -1 : this.y > that.getY() ? 1 : 0;
- }
+ public void write(DataOutput out) throws IOException {
+ out.write(b);
}
@Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- } else if (!(o instanceof IntPairWritable)) {
- return false;
- }
-
- IntPairWritable that = (IntPairWritable) o;
-
- return (that.getX() == this.x) && (this.y == that.getY());
+ public int hashCode() {
+ return 43 * Arrays.hashCode(b);
}
@Override
- public int hashCode() {
- return 43 * x + y;
+ public boolean equals(Object obj) {
+ if (this == obj) return true;
+ if (!super.equals(obj)) return false;
+ if (getClass() != obj.getClass()) return false;
+ IntPairWritable other = (IntPairWritable) obj;
+ if (!Arrays.equals(b, other.b)) return false;
+ return true;
}
@Override
public String toString() {
- return "(" + x + ", " + y + ')';
+ return "(" + getFirst() + ", " + getSecond() + ")";
+ }
+
+ @Override
+ public byte[] getBytes() {
+ return b;
+ }
+
+ @Override
+ public int getLength() {
+ return INT_PAIR_BYTE_LENGTH;
+ }
+
+ private static void putInt(int value, byte[] b, int offset) {
+ if (offset + 4 > INT_PAIR_BYTE_LENGTH) {
+ throw new IllegalArgumentException("offset+4 exceeds byte array length");
+ }
+
+ for (int i = 0; i < 4; i++) {
+ b[offset + i] = (byte) (((value >>> ((3 - i) * 8)) & 0xFF) ^ 0x80);
+ }
+ }
+
+ private static int getInt(byte[] b, int offset) {
+ if (offset + 4 > INT_PAIR_BYTE_LENGTH) {
+ throw new IllegalArgumentException("offset+4 exceeds byte array length");
+ }
+
+ int value = 0;
+ for (int i = 0; i < 4; i++) {
+ value += ((b[i + offset] & 0xFF) ^ 0x80) << (3 - i) * 8;
+ }
+ return value;
}
static {
WritableComparator.define(IntPairWritable.class, new Comparator());
}
- public static class Comparator extends WritableComparator implements Serializable {
+ public static final class Comparator extends WritableComparator implements Serializable {
public Comparator() {
super(IntPairWritable.class);
}
@Override
public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
- if (l1 != 8) {
+ if (l1 != 8 || l2 != 8) {
throw new IllegalArgumentException();
}
- int int11 = WritableComparator.readInt(b1, s1);
- int int21 = WritableComparator.readInt(b2, s2);
- if (int11 != int21) {
- return int11 - int21;
+ return WritableComparator.compareBytes(b1, s1, l1, b2, s2, l2);
+ }
+ }
+
+ /**
+ * Compare only the first part of the pair, so that reduce is called once for each value of the first part.
+ */
+ public static class FirstGroupingComparator extends WritableComparator implements Serializable {
+
+ public FirstGroupingComparator() {
+ super(IntPairWritable.class);
+ }
+
+ @Override
+ public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
+ int ret;
+ int firstb1 = WritableComparator.readInt(b1, s1);
+ int firstb2 = WritableComparator.readInt(b2, s2);
+ ret = firstb1 - firstb2;
+ return ret;
+ }
+
+ @Override
+ public int compare(Object o1, Object o2) {
+ if (o1 == null) {
+ return -1;
+ } else if (o2 == null) {
+ return 1;
+ } else {
+ int firstb1 = ((IntPairWritable) o1).getFirst();
+ int firstb2 = ((IntPairWritable) o2).getFirst();
+ return firstb1 - firstb2;
+ }
+ }
+
+ }
+
+ /** A wrapper class that associates pairs with frequency (Occurences) */
+ public static class Frequency implements Comparable<Frequency> {
+
+ private IntPairWritable pair = new IntPairWritable();
+ private double frequency = 0.0;
+
+ public double getFrequency() {
+ return frequency;
+ }
+
+ public IntPairWritable getPair() {
+ return pair;
+ }
+
+ public Frequency(IntPairWritable bigram, double frequency) {
+ this.pair = new IntPairWritable(bigram);
+ this.frequency = frequency;
+ }
+
+ @Override
+ public int hashCode() {
+ return pair.hashCode() + (int) Math.abs(Math.round(frequency * 31));
+ }
+
+ @Override
+ public boolean equals(Object right) {
+ if ((right == null) || !(right instanceof Frequency)) {
+ return false;
}
-
- int int12 = WritableComparator.readInt(b1, s1 + 4);
- int int22 = WritableComparator.readInt(b2, s2 + 4);
- return int12 - int22;
+ Frequency that = (Frequency) right;
+ return this.compareTo(that) == 0;
+ }
+
+ @Override
+ public int compareTo(Frequency that) {
+ return this.frequency > that.frequency ? 1 : -1;
+ }
+
+ @Override
+ public String toString() {
+ return pair + "\t" + frequency;
}
}
}
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java?rev=918860&r1=918859&r2=918860&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/lda/TestMapReduce.java Thu Mar 4 05:40:03 2010
@@ -16,24 +16,29 @@
*/
package org.apache.mahout.clustering.lda;
+import static org.easymock.EasyMock.expectLastCall;
+import static org.easymock.EasyMock.isA;
+import static org.easymock.classextension.EasyMock.createMock;
+import static org.easymock.classextension.EasyMock.replay;
+import static org.easymock.classextension.EasyMock.verify;
+
import java.io.File;
import java.util.Iterator;
import java.util.Random;
+import org.apache.commons.math.MathException;
import org.apache.commons.math.distribution.PoissonDistribution;
import org.apache.commons.math.distribution.PoissonDistributionImpl;
-import org.apache.commons.math.MathException;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.Text;
+import org.apache.mahout.common.IntPairWritable;
import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
-import org.apache.mahout.common.RandomUtils;
-
-import static org.easymock.classextension.EasyMock.*;
public class TestMapReduce extends MahoutTestCase {
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java?rev=918860&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java Thu Mar 4 05:40:03 2010
@@ -0,0 +1,100 @@
+package org.apache.mahout.common;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Arrays;
+
+import junit.framework.Assert;
+
+import org.apache.mahout.common.IntPairWritable;
+import org.junit.Test;
+
+
+public class IntPairWritableTest {
+
+ @Test
+ public void testGetSet() {
+ IntPairWritable n = new IntPairWritable();
+
+ Assert.assertEquals(0, n.getFirst());
+ Assert.assertEquals(0, n.getSecond());
+
+ n.setFirst(5);
+ n.setSecond(10);
+
+ Assert.assertEquals(5, n.getFirst());
+ Assert.assertEquals(10, n.getSecond());
+
+ n = new IntPairWritable(2,4);
+
+ Assert.assertEquals(2, n.getFirst());
+ Assert.assertEquals(4, n.getSecond());
+ }
+
+ @Test
+ public void testWritable() throws IOException {
+ IntPairWritable one = new IntPairWritable(1,2);
+ IntPairWritable two = new IntPairWritable(3,4);
+
+ Assert.assertEquals(1, one.getFirst());
+ Assert.assertEquals(2, one.getSecond());
+
+ Assert.assertEquals(3, two.getFirst());
+ Assert.assertEquals(4, two.getSecond());
+
+
+ ByteArrayOutputStream bout = new ByteArrayOutputStream();
+ DataOutputStream out = new DataOutputStream(bout);
+
+ two.write(out);
+
+ byte[] b = bout.toByteArray();
+
+ ByteArrayInputStream bin = new ByteArrayInputStream(b);
+ DataInputStream din = new DataInputStream(bin);
+
+ one.readFields(din);
+
+ Assert.assertEquals(two.getFirst(), one.getFirst());
+ Assert.assertEquals(two.getSecond(), one.getSecond());
+ }
+
+ @Test
+ public void testComparable() throws IOException {
+ IntPairWritable[] input = {
+ new IntPairWritable(2,3),
+ new IntPairWritable(2,2),
+ new IntPairWritable(1,3),
+ new IntPairWritable(1,2),
+ new IntPairWritable(2,1),
+ new IntPairWritable(2,2),
+ new IntPairWritable(1,-2),
+ new IntPairWritable(1,-1),
+ new IntPairWritable(-2,-2),
+ new IntPairWritable(-2,-1),
+ new IntPairWritable(-1,-1),
+ new IntPairWritable(-1,-2),
+ new IntPairWritable(Integer.MAX_VALUE,1),
+ new IntPairWritable(Integer.MAX_VALUE/2,1),
+ new IntPairWritable(Integer.MIN_VALUE,1),
+ new IntPairWritable(Integer.MIN_VALUE/2,1)
+
+ };
+
+ IntPairWritable[] sorted = new IntPairWritable[input.length];
+ System.arraycopy(input, 0, sorted, 0, input.length);
+ Arrays.sort(sorted);
+
+ int[] expected = {
+ 14, 15, 8, 9, 11, 10, 6, 7, 3, 2, 4, 1, 5, 0, 13, 12
+ };
+
+ for (int i=0; i < input.length; i++) {
+ Assert.assertSame(input[expected[i]], sorted[i]);
+ }
+
+ }
+}