You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/03/05 13:23:32 UTC
svn commit: r919388 - in /lucene/mahout/trunk/core/src:
main/java/org/apache/mahout/common/IntPairWritable.java
test/java/org/apache/mahout/common/IntPairWritableTest.java
Author: srowen
Date: Fri Mar 5 12:23:32 2010
New Revision: 919388
URL: http://svn.apache.org/viewvc?rev=919388&view=rev
Log:
IntPairWritable fixes, attached to MAHOUT-320
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java?rev=919388&r1=919387&r2=919388&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java Fri Mar 5 12:23:32 2010
@@ -28,11 +28,14 @@
import org.apache.hadoop.io.WritableComparator;
/**
- * Saves two ints, x and y.
+ * A {@link WritableComparable} which encapsulates an ordered pair of signed integers.
*/
-public final class IntPairWritable extends BinaryComparable implements WritableComparable<BinaryComparable> {
-
- private static final int INT_PAIR_BYTE_LENGTH = 8;
+public final class IntPairWritable
+ extends BinaryComparable
+ implements WritableComparable<BinaryComparable>, Serializable {
+
+ static final int INT_BYTE_LENGTH = 4;
+ static final int INT_PAIR_BYTE_LENGTH = 2 * INT_BYTE_LENGTH;
private byte[] b = new byte[INT_PAIR_BYTE_LENGTH];
public IntPairWritable() {
@@ -46,12 +49,12 @@
public IntPairWritable(int x, int y) {
putInt(x, b, 0);
- putInt(y, b, 4);
+ putInt(y, b, INT_BYTE_LENGTH);
}
public void set(int x, int y) {
putInt(x, b, 0);
- putInt(y, b, 4);
+ putInt(y, b, INT_BYTE_LENGTH);
}
public void setFirst(int x) {
@@ -63,11 +66,11 @@
}
public void setSecond(int y) {
- putInt(y, b, 4);
+ putInt(y, b, INT_BYTE_LENGTH);
}
public int getSecond() {
- return getInt(b, 4);
+ return getInt(b, INT_BYTE_LENGTH);
}
@Override
@@ -82,22 +85,37 @@
@Override
public int hashCode() {
- return 43 * Arrays.hashCode(b);
+ return Arrays.hashCode(b);
}
@Override
public boolean equals(Object obj) {
- if (this == obj) return true;
- if (!super.equals(obj)) return false;
- if (getClass() != obj.getClass()) return false;
+ if (this == obj) {
+ return true;
+ }
+ if (!super.equals(obj)) {
+ return false;
+ }
+ if (!(obj instanceof IntPairWritable)) {
+ return false;
+ }
IntPairWritable other = (IntPairWritable) obj;
- if (!Arrays.equals(b, other.b)) return false;
- return true;
+ return Arrays.equals(b, other.b);
+ }
+
+ @Override
+ public int compareTo(BinaryComparable other) {
+ return Comparator.doCompare(b, 0, ((IntPairWritable) other).b, 0);
+ }
+
+ @Override
+ public Object clone() {
+ return new IntPairWritable(this);
}
@Override
public String toString() {
- return "(" + getFirst() + ", " + getSecond() + ")";
+ return "(" + getFirst() + ", " + getSecond() + ')';
}
@Override
@@ -111,31 +129,23 @@
}
private static void putInt(int value, byte[] b, int offset) {
- if (offset + 4 > INT_PAIR_BYTE_LENGTH) {
- throw new IllegalArgumentException("offset+4 exceeds byte array length");
- }
-
- for (int i = 0; i < 4; i++) {
- b[offset + i] = (byte) (((value >>> ((3 - i) * 8)) & 0xFF) ^ 0x80);
+ for (int i = offset, j = 24; j >= 0; i++, j -= 8) {
+ b[i] = (byte) (value >>> j);
}
}
private static int getInt(byte[] b, int offset) {
- if (offset + 4 > INT_PAIR_BYTE_LENGTH) {
- throw new IllegalArgumentException("offset+4 exceeds byte array length");
- }
-
int value = 0;
- for (int i = 0; i < 4; i++) {
- value += ((b[i + offset] & 0xFF) ^ 0x80) << (3 - i) * 8;
+ for (int i = offset, j = 24; j >= 0; i++, j -= 8) {
+ value |= (b[i] & 0xFF) << j;
}
return value;
}
-
+
static {
WritableComparator.define(IntPairWritable.class, new Comparator());
}
-
+
public static final class Comparator extends WritableComparator implements Serializable {
public Comparator() {
super(IntPairWritable.class);
@@ -143,10 +153,32 @@
@Override
public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
- if (l1 != 8 || l2 != 8) {
- throw new IllegalArgumentException();
+ return doCompare(b1, s1, b2, s2);
+ }
+
+ static int doCompare(byte[] b1, int s1, byte[] b2, int s2) {
+ int compare1 = compareInts(b1, s1, b2, s2);
+ if (compare1 != 0) {
+ return compare1;
+ }
+ return compareInts(b1, s1 + INT_BYTE_LENGTH, b2, s2 + INT_BYTE_LENGTH);
+ }
+
+ private static int compareInts(byte[] b1, int s1, byte[] b2, int s2) {
+ // Like WritableComparator.compareBytes(), but treats first byte as signed value
+ int end1 = s1 + INT_BYTE_LENGTH;
+ for (int i = s1, j = s2; i < end1; i++, j++) {
+ int a = b1[i];
+ int b = b2[j];
+ if (i > s1) {
+ a &= 0xff;
+ b &= 0xff;
+ }
+ if (a != b) {
+ return a - b;
+ }
}
- return WritableComparator.compareBytes(b1, s1, l1, b2, s2, l2);
+ return 0;
}
}
@@ -161,42 +193,45 @@
@Override
public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
- int ret;
int firstb1 = WritableComparator.readInt(b1, s1);
int firstb2 = WritableComparator.readInt(b2, s2);
- ret = firstb1 - firstb2;
- return ret;
+ if (firstb1 < firstb2) {
+ return -1;
+ } else if (firstb1 > firstb2) {
+ return 1;
+ } else {
+ return 0;
+ }
}
@Override
public int compare(Object o1, Object o2) {
- if (o1 == null) {
+ int firstb1 = ((IntPairWritable) o1).getFirst();
+ int firstb2 = ((IntPairWritable) o2).getFirst();
+ if (firstb1 < firstb2) {
return -1;
- } else if (o2 == null) {
+ } else if (firstb1 > firstb2) {
return 1;
- } else {
- int firstb1 = ((IntPairWritable) o1).getFirst();
- int firstb2 = ((IntPairWritable) o2).getFirst();
- return firstb1 - firstb2;
}
+ return 0;
}
}
/** A wrapper class that associates pairs with frequency (Occurences) */
- public static class Frequency implements Comparable<Frequency> {
-
- private IntPairWritable pair = new IntPairWritable();
- private double frequency = 0.0;
+ public static class Frequency implements Comparable<Frequency>, Serializable {
+ private final IntPairWritable pair;
+ private final double frequency;
+
public double getFrequency() {
return frequency;
}
-
+
public IntPairWritable getPair() {
return pair;
}
-
+
public Frequency(IntPairWritable bigram, double frequency) {
this.pair = new IntPairWritable(bigram);
this.frequency = frequency;
@@ -204,21 +239,26 @@
@Override
public int hashCode() {
- return pair.hashCode() + (int) Math.abs(Math.round(frequency * 31));
+ return pair.hashCode() + RandomUtils.hashDouble(frequency);
}
@Override
public boolean equals(Object right) {
- if ((right == null) || !(right instanceof Frequency)) {
+ if (!(right instanceof Frequency)) {
return false;
}
Frequency that = (Frequency) right;
- return this.compareTo(that) == 0;
+ return pair.equals(that.pair) && frequency == that.frequency;
}
@Override
public int compareTo(Frequency that) {
- return this.frequency > that.frequency ? 1 : -1;
+ if (frequency < that.frequency) {
+ return -1;
+ } else if (frequency > that.frequency) {
+ return 1;
+ }
+ return 0;
}
@Override
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java?rev=919388&r1=919387&r2=919388&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java Fri Mar 5 12:23:32 2010
@@ -1,3 +1,20 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.mahout.common;
import java.io.ByteArrayInputStream;
@@ -9,10 +26,8 @@
import junit.framework.Assert;
-import org.apache.mahout.common.IntPairWritable;
import org.junit.Test;
-
public class IntPairWritableTest {
@Test