You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/05/04 20:31:22 UTC

svn commit: r940992 - /lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java

Author: srowen
Date: Tue May  4 18:31:22 2010
New Revision: 940992

URL: http://svn.apache.org/viewvc?rev=940992&view=rev
Log:
Fix a possible loop logic problem here, and, do better by iterating over observed count values actually seen in the user vector

Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java?rev=940992&r1=940991&r2=940992&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java Tue May  4 18:31:22 2010
@@ -18,6 +18,7 @@
 package org.apache.mahout.cf.taste.hadoop.item;
 
 import java.io.IOException;
+import java.util.Arrays;
 import java.util.Iterator;
 
 import org.apache.hadoop.io.IntWritable;
@@ -28,6 +29,7 @@ import org.apache.hadoop.mapred.OutputCo
 import org.apache.hadoop.mapred.Reporter;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.list.IntArrayList;
 import org.apache.mahout.math.map.OpenIntIntHashMap;
 
 public final class UserVectorToCooccurrenceMapper extends MapReduceBase implements
@@ -79,21 +81,31 @@ public final class UserVectorToCooccurre
       countCounts.adjustOrPutValue(count, 1, 1);
     }
 
+    IntArrayList countsList = new IntArrayList(countCounts.size());
+    countCounts.keys(countsList);
+    int[] counts = countsList.elements();
+    Arrays.sort(counts);
+
     int resultingSizeAtCutoff = 0;
-    int cutoff = 0;
-    while (resultingSizeAtCutoff <= MAX_PREFS_CONSIDERED) {
-      cutoff++;
-      int count = indexCounts.get(cutoff);
+    int cutoffIndex = 0;
+    while (cutoffIndex < counts.length && resultingSizeAtCutoff <= MAX_PREFS_CONSIDERED) {
+      int cutoff = counts[cutoffIndex];
+      cutoffIndex++;
+      int count = countCounts.get(cutoff);
       resultingSizeAtCutoff += count;
     }
+    cutoffIndex--;    
 
-    Iterator<Vector.Element> it2 = userVector.iterateNonZero();
-    while (it2.hasNext()) {
-      Vector.Element e = it2.next();
-      int index = e.index();
-      int count = indexCounts.get(index);
-      if (count > cutoff) {
-        e.set(0.0);
+    if (resultingSizeAtCutoff > MAX_PREFS_CONSIDERED) {
+      int cutoff = counts[cutoffIndex];
+      Iterator<Vector.Element> it2 = userVector.iterateNonZero();
+      while (it2.hasNext()) {
+        Vector.Element e = it2.next();
+        int index = e.index();
+        int count = indexCounts.get(index);
+        if (count >= cutoff) {
+          e.set(0.0);
+        }
       }
     }