You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/05/04 20:31:22 UTC
svn commit: r940992 -
/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
Author: srowen
Date: Tue May 4 18:31:22 2010
New Revision: 940992
URL: http://svn.apache.org/viewvc?rev=940992&view=rev
Log:
Fix a possible loop logic problem here, and, do better by iterating over observed count values actually seen in the user vector
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java?rev=940992&r1=940991&r2=940992&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorToCooccurrenceMapper.java Tue May 4 18:31:22 2010
@@ -18,6 +18,7 @@
package org.apache.mahout.cf.taste.hadoop.item;
import java.io.IOException;
+import java.util.Arrays;
import java.util.Iterator;
import org.apache.hadoop.io.IntWritable;
@@ -28,6 +29,7 @@ import org.apache.hadoop.mapred.OutputCo
import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.list.IntArrayList;
import org.apache.mahout.math.map.OpenIntIntHashMap;
public final class UserVectorToCooccurrenceMapper extends MapReduceBase implements
@@ -79,21 +81,31 @@ public final class UserVectorToCooccurre
countCounts.adjustOrPutValue(count, 1, 1);
}
+ IntArrayList countsList = new IntArrayList(countCounts.size());
+ countCounts.keys(countsList);
+ int[] counts = countsList.elements();
+ Arrays.sort(counts);
+
int resultingSizeAtCutoff = 0;
- int cutoff = 0;
- while (resultingSizeAtCutoff <= MAX_PREFS_CONSIDERED) {
- cutoff++;
- int count = indexCounts.get(cutoff);
+ int cutoffIndex = 0;
+ while (cutoffIndex < counts.length && resultingSizeAtCutoff <= MAX_PREFS_CONSIDERED) {
+ int cutoff = counts[cutoffIndex];
+ cutoffIndex++;
+ int count = countCounts.get(cutoff);
resultingSizeAtCutoff += count;
}
+ cutoffIndex--;
- Iterator<Vector.Element> it2 = userVector.iterateNonZero();
- while (it2.hasNext()) {
- Vector.Element e = it2.next();
- int index = e.index();
- int count = indexCounts.get(index);
- if (count > cutoff) {
- e.set(0.0);
+ if (resultingSizeAtCutoff > MAX_PREFS_CONSIDERED) {
+ int cutoff = counts[cutoffIndex];
+ Iterator<Vector.Element> it2 = userVector.iterateNonZero();
+ while (it2.hasNext()) {
+ Vector.Element e = it2.next();
+ int index = e.index();
+ int count = indexCounts.get(index);
+ if (count >= cutoff) {
+ e.set(0.0);
+ }
}
}