You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/03/05 13:23:32 UTC

svn commit: r919388 - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/common/IntPairWritable.java test/java/org/apache/mahout/common/IntPairWritableTest.java

Author: srowen
Date: Fri Mar  5 12:23:32 2010
New Revision: 919388

URL: http://svn.apache.org/viewvc?rev=919388&view=rev
Log:
IntPairWritable fixes, attached to MAHOUT-320

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java?rev=919388&r1=919387&r2=919388&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/common/IntPairWritable.java Fri Mar  5 12:23:32 2010
@@ -28,11 +28,14 @@
 import org.apache.hadoop.io.WritableComparator;
 
 /**
- * Saves two ints, x and y.
+ * A {@link WritableComparable} which encapsulates an ordered pair of signed integers.
  */
-public final class IntPairWritable extends BinaryComparable implements WritableComparable<BinaryComparable> {
-  
-  private static final int INT_PAIR_BYTE_LENGTH = 8;
+public final class IntPairWritable
+    extends BinaryComparable 
+    implements WritableComparable<BinaryComparable>, Serializable {
+
+  static final int INT_BYTE_LENGTH = 4;
+  static final int INT_PAIR_BYTE_LENGTH = 2 * INT_BYTE_LENGTH;
   private byte[] b = new byte[INT_PAIR_BYTE_LENGTH];
   
   public IntPairWritable() {
@@ -46,12 +49,12 @@
   
   public IntPairWritable(int x, int y) {
     putInt(x, b, 0);
-    putInt(y, b, 4);
+    putInt(y, b, INT_BYTE_LENGTH);
   }
   
   public void set(int x, int y) {
     putInt(x, b, 0);
-    putInt(y, b, 4);
+    putInt(y, b, INT_BYTE_LENGTH);
   }
   
   public void setFirst(int x) {
@@ -63,11 +66,11 @@
   }
   
   public void setSecond(int y) {
-    putInt(y, b, 4);
+    putInt(y, b, INT_BYTE_LENGTH);
   }
   
   public int getSecond() {
-    return getInt(b, 4);
+    return getInt(b, INT_BYTE_LENGTH);
   }
   
   @Override
@@ -82,22 +85,37 @@
   
   @Override
   public int hashCode() {
-    return 43 * Arrays.hashCode(b);
+    return Arrays.hashCode(b);
   }
   
   @Override
   public boolean equals(Object obj) {
-    if (this == obj) return true;
-    if (!super.equals(obj)) return false;
-    if (getClass() != obj.getClass()) return false;
+    if (this == obj) {
+      return true;
+    }
+    if (!super.equals(obj)) {
+      return false;
+    }
+    if (!(obj instanceof IntPairWritable)) {
+      return false;
+    }
     IntPairWritable other = (IntPairWritable) obj;
-    if (!Arrays.equals(b, other.b)) return false;
-    return true;
+    return Arrays.equals(b, other.b);
+  }
+
+  @Override
+  public int compareTo(BinaryComparable other) {
+    return Comparator.doCompare(b, 0, ((IntPairWritable) other).b, 0);
+  }
+
+  @Override
+  public Object clone() {
+    return new IntPairWritable(this);
   }
   
   @Override
   public String toString() {
-    return "(" + getFirst() + ", " + getSecond() + ")";
+    return "(" + getFirst() + ", " + getSecond() + ')';
   }
   
   @Override
@@ -111,31 +129,23 @@
   }
   
   private static void putInt(int value, byte[] b, int offset) {
-    if (offset + 4 > INT_PAIR_BYTE_LENGTH) {
-      throw new IllegalArgumentException("offset+4 exceeds byte array length");
-    }
-    
-    for (int i = 0; i < 4; i++) {
-      b[offset + i] = (byte) (((value >>> ((3 - i) * 8)) & 0xFF) ^ 0x80);
+    for (int i = offset, j = 24; j >= 0; i++, j -= 8) {
+      b[i] = (byte) (value >>> j);
     }
   }
   
   private static int getInt(byte[] b, int offset) {
-    if (offset + 4 > INT_PAIR_BYTE_LENGTH) {
-      throw new IllegalArgumentException("offset+4 exceeds byte array length");
-    }
-    
     int value = 0;
-    for (int i = 0; i < 4; i++) {
-      value += ((b[i + offset] & 0xFF) ^ 0x80) << (3 - i) * 8;
+    for (int i = offset, j = 24; j >= 0; i++, j -= 8) {
+      value |= (b[i] & 0xFF) << j;
     }
     return value;
   }
-  
+
   static {
     WritableComparator.define(IntPairWritable.class, new Comparator());
   }
-  
+
   public static final class Comparator extends WritableComparator implements Serializable {
     public Comparator() {
       super(IntPairWritable.class);
@@ -143,10 +153,32 @@
     
     @Override
     public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
-      if (l1 != 8 || l2 != 8) {
-        throw new IllegalArgumentException();
+      return doCompare(b1, s1, b2, s2);
+    }
+
+    static int doCompare(byte[] b1, int s1, byte[] b2, int s2) {
+      int compare1 = compareInts(b1, s1, b2, s2);
+      if (compare1 != 0) {
+        return compare1;
+      }
+      return compareInts(b1, s1 + INT_BYTE_LENGTH, b2, s2 + INT_BYTE_LENGTH);
+    }
+
+    private static int compareInts(byte[] b1, int s1, byte[] b2, int s2) {
+      // Like WritableComparator.compareBytes(), but treats first byte as signed value
+      int end1 = s1 + INT_BYTE_LENGTH;
+      for (int i = s1, j = s2; i < end1; i++, j++) {
+        int a = b1[i];
+        int b = b2[j];
+        if (i > s1) {
+          a &= 0xff;
+          b &= 0xff;
+        }
+        if (a != b) {
+          return a - b;
+        }
       }
-      return WritableComparator.compareBytes(b1, s1, l1, b2, s2, l2);
+      return 0;
     }
   }
   
@@ -161,42 +193,45 @@
     
     @Override
     public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
-      int ret;
       int firstb1 = WritableComparator.readInt(b1, s1);
       int firstb2 = WritableComparator.readInt(b2, s2);
-      ret = firstb1 - firstb2;
-      return ret;
+      if (firstb1 < firstb2) {
+        return -1;
+      } else if (firstb1 > firstb2) {
+        return 1;
+      } else {
+        return 0;
+      }
     }
     
     @Override
     public int compare(Object o1, Object o2) {
-      if (o1 == null) {
+      int firstb1 = ((IntPairWritable) o1).getFirst();
+      int firstb2 = ((IntPairWritable) o2).getFirst();
+      if (firstb1 < firstb2) {
         return -1;
-      } else if (o2 == null) {
+      } else if (firstb1 > firstb2) {
         return 1;
-      } else {
-        int firstb1 = ((IntPairWritable) o1).getFirst();
-        int firstb2 = ((IntPairWritable) o2).getFirst();
-        return firstb1 - firstb2;
       }
+      return 0;
     }
     
   }
   
   /** A wrapper class that associates pairs with frequency (Occurences) */
-  public static class Frequency implements Comparable<Frequency> {
-    
-    private IntPairWritable pair = new IntPairWritable();
-    private double frequency = 0.0;
+  public static class Frequency implements Comparable<Frequency>, Serializable {
     
+    private final IntPairWritable pair;
+    private final double frequency;
+
     public double getFrequency() {
       return frequency;
     }
-    
+
     public IntPairWritable getPair() {
       return pair;
     }
-    
+
     public Frequency(IntPairWritable bigram, double frequency) {
       this.pair = new IntPairWritable(bigram);
       this.frequency = frequency;
@@ -204,21 +239,26 @@
     
     @Override
     public int hashCode() {
-      return pair.hashCode() + (int) Math.abs(Math.round(frequency * 31));
+      return pair.hashCode() + RandomUtils.hashDouble(frequency);
     }
     
     @Override
     public boolean equals(Object right) {
-      if ((right == null) || !(right instanceof Frequency)) {
+      if (!(right instanceof Frequency)) {
         return false;
       }
       Frequency that = (Frequency) right;
-      return this.compareTo(that) == 0;
+      return pair.equals(that.pair) && frequency == that.frequency;
     }
     
     @Override
     public int compareTo(Frequency that) {
-      return this.frequency > that.frequency ? 1 : -1;
+      if (frequency < that.frequency) {
+        return -1;
+      } else if (frequency > that.frequency) {
+        return 1;
+      }
+      return 0;
     }
     
     @Override

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java?rev=919388&r1=919387&r2=919388&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/common/IntPairWritableTest.java Fri Mar  5 12:23:32 2010
@@ -1,3 +1,20 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 package org.apache.mahout.common;
 
 import java.io.ByteArrayInputStream;
@@ -9,10 +26,8 @@
 
 import junit.framework.Assert;
 
-import org.apache.mahout.common.IntPairWritable;
 import org.junit.Test;
 
-
 public class IntPairWritableTest {
 
   @Test