You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ro...@apache.org on 2010/02/13 21:54:31 UTC
svn commit: r909912 [8/10] - in
/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste: common/
eval/ hadoop/ hadoop/cooccurence/ hadoop/item/ hadoop/pseudo/
hadoop/slopeone/ impl/common/ impl/common/jdbc/ impl/eval/ impl/model/
impl/model/...
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java?rev=909912&r1=909911&r2=909912&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java Sat Feb 13 20:54:05 2010
@@ -17,6 +17,13 @@
package org.apache.mahout.cf.taste.impl.recommender.slopeone;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.concurrent.Callable;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.common.Weighting;
@@ -38,21 +45,16 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.Collection;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.concurrent.Callable;
-import java.util.concurrent.locks.ReadWriteLock;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
-
/**
- * <p>An implementation of {@link DiffStorage} that merely stores item-item diffs in memory. It is fast, but can consume
- * a great deal of memory.</p>
+ * <p>
+ * An implementation of {@link DiffStorage} that merely stores item-item diffs in memory. It is fast, but can
+ * consume a great deal of memory.
+ * </p>
*/
public final class MemoryDiffStorage implements DiffStorage {
-
+
private static final Logger log = LoggerFactory.getLogger(MemoryDiffStorage.class);
-
+
private final DataModel dataModel;
private final boolean stdDevWeighted;
private final boolean compactAverages;
@@ -62,28 +64,36 @@
private final FastIDSet allRecommendableItemIDs;
private final ReadWriteLock buildAverageDiffsLock;
private final RefreshHelper refreshHelper;
-
+
/**
- * <p>Creates a new {@link MemoryDiffStorage}.</p>
- *
- * <p>See {@link org.apache.mahout.cf.taste.impl.recommender.slopeone.SlopeOneRecommender} for the meaning of
- * <code>stdDevWeighted</code>. If <code>compactAverages</code> is set, this uses alternate data structures ({@link
- * CompactRunningAverage} versus {@link FullRunningAverage}) that use almost 50% less memory but store item-item
- * averages less accurately. <code>maxEntries</code> controls the maximum number of item-item average preference
- * differences that will be tracked internally. After the limit is reached, if a new item-item pair is observed in the
- * data it will be ignored. This is recommended for large datasets. The first <code>maxEntries</code> item-item pairs
- * observed in the data are tracked. Assuming that item ratings are reasonably distributed among users, this should
- * only ignore item-item pairs that are very infrequently co-rated by a user. The intuition is that data on these
- * infrequently co-rated item-item pairs is less reliable and should be the first that is ignored. This parameter can
- * be used to limit the memory requirements of {@link SlopeOneRecommender}, which otherwise grow as the square of the
- * number of items that exist in the {@link DataModel}. Memory requirements can reach gigabytes with only about 10000
+ * <p>
+ * Creates a new {@link MemoryDiffStorage}.
+ * </p>
+ *
+ * <p>
+ * See {@link org.apache.mahout.cf.taste.impl.recommender.slopeone.SlopeOneRecommender} for the meaning of
+ * <code>stdDevWeighted</code>. If <code>compactAverages</code> is set, this uses alternate data structures
+ * ({@link CompactRunningAverage} versus {@link FullRunningAverage}) that use almost 50% less memory but
+ * store item-item averages less accurately. <code>maxEntries</code> controls the maximum number of
+ * item-item average preference differences that will be tracked internally. After the limit is reached, if
+ * a new item-item pair is observed in the data it will be ignored. This is recommended for large datasets.
+ * The first <code>maxEntries</code> item-item pairs observed in the data are tracked. Assuming that item
+ * ratings are reasonably distributed among users, this should only ignore item-item pairs that are very
+ * infrequently co-rated by a user. The intuition is that data on these infrequently co-rated item-item
+ * pairs is less reliable and should be the first that is ignored. This parameter can be used to limit the
+ * memory requirements of {@link SlopeOneRecommender}, which otherwise grow as the square of the number of
+ * items that exist in the {@link DataModel}. Memory requirements can reach gigabytes with only about 10000
* items, so this may be necessary on larger datasets.
- *
- * @param stdDevWeighted see {@link org.apache.mahout.cf.taste.impl.recommender.slopeone.SlopeOneRecommender}
- * @param compactAverages if <code>true</code>, use {@link CompactRunningAverage} instead of {@link
- * FullRunningAverage} internally
- * @param maxEntries maximum number of item-item average preference differences to track internally
- * @throws IllegalArgumentException if <code>maxEntries</code> is not positive or <code>dataModel</code> is null
+ *
+ * @param stdDevWeighted
+ * see {@link org.apache.mahout.cf.taste.impl.recommender.slopeone.SlopeOneRecommender}
+ * @param compactAverages
+ * if <code>true</code>, use {@link CompactRunningAverage} instead of {@link FullRunningAverage}
+ * internally
+ * @param maxEntries
+ * maximum number of item-item average preference differences to track internally
+ * @throws IllegalArgumentException
+ * if <code>maxEntries</code> is not positive or <code>dataModel</code> is null
*/
public MemoryDiffStorage(DataModel dataModel,
Weighting stdDevWeighted,
@@ -116,10 +126,10 @@
refreshHelper.addDependency(dataModel);
buildAverageDiffs();
}
-
+
@Override
public RunningAverage getDiff(long itemID1, long itemID2) {
-
+
boolean inverted = false;
if (itemID1 > itemID2) {
inverted = true;
@@ -127,7 +137,7 @@
itemID1 = itemID2;
itemID2 = temp;
}
-
+
FastByIDMap<RunningAverage> level2Map;
try {
buildAverageDiffsLock.readLock().lock();
@@ -143,14 +153,13 @@
if (average == null) {
return null;
}
- return stdDevWeighted ?
- new InvertedRunningAverageAndStdDev((RunningAverageAndStdDev) average) :
- new InvertedRunningAverage(average);
+ return stdDevWeighted ? new InvertedRunningAverageAndStdDev((RunningAverageAndStdDev) average)
+ : new InvertedRunningAverage(average);
} else {
return average;
}
}
-
+
@Override
public RunningAverage[] getDiffs(long userID, long itemID, PreferenceArray prefs) {
try {
@@ -165,12 +174,12 @@
buildAverageDiffsLock.readLock().unlock();
}
}
-
+
@Override
public RunningAverage getAverageItemPref(long itemID) {
return averageItemPref.get(itemID);
}
-
+
@Override
public void updateItemPref(long itemID, float prefDelta, boolean remove) {
if (!remove && stdDevWeighted) {
@@ -178,9 +187,9 @@
}
try {
buildAverageDiffsLock.readLock().lock();
- for (Map.Entry<Long, FastByIDMap<RunningAverage>> entry : averageDiffs.entrySet()) {
+ for (Map.Entry<Long,FastByIDMap<RunningAverage>> entry : averageDiffs.entrySet()) {
boolean matchesItemID1 = itemID == entry.getKey();
- for (Map.Entry<Long, RunningAverage> entry2 : entry.getValue().entrySet()) {
+ for (Map.Entry<Long,RunningAverage> entry2 : entry.getValue().entrySet()) {
RunningAverage average = entry2.getValue();
if (matchesItemID1) {
if (remove) {
@@ -205,7 +214,7 @@
buildAverageDiffsLock.readLock().unlock();
}
}
-
+
@Override
public FastIDSet getRecommendableItemIDs(long userID) throws TasteException {
FastIDSet result;
@@ -223,9 +232,9 @@
}
return result;
}
-
+
private void buildAverageDiffs() throws TasteException {
- log.info("Building average diffs...");
+ MemoryDiffStorage.log.info("Building average diffs...");
try {
buildAverageDiffsLock.writeLock().lock();
averageDiffs.clear();
@@ -234,22 +243,22 @@
while (it.hasNext()) {
averageCount = processOneUser(averageCount, it.nextLong());
}
-
+
pruneInconsequentialDiffs();
updateAllRecommendableItems();
-
+
} finally {
buildAverageDiffsLock.writeLock().unlock();
}
}
-
+
private void pruneInconsequentialDiffs() {
// Go back and prune inconsequential diffs. "Inconsequential" means, here, only represented by one
// data point, so possibly unreliable
- Iterator<Map.Entry<Long, FastByIDMap<RunningAverage>>> it1 = averageDiffs.entrySet().iterator();
+ Iterator<Map.Entry<Long,FastByIDMap<RunningAverage>>> it1 = averageDiffs.entrySet().iterator();
while (it1.hasNext()) {
FastByIDMap<RunningAverage> map = it1.next().getValue();
- Iterator<Map.Entry<Long, RunningAverage>> it2 = map.entrySet().iterator();
+ Iterator<Map.Entry<Long,RunningAverage>> it2 = map.entrySet().iterator();
while (it2.hasNext()) {
RunningAverage average = it2.next().getValue();
if (average.getCount() <= 1) {
@@ -264,10 +273,10 @@
}
averageDiffs.rehash();
}
-
+
private void updateAllRecommendableItems() throws TasteException {
FastIDSet ids = new FastIDSet(dataModel.getNumItems());
- for (Map.Entry<Long, FastByIDMap<RunningAverage>> entry : averageDiffs.entrySet()) {
+ for (Map.Entry<Long,FastByIDMap<RunningAverage>> entry : averageDiffs.entrySet()) {
ids.add(entry.getKey());
LongPrimitiveIterator it = entry.getValue().keySetIterator();
while (it.hasNext()) {
@@ -278,9 +287,9 @@
allRecommendableItemIDs.addAll(ids);
allRecommendableItemIDs.rehash();
}
-
+
private long processOneUser(long averageCount, long userID) throws TasteException {
- log.debug("Processing prefs for user {}", userID);
+ MemoryDiffStorage.log.debug("Processing prefs for user {}", userID);
// Save off prefs for the life of this loop iteration
PreferenceArray userPreferences = dataModel.getPreferencesFromUser(userID);
int length = userPreferences.length();
@@ -296,7 +305,7 @@
// This is a performance-critical block
long itemIDB = userPreferences.getItemID(j);
RunningAverage average = aMap.get(itemIDB);
- if (average == null && averageCount < maxEntries) {
+ if ((average == null) && (averageCount < maxEntries)) {
average = buildRunningAverage();
aMap.put(itemIDB, average);
averageCount++;
@@ -304,7 +313,7 @@
if (average != null) {
average.addDatum(userPreferences.getValue(j) - prefAValue);
}
-
+
}
RunningAverage itemAverage = averageItemPref.get(itemIDA);
if (itemAverage == null) {
@@ -315,7 +324,7 @@
}
return averageCount;
}
-
+
private RunningAverage buildRunningAverage() {
if (stdDevWeighted) {
return compactAverages ? new CompactRunningAverageAndStdDev() : new FullRunningAverageAndStdDev();
@@ -323,15 +332,15 @@
return compactAverages ? new CompactRunningAverage() : new FullRunningAverage();
}
}
-
+
@Override
public void refresh(Collection<Refreshable> alreadyRefreshed) {
refreshHelper.refresh(alreadyRefreshed);
}
-
+
@Override
public String toString() {
return "MemoryDiffStorage";
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java?rev=909912&r1=909911&r2=909912&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java Sat Feb 13 20:54:05 2010
@@ -17,6 +17,9 @@
package org.apache.mahout.cf.taste.impl.recommender.slopeone;
+import java.util.Collection;
+import java.util.List;
+
import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
@@ -35,55 +38,63 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.Collection;
-import java.util.List;
-
/**
- * <p>A basic "slope one" recommender. (See an <a href="http://www.daniel-lemire.com/fr/abstracts/SDM2005.html">
+ * <p>
+ * A basic "slope one" recommender. (See an <a href="http://www.daniel-lemire.com/fr/abstracts/SDM2005.html">
* excellent summary here</a> for example.) This {@link org.apache.mahout.cf.taste.recommender.Recommender} is
- * especially suitable when user preferences are updating frequently as it can incorporate this information without
- * expensive recomputation.</p>
- *
- * <p>This implementation can also be used as a "weighted slope one" recommender.</p>
+ * especially suitable when user preferences are updating frequently as it can incorporate this information
+ * without expensive recomputation.
+ * </p>
+ *
+ * <p>
+ * This implementation can also be used as a "weighted slope one" recommender.
+ * </p>
*/
public final class SlopeOneRecommender extends AbstractRecommender {
-
+
private static final Logger log = LoggerFactory.getLogger(SlopeOneRecommender.class);
-
+
private final boolean weighted;
private final boolean stdDevWeighted;
private final DiffStorage diffStorage;
-
+
/**
- * <p>Creates a default (weighted) {@link SlopeOneRecommender} based on the given {@link DataModel}.</p>
- *
- * @param dataModel data model
+ * <p>
+ * Creates a default (weighted) {@link SlopeOneRecommender} based on the given {@link DataModel}.
+ * </p>
+ *
+ * @param dataModel
+ * data model
*/
public SlopeOneRecommender(DataModel dataModel) throws TasteException {
- this(dataModel,
- Weighting.WEIGHTED,
- Weighting.WEIGHTED,
- new MemoryDiffStorage(dataModel, Weighting.WEIGHTED, false, Long.MAX_VALUE));
+ this(dataModel, Weighting.WEIGHTED, Weighting.WEIGHTED, new MemoryDiffStorage(dataModel,
+ Weighting.WEIGHTED, false, Long.MAX_VALUE));
}
-
+
/**
- * <p>Creates a {@link SlopeOneRecommender} based on the given {@link DataModel}.</p>
- *
- * <p>If <code>weighted</code> is set, acts as a weighted slope one recommender. This implementation also includes an
- * experimental "standard deviation" weighting which weights item-item ratings diffs with lower standard deviation
- * more highly, on the theory that they are more reliable.</p>
- *
- * @param weighting if {@link Weighting#WEIGHTED}, acts as a weighted slope one recommender
- * @param stdDevWeighting use optional standard deviation weighting of diffs
- * @throws IllegalArgumentException if <code>diffStorage</code> is null, or stdDevWeighted is set when weighted is not
- * set
+ * <p>
+ * Creates a {@link SlopeOneRecommender} based on the given {@link DataModel}.
+ * </p>
+ *
+ * <p>
+ * If <code>weighted</code> is set, acts as a weighted slope one recommender. This implementation also
+ * includes an experimental "standard deviation" weighting which weights item-item ratings diffs with lower
+ * standard deviation more highly, on the theory that they are more reliable.
+ * </p>
+ *
+ * @param weighting
+ * if {@link Weighting#WEIGHTED}, acts as a weighted slope one recommender
+ * @param stdDevWeighting
+ * use optional standard deviation weighting of diffs
+ * @throws IllegalArgumentException
+ * if <code>diffStorage</code> is null, or stdDevWeighted is set when weighted is not set
*/
public SlopeOneRecommender(DataModel dataModel,
Weighting weighting,
Weighting stdDevWeighting,
DiffStorage diffStorage) {
super(dataModel);
- if (stdDevWeighting == Weighting.WEIGHTED && weighting == Weighting.UNWEIGHTED) {
+ if ((stdDevWeighting == Weighting.WEIGHTED) && (weighting == Weighting.UNWEIGHTED)) {
throw new IllegalArgumentException("weighted required when stdDevWeighted is set");
}
if (diffStorage == null) {
@@ -93,26 +104,26 @@
this.stdDevWeighted = stdDevWeighting == Weighting.WEIGHTED;
this.diffStorage = diffStorage;
}
-
+
@Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer)
- throws TasteException {
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
if (howMany < 1) {
throw new IllegalArgumentException("howMany must be at least 1");
}
-
- log.debug("Recommending items for user ID '{}'", userID);
-
+
+ SlopeOneRecommender.log.debug("Recommending items for user ID '{}'", userID);
+
FastIDSet possibleItemIDs = diffStorage.getRecommendableItemIDs(userID);
-
+
TopItems.Estimator<Long> estimator = new Estimator(userID);
-
- List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer, estimator);
-
- log.debug("Recommendations are: {}", topItems);
+
+ List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer,
+ estimator);
+
+ SlopeOneRecommender.log.debug("Recommendations are: {}", topItems);
return topItems;
}
-
+
@Override
public float estimatePreference(long userID, long itemID) throws TasteException {
DataModel model = getDataModel();
@@ -122,7 +133,7 @@
}
return doEstimatePreference(userID, itemID);
}
-
+
private float doEstimatePreference(long userID, long itemID) throws TasteException {
double count = 0.0;
double totalPreference = 0.0;
@@ -134,7 +145,7 @@
if (averageDiff != null) {
double averageDiffValue = averageDiff.getAverage();
if (weighted) {
- double weight = (double) averageDiff.getCount();
+ double weight = averageDiff.getCount();
if (stdDevWeighted) {
double stdev = ((RunningAverageAndStdDev) averageDiff).getStandardDeviation();
if (!Double.isNaN(stdev)) {
@@ -161,7 +172,7 @@
return (float) (totalPreference / count);
}
}
-
+
@Override
public void setPreference(long userID, long itemID, float value) throws TasteException {
DataModel dataModel = getDataModel();
@@ -175,7 +186,7 @@
super.setPreference(userID, itemID, value);
diffStorage.updateItemPref(itemID, prefDelta, false);
}
-
+
@Override
public void removePreference(long userID, long itemID) throws TasteException {
DataModel dataModel = getDataModel();
@@ -185,31 +196,31 @@
diffStorage.updateItemPref(itemID, oldPref, true);
}
}
-
+
@Override
public void refresh(Collection<Refreshable> alreadyRefreshed) {
alreadyRefreshed = RefreshHelper.buildRefreshed(alreadyRefreshed);
RefreshHelper.maybeRefresh(alreadyRefreshed, diffStorage);
}
-
+
@Override
public String toString() {
- return "SlopeOneRecommender[weighted:" + weighted + ", stdDevWeighted:" + stdDevWeighted +
- ", diffStorage:" + diffStorage + ']';
+ return "SlopeOneRecommender[weighted:" + weighted + ", stdDevWeighted:" + stdDevWeighted
+ + ", diffStorage:" + diffStorage + ']';
}
-
+
private final class Estimator implements TopItems.Estimator<Long> {
-
+
private final long userID;
-
+
private Estimator(long userID) {
this.userID = userID;
}
-
+
@Override
public double estimate(Long itemID) throws TasteException {
return doEstimatePreference(userID, itemID);
}
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/file/FileDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/file/FileDiffStorage.java?rev=909912&r1=909911&r2=909912&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/file/FileDiffStorage.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/file/FileDiffStorage.java Sat Feb 13 20:54:05 2010
@@ -17,6 +17,15 @@
package org.apache.mahout.cf.taste.impl.recommender.slopeone.file;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
@@ -31,32 +40,26 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.File;
-import java.io.FileNotFoundException;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.concurrent.locks.ReadWriteLock;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
-
/**
- * <p>{@link DiffStorage} which reads pre-computed diffs from a file and stores
- * in memory. The file should have one diff per line:</p>
- *
+ * <p>
+ * {@link DiffStorage} which reads pre-computed diffs from a file and stores in memory. The file should have
+ * one diff per line:
+ * </p>
+ *
* {@code itemID1,itemID2,diff}
- *
- * <p>Commas or tabs can be delimiters. This is intended for use in conjuction
- * with the output of
- * {@link org.apache.mahout.cf.taste.hadoop.slopeone.SlopeOneAverageDiffsJob}.</p>
+ *
+ * <p>
+ * Commas or tabs can be delimiters. This is intended for use in conjuction with the output of
+ * {@link org.apache.mahout.cf.taste.hadoop.slopeone.SlopeOneAverageDiffsJob}.
+ * </p>
*/
public final class FileDiffStorage implements DiffStorage {
-
+
private static final Logger log = LoggerFactory.getLogger(FileDiffStorage.class);
-
+
private static final long MIN_RELOAD_INTERVAL_MS = 60 * 1000L; // 1 minute?
private static final char COMMENT_CHAR = '#';
-
+
private final File dataFile;
private long lastModified;
private boolean loaded;
@@ -64,11 +67,14 @@
private final FastByIDMap<FastByIDMap<RunningAverage>> averageDiffs;
private final FastIDSet allRecommendableItemIDs;
private final ReadWriteLock buildAverageDiffsLock;
-
+
/**
- * @param dataFile diffs file
- * @param maxEntries maximum number of diffs to store
- * @throws FileNotFoundException if data file does not exist or is a directory
+ * @param dataFile
+ * diffs file
+ * @param maxEntries
+ * maximum number of diffs to store
+ * @throws FileNotFoundException
+ * if data file does not exist or is a directory
*/
public FileDiffStorage(File dataFile, long maxEntries) throws FileNotFoundException {
if (dataFile == null) {
@@ -80,9 +86,9 @@
if (maxEntries <= 0L) {
throw new IllegalArgumentException("maxEntries must be positive");
}
-
- log.info("Creating FileDataModel for file {}", dataFile);
-
+
+ FileDiffStorage.log.info("Creating FileDataModel for file {}", dataFile);
+
this.dataFile = dataFile.getAbsoluteFile();
this.lastModified = dataFile.lastModified();
this.maxEntries = maxEntries;
@@ -90,17 +96,17 @@
this.allRecommendableItemIDs = new FastIDSet();
this.buildAverageDiffsLock = new ReentrantReadWriteLock();
}
-
+
private void buildDiffs() {
if (buildAverageDiffsLock.writeLock().tryLock()) {
try {
-
+
averageDiffs.clear();
allRecommendableItemIDs.clear();
-
+
FileLineIterator iterator = new FileLineIterator(dataFile, false);
String firstLine = iterator.peek();
- while (firstLine.length() == 0 || firstLine.charAt(0) == COMMENT_CHAR) {
+ while ((firstLine.length() == 0) || (firstLine.charAt(0) == FileDiffStorage.COMMENT_CHAR)) {
iterator.next();
firstLine = iterator.peek();
}
@@ -109,50 +115,50 @@
while (iterator.hasNext()) {
averageCount = processLine(iterator.next(), delimiter, averageCount);
}
-
+
pruneInconsequentialDiffs();
updateAllRecommendableItems();
-
+
} catch (IOException ioe) {
- log.warn("Exception while reloading", ioe);
+ FileDiffStorage.log.warn("Exception while reloading", ioe);
} finally {
buildAverageDiffsLock.writeLock().unlock();
}
}
}
-
+
private long processLine(String line, char delimiter, long averageCount) {
-
- if (line.length() == 0 || line.charAt(0) == COMMENT_CHAR) {
+
+ if ((line.length() == 0) || (line.charAt(0) == FileDiffStorage.COMMENT_CHAR)) {
return averageCount;
}
-
- int delimiterOne = line.indexOf((int) delimiter);
+
+ int delimiterOne = line.indexOf(delimiter);
if (delimiterOne < 0) {
throw new IllegalArgumentException("Bad line: " + line);
}
- int delimiterTwo = line.indexOf((int) delimiter, delimiterOne + 1);
+ int delimiterTwo = line.indexOf(delimiter, delimiterOne + 1);
if (delimiterTwo < 0) {
throw new IllegalArgumentException("Bad line: " + line);
}
-
+
long itemID1 = Long.parseLong(line.substring(0, delimiterOne));
long itemID2 = Long.parseLong(line.substring(delimiterOne + 1, delimiterTwo));
double diff = Double.parseDouble(line.substring(delimiterTwo + 1));
-
+
if (itemID1 > itemID2) {
long temp = itemID1;
itemID1 = itemID2;
itemID2 = temp;
}
-
+
FastByIDMap<RunningAverage> level1Map = averageDiffs.get(itemID1);
if (level1Map == null) {
level1Map = new FastByIDMap<RunningAverage>();
averageDiffs.put(itemID1, level1Map);
}
RunningAverage average = level1Map.get(itemID2);
- if (average == null && averageCount < maxEntries) {
+ if ((average == null) && (averageCount < maxEntries)) {
average = new FullRunningAverage();
level1Map.put(itemID2, average);
averageCount++;
@@ -160,20 +166,20 @@
if (average != null) {
average.addDatum(diff);
}
-
+
allRecommendableItemIDs.add(itemID1);
allRecommendableItemIDs.add(itemID2);
-
+
return averageCount;
}
-
+
private void pruneInconsequentialDiffs() {
// Go back and prune inconsequential diffs. "Inconsequential" means, here, only represented by one
// data point, so possibly unreliable
- Iterator<Map.Entry<Long, FastByIDMap<RunningAverage>>> it1 = averageDiffs.entrySet().iterator();
+ Iterator<Map.Entry<Long,FastByIDMap<RunningAverage>>> it1 = averageDiffs.entrySet().iterator();
while (it1.hasNext()) {
FastByIDMap<RunningAverage> map = it1.next().getValue();
- Iterator<Map.Entry<Long, RunningAverage>> it2 = map.entrySet().iterator();
+ Iterator<Map.Entry<Long,RunningAverage>> it2 = map.entrySet().iterator();
while (it2.hasNext()) {
RunningAverage average = it2.next().getValue();
if (average.getCount() <= 1) {
@@ -188,9 +194,9 @@
}
averageDiffs.rehash();
}
-
+
private void updateAllRecommendableItems() {
- for (Map.Entry<Long, FastByIDMap<RunningAverage>> entry : averageDiffs.entrySet()) {
+ for (Map.Entry<Long,FastByIDMap<RunningAverage>> entry : averageDiffs.entrySet()) {
allRecommendableItemIDs.add(entry.getKey());
LongPrimitiveIterator it = entry.getValue().keySetIterator();
while (it.hasNext()) {
@@ -199,18 +205,18 @@
}
allRecommendableItemIDs.rehash();
}
-
+
private void checkLoaded() {
if (!loaded) {
buildDiffs();
loaded = true;
}
}
-
+
@Override
public RunningAverage getDiff(long itemID1, long itemID2) {
checkLoaded();
-
+
boolean inverted = false;
if (itemID1 > itemID2) {
inverted = true;
@@ -218,7 +224,7 @@
itemID1 = itemID2;
itemID2 = temp;
}
-
+
FastByIDMap<RunningAverage> level2Map;
try {
buildAverageDiffsLock.readLock().lock();
@@ -239,7 +245,7 @@
return average;
}
}
-
+
@Override
public RunningAverage[] getDiffs(long userID, long itemID, PreferenceArray prefs) {
checkLoaded();
@@ -255,21 +261,21 @@
buildAverageDiffsLock.readLock().unlock();
}
}
-
+
@Override
public RunningAverage getAverageItemPref(long itemID) {
checkLoaded();
return null; // TODO can't do this without a DataModel
}
-
+
@Override
public void updateItemPref(long itemID, float prefDelta, boolean remove) {
checkLoaded();
try {
buildAverageDiffsLock.readLock().lock();
- for (Map.Entry<Long, FastByIDMap<RunningAverage>> entry : averageDiffs.entrySet()) {
+ for (Map.Entry<Long,FastByIDMap<RunningAverage>> entry : averageDiffs.entrySet()) {
boolean matchesItemID1 = itemID == entry.getKey();
- for (Map.Entry<Long, RunningAverage> entry2 : entry.getValue().entrySet()) {
+ for (Map.Entry<Long,RunningAverage> entry2 : entry.getValue().entrySet()) {
RunningAverage average = entry2.getValue();
if (matchesItemID1) {
if (remove) {
@@ -286,15 +292,15 @@
}
}
}
- //RunningAverage itemAverage = averageItemPref.get(itemID);
- //if (itemAverage != null) {
- // itemAverage.changeDatum(prefDelta);
- //}
+ // RunningAverage itemAverage = averageItemPref.get(itemID);
+ // if (itemAverage != null) {
+ // itemAverage.changeDatum(prefDelta);
+ // }
} finally {
buildAverageDiffsLock.readLock().unlock();
}
}
-
+
@Override
public FastIDSet getRecommendableItemIDs(long userID) {
checkLoaded();
@@ -305,15 +311,15 @@
buildAverageDiffsLock.readLock().unlock();
}
}
-
+
@Override
public void refresh(Collection<Refreshable> alreadyRefreshed) {
long mostRecentModification = dataFile.lastModified();
- if (mostRecentModification > lastModified + MIN_RELOAD_INTERVAL_MS) {
- log.debug("File has changed; reloading...");
+ if (mostRecentModification > lastModified + FileDiffStorage.MIN_RELOAD_INTERVAL_MS) {
+ FileDiffStorage.log.debug("File has changed; reloading...");
lastModified = mostRecentModification;
buildDiffs();
}
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java?rev=909912&r1=909911&r2=909912&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java Sat Feb 13 20:54:05 2010
@@ -17,44 +17,47 @@
package org.apache.mahout.cf.taste.impl.recommender.slopeone.jdbc;
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Statement;
+import java.util.Collection;
+import java.util.concurrent.Callable;
+
+import javax.sql.DataSource;
+
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
-import org.apache.mahout.common.IOUtils;
import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent;
import org.apache.mahout.cf.taste.model.JDBCDataModel;
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.cf.taste.recommender.slopeone.DiffStorage;
+import org.apache.mahout.common.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import javax.sql.DataSource;
-import java.sql.Connection;
-import java.sql.PreparedStatement;
-import java.sql.ResultSet;
-import java.sql.SQLException;
-import java.sql.Statement;
-import java.util.Collection;
-import java.util.concurrent.Callable;
-
/**
- * <p>A {@link DiffStorage} which stores diffs in a database. Database-specific implementations subclass this abstract
- * class. Note that this implementation has a fairly particular dependence on the {@link
- * org.apache.mahout.cf.taste.model.DataModel} used; it needs a {@link JDBCDataModel} attached to the same database
- * since its efficent operation depends on accessing preference data in the database directly.</p>
+ * <p>
+ * A {@link DiffStorage} which stores diffs in a database. Database-specific implementations subclass this
+ * abstract class. Note that this implementation has a fairly particular dependence on the
+ * {@link org.apache.mahout.cf.taste.model.DataModel} used; it needs a {@link JDBCDataModel} attached to the
+ * same database since its efficent operation depends on accessing preference data in the database directly.
+ * </p>
*/
public abstract class AbstractJDBCDiffStorage extends AbstractJDBCComponent implements DiffStorage {
-
+
private static final Logger log = LoggerFactory.getLogger(AbstractJDBCDiffStorage.class);
-
+
public static final String DEFAULT_DIFF_TABLE = "taste_slopeone_diffs";
public static final String DEFAULT_ITEM_A_COLUMN = "item_id_a";
public static final String DEFAULT_ITEM_B_COLUMN = "item_id_b";
public static final String DEFAULT_COUNT_COLUMN = "count";
public static final String DEFAULT_AVERAGE_DIFF_COLUMN = "average_diff";
-
+
private final DataSource dataSource;
private final String getDiffSQL;
private final String getDiffsSQL;
@@ -67,7 +70,7 @@
private final String diffsExistSQL;
private final int minDiffCount;
private final RefreshHelper refreshHelper;
-
+
protected AbstractJDBCDiffStorage(JDBCDataModel dataModel,
String getDiffSQL,
String getDiffsSQL,
@@ -79,18 +82,18 @@
String createDiffsSQL,
String diffsExistSQL,
int minDiffCount) throws TasteException {
-
- checkNotNullAndLog("dataModel", dataModel);
- checkNotNullAndLog("getDiffSQL", getDiffSQL);
- checkNotNullAndLog("getDiffsSQL", getDiffsSQL);
- checkNotNullAndLog("getAverageItemPrefSQL", getAverageItemPrefSQL);
- checkNotNullAndLog("updateDiffSQLs", updateDiffSQLs);
- checkNotNullAndLog("removeDiffSQLs", removeDiffSQLs);
- checkNotNullAndLog("getRecommendableItemsSQL", getRecommendableItemsSQL);
- checkNotNullAndLog("deleteDiffsSQL", deleteDiffsSQL);
- checkNotNullAndLog("createDiffsSQL", createDiffsSQL);
- checkNotNullAndLog("diffsExistSQL", diffsExistSQL);
-
+
+ AbstractJDBCComponent.checkNotNullAndLog("dataModel", dataModel);
+ AbstractJDBCComponent.checkNotNullAndLog("getDiffSQL", getDiffSQL);
+ AbstractJDBCComponent.checkNotNullAndLog("getDiffsSQL", getDiffsSQL);
+ AbstractJDBCComponent.checkNotNullAndLog("getAverageItemPrefSQL", getAverageItemPrefSQL);
+ AbstractJDBCComponent.checkNotNullAndLog("updateDiffSQLs", updateDiffSQLs);
+ AbstractJDBCComponent.checkNotNullAndLog("removeDiffSQLs", removeDiffSQLs);
+ AbstractJDBCComponent.checkNotNullAndLog("getRecommendableItemsSQL", getRecommendableItemsSQL);
+ AbstractJDBCComponent.checkNotNullAndLog("deleteDiffsSQL", deleteDiffsSQL);
+ AbstractJDBCComponent.checkNotNullAndLog("createDiffsSQL", createDiffsSQL);
+ AbstractJDBCComponent.checkNotNullAndLog("diffsExistSQL", diffsExistSQL);
+
if (minDiffCount < 0) {
throw new IllegalArgumentException("minDiffCount is not positive");
}
@@ -114,13 +117,13 @@
});
refreshHelper.addDependency(dataModel);
if (isDiffsExist()) {
- log.info("Diffs already exist in database; using them instead of recomputing");
+ AbstractJDBCDiffStorage.log.info("Diffs already exist in database; using them instead of recomputing");
} else {
- log.info("No diffs exist in database; recomputing...");
+ AbstractJDBCDiffStorage.log.info("No diffs exist in database; recomputing...");
buildAverageDiffs();
}
}
-
+
@Override
public RunningAverage getDiff(long itemID1, long itemID2) throws TasteException {
Connection conn = null;
@@ -135,20 +138,19 @@
stmt.setLong(2, itemID2);
stmt.setLong(3, itemID2);
stmt.setLong(4, itemID1);
- log.debug("Executing SQL query: {}", getDiffSQL);
+ AbstractJDBCDiffStorage.log.debug("Executing SQL query: {}", getDiffSQL);
rs = stmt.executeQuery();
return rs.next() ? new FixedRunningAverage(rs.getInt(1), rs.getDouble(2)) : null;
} catch (SQLException sqle) {
- log.warn("Exception while retrieving diff", sqle);
+ AbstractJDBCDiffStorage.log.warn("Exception while retrieving diff", sqle);
throw new TasteException(sqle);
} finally {
IOUtils.quietClose(rs, stmt, conn);
}
}
-
+
@Override
- public RunningAverage[] getDiffs(long userID, long itemID, PreferenceArray prefs)
- throws TasteException {
+ public RunningAverage[] getDiffs(long userID, long itemID, PreferenceArray prefs) throws TasteException {
int size = prefs.length();
RunningAverage[] result = new RunningAverage[size];
Connection conn = null;
@@ -161,7 +163,7 @@
stmt.setFetchSize(getFetchSize());
stmt.setLong(1, itemID);
stmt.setLong(2, userID);
- log.debug("Executing SQL query: {}", getDiffsSQL);
+ AbstractJDBCDiffStorage.log.debug("Executing SQL query: {}", getDiffsSQL);
rs = stmt.executeQuery();
// We should have up to one result for each Preference in prefs
// They are both ordered by item. Step through and create a RunningAverage[]
@@ -177,14 +179,14 @@
i++;
}
} catch (SQLException sqle) {
- log.warn("Exception while retrieving diff", sqle);
+ AbstractJDBCDiffStorage.log.warn("Exception while retrieving diff", sqle);
throw new TasteException(sqle);
} finally {
IOUtils.quietClose(rs, stmt, conn);
}
return result;
}
-
+
@Override
public RunningAverage getAverageItemPref(long itemID) throws TasteException {
Connection conn = null;
@@ -192,11 +194,12 @@
ResultSet rs = null;
try {
conn = dataSource.getConnection();
- stmt = conn.prepareStatement(getAverageItemPrefSQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+ stmt = conn.prepareStatement(getAverageItemPrefSQL, ResultSet.TYPE_FORWARD_ONLY,
+ ResultSet.CONCUR_READ_ONLY);
stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
stmt.setFetchSize(getFetchSize());
stmt.setLong(1, itemID);
- log.debug("Executing SQL query: {}", getAverageItemPrefSQL);
+ AbstractJDBCDiffStorage.log.debug("Executing SQL query: {}", getAverageItemPrefSQL);
rs = stmt.executeQuery();
if (rs.next()) {
int count = rs.getInt(1);
@@ -206,47 +209,45 @@
}
return null;
} catch (SQLException sqle) {
- log.warn("Exception while retrieving average item pref", sqle);
+ AbstractJDBCDiffStorage.log.warn("Exception while retrieving average item pref", sqle);
throw new TasteException(sqle);
} finally {
IOUtils.quietClose(rs, stmt, conn);
}
}
-
+
@Override
- public void updateItemPref(long itemID, float prefDelta, boolean remove)
- throws TasteException {
+ public void updateItemPref(long itemID, float prefDelta, boolean remove) throws TasteException {
Connection conn = null;
try {
conn = dataSource.getConnection();
if (remove) {
- doPartialUpdate(removeDiffSQLs[0], itemID, prefDelta, conn);
- doPartialUpdate(removeDiffSQLs[1], itemID, prefDelta, conn);
+ AbstractJDBCDiffStorage.doPartialUpdate(removeDiffSQLs[0], itemID, prefDelta, conn);
+ AbstractJDBCDiffStorage.doPartialUpdate(removeDiffSQLs[1], itemID, prefDelta, conn);
} else {
- doPartialUpdate(updateDiffSQLs[0], itemID, prefDelta, conn);
- doPartialUpdate(updateDiffSQLs[1], itemID, prefDelta, conn);
+ AbstractJDBCDiffStorage.doPartialUpdate(updateDiffSQLs[0], itemID, prefDelta, conn);
+ AbstractJDBCDiffStorage.doPartialUpdate(updateDiffSQLs[1], itemID, prefDelta, conn);
}
} catch (SQLException sqle) {
- log.warn("Exception while updating item diff", sqle);
+ AbstractJDBCDiffStorage.log.warn("Exception while updating item diff", sqle);
throw new TasteException(sqle);
} finally {
IOUtils.quietClose(conn);
}
}
-
- private static void doPartialUpdate(String sql, long itemID, double prefDelta, Connection conn)
- throws SQLException {
+
+ private static void doPartialUpdate(String sql, long itemID, double prefDelta, Connection conn) throws SQLException {
PreparedStatement stmt = conn.prepareStatement(sql);
try {
stmt.setDouble(1, prefDelta);
stmt.setLong(2, itemID);
- log.debug("Executing SQL update: {}", sql);
+ AbstractJDBCDiffStorage.log.debug("Executing SQL update: {}", sql);
stmt.executeUpdate();
} finally {
IOUtils.quietClose(stmt);
}
}
-
+
@Override
public FastIDSet getRecommendableItemIDs(long userID) throws TasteException {
Connection conn = null;
@@ -254,13 +255,14 @@
ResultSet rs = null;
try {
conn = dataSource.getConnection();
- stmt = conn.prepareStatement(getRecommendableItemsSQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+ stmt = conn.prepareStatement(getRecommendableItemsSQL, ResultSet.TYPE_FORWARD_ONLY,
+ ResultSet.CONCUR_READ_ONLY);
stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
stmt.setFetchSize(getFetchSize());
stmt.setLong(1, userID);
stmt.setLong(2, userID);
stmt.setLong(3, userID);
- log.debug("Executing SQL query: {}", getRecommendableItemsSQL);
+ AbstractJDBCDiffStorage.log.debug("Executing SQL query: {}", getRecommendableItemsSQL);
rs = stmt.executeQuery();
FastIDSet itemIDs = new FastIDSet();
while (rs.next()) {
@@ -268,13 +270,13 @@
}
return itemIDs;
} catch (SQLException sqle) {
- log.warn("Exception while retrieving recommendable items", sqle);
+ AbstractJDBCDiffStorage.log.warn("Exception while retrieving recommendable items", sqle);
throw new TasteException(sqle);
} finally {
IOUtils.quietClose(rs, stmt, conn);
}
}
-
+
private void buildAverageDiffs() throws TasteException {
Connection conn = null;
try {
@@ -282,7 +284,7 @@
PreparedStatement stmt = null;
try {
stmt = conn.prepareStatement(deleteDiffsSQL);
- log.debug("Executing SQL update: {}", deleteDiffsSQL);
+ AbstractJDBCDiffStorage.log.debug("Executing SQL update: {}", deleteDiffsSQL);
stmt.executeUpdate();
} finally {
IOUtils.quietClose(stmt);
@@ -290,19 +292,19 @@
try {
stmt = conn.prepareStatement(createDiffsSQL);
stmt.setInt(1, minDiffCount);
- log.debug("Executing SQL update: {}", createDiffsSQL);
+ AbstractJDBCDiffStorage.log.debug("Executing SQL update: {}", createDiffsSQL);
stmt.executeUpdate();
} finally {
IOUtils.quietClose(stmt);
}
} catch (SQLException sqle) {
- log.warn("Exception while updating/deleting diffs", sqle);
+ AbstractJDBCDiffStorage.log.warn("Exception while updating/deleting diffs", sqle);
throw new TasteException(sqle);
} finally {
IOUtils.quietClose(conn);
}
}
-
+
private boolean isDiffsExist() throws TasteException {
Connection conn = null;
Statement stmt = null;
@@ -312,57 +314,57 @@
stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
stmt.setFetchSize(getFetchSize());
- log.debug("Executing SQL query: {}", diffsExistSQL);
+ AbstractJDBCDiffStorage.log.debug("Executing SQL query: {}", diffsExistSQL);
rs = stmt.executeQuery(diffsExistSQL);
rs.next();
return rs.getInt(1) > 0;
} catch (SQLException sqle) {
- log.warn("Exception while deleting diffs", sqle);
+ AbstractJDBCDiffStorage.log.warn("Exception while deleting diffs", sqle);
throw new TasteException(sqle);
} finally {
IOUtils.quietClose(rs, stmt, conn);
}
}
-
+
@Override
public void refresh(Collection<Refreshable> alreadyRefreshed) {
refreshHelper.refresh(alreadyRefreshed);
}
-
+
private static class FixedRunningAverage implements RunningAverage {
-
+
private final int count;
private final double average;
-
+
private FixedRunningAverage(int count, double average) {
this.count = count;
this.average = average;
}
-
+
@Override
public void addDatum(double datum) {
throw new UnsupportedOperationException();
}
-
+
@Override
public void removeDatum(double datum) {
throw new UnsupportedOperationException();
}
-
+
@Override
public void changeDatum(double delta) {
throw new UnsupportedOperationException();
}
-
+
@Override
public int getCount() {
return count;
}
-
+
@Override
public double getAverage() {
return average;
}
}
-
+
}
\ No newline at end of file
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java?rev=909912&r1=909911&r2=909912&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java Sat Feb 13 20:54:05 2010
@@ -22,21 +22,53 @@
import org.apache.mahout.cf.taste.impl.model.jdbc.MySQLJDBCDataModel;
/**
- * <p>MySQL-specific implementation. Should be used in conjunction with a {@link MySQLJDBCDataModel}. This
- * implementation stores item-item diffs in a MySQL database and encapsulates some other slope-one-specific operations
- * that are needed on the preference data in the database. It assumes the database has a schema like:</p>
- *
- * <table> <tr><th>item_id_a</th><th>item_id_b</th><th>average_diff</th><th>count</th></tr>
- * <tr><td>123</td><td>234</td><td>0.5</td><td>5</td></tr> <tr><td>123</td><td>789</td><td>-1.33</td><td>3</td></tr>
- * <tr><td>234</td><td>789</td><td>2.1</td><td>1</td></tr> </table>
- *
- * <p><code>item_id_a</code> and <code>item_id_b</code> should have types compatible with the long
- * primitive type. <code>average_diff</code> must be compatible with <code>float</code> and
- * <code>count</code> must be compatible with <code>int</code>.</p>
- *
- * <p>The following command sets up a suitable table in MySQL:</p>
- *
- * <p><pre>
+ * <p>
+ * MySQL-specific implementation. Should be used in conjunction with a {@link MySQLJDBCDataModel}. This
+ * implementation stores item-item diffs in a MySQL database and encapsulates some other slope-one-specific
+ * operations that are needed on the preference data in the database. It assumes the database has a schema
+ * like:
+ * </p>
+ *
+ * <table>
+ * <tr>
+ * <th>item_id_a</th>
+ * <th>item_id_b</th>
+ * <th>average_diff</th>
+ * <th>count</th>
+ * </tr>
+ * <tr>
+ * <td>123</td>
+ * <td>234</td>
+ * <td>0.5</td>
+ * <td>5</td>
+ * </tr>
+ * <tr>
+ * <td>123</td>
+ * <td>789</td>
+ * <td>-1.33</td>
+ * <td>3</td>
+ * </tr>
+ * <tr>
+ * <td>234</td>
+ * <td>789</td>
+ * <td>2.1</td>
+ * <td>1</td>
+ * </tr>
+ * </table>
+ *
+ * <p>
+ * <code>item_id_a</code> and <code>item_id_b</code> should have types compatible with the long primitive
+ * type. <code>average_diff</code> must be compatible with <code>float</code> and <code>count</code> must be
+ * compatible with <code>int</code>.
+ * </p>
+ *
+ * <p>
+ * The following command sets up a suitable table in MySQL:
+ * </p>
+ *
+ * <p>
+ *
+ * <pre>
* CREATE TABLE taste_slopeone_diffs (
* item_id_a BIGINT NOT NULL,
* item_id_b BIGINT NOT NULL,
@@ -46,22 +78,21 @@
* INDEX (item_id_a),
* INDEX (item_id_b)
* )
- * </pre></p>
+ * </pre>
+ *
+ * </p>
*/
public final class MySQLJDBCDiffStorage extends AbstractJDBCDiffStorage {
-
+
private static final int DEFAULT_MIN_DIFF_COUNT = 2;
-
+
public MySQLJDBCDiffStorage(AbstractJDBCDataModel dataModel) throws TasteException {
- this(dataModel,
- DEFAULT_DIFF_TABLE,
- DEFAULT_ITEM_A_COLUMN,
- DEFAULT_ITEM_B_COLUMN,
- DEFAULT_COUNT_COLUMN,
- DEFAULT_AVERAGE_DIFF_COLUMN,
- DEFAULT_MIN_DIFF_COUNT);
+ this(dataModel, AbstractJDBCDiffStorage.DEFAULT_DIFF_TABLE,
+ AbstractJDBCDiffStorage.DEFAULT_ITEM_A_COLUMN, AbstractJDBCDiffStorage.DEFAULT_ITEM_B_COLUMN,
+ AbstractJDBCDiffStorage.DEFAULT_COUNT_COLUMN, AbstractJDBCDiffStorage.DEFAULT_AVERAGE_DIFF_COLUMN,
+ MySQLJDBCDiffStorage.DEFAULT_MIN_DIFF_COUNT);
}
-
+
public MySQLJDBCDiffStorage(AbstractJDBCDataModel dataModel,
String diffsTable,
String itemIDAColumn,
@@ -70,67 +101,65 @@
String avgColumn,
int minDiffCount) throws TasteException {
super(dataModel,
- // getDiffSQL
- "SELECT " + countColumn + ", " + avgColumn + " FROM " + diffsTable +
- " WHERE " + itemIDAColumn + "=? AND " + itemIDBColumn + "=? UNION " +
- "SELECT " + countColumn + ", " + avgColumn + " FROM " + diffsTable +
- " WHERE " + itemIDAColumn + "=? AND " + itemIDBColumn + "=?",
+ // getDiffSQL
+ "SELECT " + countColumn + ", " + avgColumn + " FROM " + diffsTable + " WHERE " + itemIDAColumn
+ + "=? AND " + itemIDBColumn + "=? UNION " + "SELECT " + countColumn + ", " + avgColumn + " FROM "
+ + diffsTable + " WHERE " + itemIDAColumn + "=? AND " + itemIDBColumn + "=?",
// getDiffsSQL
- "SELECT " + countColumn + ", " + avgColumn + ", " + itemIDAColumn + " FROM " + diffsTable + ", " +
- dataModel.getPreferenceTable() + " WHERE " + itemIDBColumn + "=? AND " + itemIDAColumn + " = " +
- dataModel.getItemIDColumn() + " AND " + dataModel.getUserIDColumn() + "=? ORDER BY " + itemIDAColumn,
+ "SELECT " + countColumn + ", " + avgColumn + ", " + itemIDAColumn + " FROM " + diffsTable + ", "
+ + dataModel.getPreferenceTable() + " WHERE " + itemIDBColumn + "=? AND " + itemIDAColumn + " = "
+ + dataModel.getItemIDColumn() + " AND " + dataModel.getUserIDColumn() + "=? ORDER BY "
+ + itemIDAColumn,
// getAverageItemPrefSQL
- "SELECT COUNT(1), AVG(" + dataModel.getPreferenceColumn() + ") FROM " + dataModel.getPreferenceTable() +
- " WHERE " + dataModel.getItemIDColumn() + "=?",
+ "SELECT COUNT(1), AVG(" + dataModel.getPreferenceColumn() + ") FROM "
+ + dataModel.getPreferenceTable() + " WHERE " + dataModel.getItemIDColumn() + "=?",
// updateDiffSQLs
- new String[]{
- "UPDATE " + diffsTable + " SET " + avgColumn + " = " + avgColumn + " - (? / " + countColumn +
- ") WHERE " + itemIDAColumn + "=?",
- "UPDATE " + diffsTable + " SET " + avgColumn + " = " + avgColumn + " + (? / " + countColumn +
- ") WHERE " + itemIDBColumn + "=?"
- },
+ new String[] {
+ "UPDATE " + diffsTable + " SET " + avgColumn + " = " + avgColumn + " - (? / "
+ + countColumn + ") WHERE " + itemIDAColumn + "=?",
+ "UPDATE " + diffsTable + " SET " + avgColumn + " = " + avgColumn + " + (? / "
+ + countColumn + ") WHERE " + itemIDBColumn + "=?"},
// removeDiffSQL
- new String[]{
- "UPDATE " + diffsTable + " SET " + countColumn + " = " + countColumn + "-1, " +
- avgColumn + " = " + avgColumn + " * ((" + countColumn + " + 1) / CAST(" + countColumn +
- " AS DECIMAL)) + ? / CAST(" + countColumn + " AS DECIMAL) WHERE " + itemIDAColumn + "=?",
- "UPDATE " + diffsTable + " SET " + countColumn + " = " + countColumn + "-1, " +
- avgColumn + " = " + avgColumn + " * ((" + countColumn + " + 1) / CAST(" + countColumn +
- " AS DECIMAL)) - ? / CAST(" + countColumn + " AS DECIMAL) WHERE " + itemIDBColumn + "=?"
- },
+ new String[] {
+ "UPDATE " + diffsTable + " SET " + countColumn + " = " + countColumn + "-1, "
+ + avgColumn + " = " + avgColumn + " * ((" + countColumn + " + 1) / CAST("
+ + countColumn + " AS DECIMAL)) + ? / CAST(" + countColumn + " AS DECIMAL) WHERE "
+ + itemIDAColumn + "=?",
+ "UPDATE " + diffsTable + " SET " + countColumn + " = " + countColumn + "-1, "
+ + avgColumn + " = " + avgColumn + " * ((" + countColumn + " + 1) / CAST("
+ + countColumn + " AS DECIMAL)) - ? / CAST(" + countColumn + " AS DECIMAL) WHERE "
+ + itemIDBColumn + "=?"},
// getRecommendableItemsSQL
- "SELECT id FROM " +
- "(SELECT " + itemIDAColumn + " AS id FROM " + diffsTable + ", " + dataModel.getPreferenceTable() +
- " WHERE " + itemIDBColumn + " = " + dataModel.getItemIDColumn() +
- " AND " + dataModel.getUserIDColumn() + "=? UNION DISTINCT" +
- " SELECT " + itemIDBColumn + " AS id FROM " + diffsTable + ", " + dataModel.getPreferenceTable() +
- " WHERE " + itemIDAColumn + " = " + dataModel.getItemIDColumn() +
- " AND " + dataModel.getUserIDColumn() +"=?) " +
- "possible_item_ids WHERE id NOT IN (SELECT " + dataModel.getItemIDColumn() + " FROM " +
- dataModel.getPreferenceTable() + " WHERE " + dataModel.getUserIDColumn() + "=?)",
+ "SELECT id FROM " + "(SELECT " + itemIDAColumn + " AS id FROM " + diffsTable + ", "
+ + dataModel.getPreferenceTable() + " WHERE " + itemIDBColumn + " = "
+ + dataModel.getItemIDColumn() + " AND " + dataModel.getUserIDColumn() + "=? UNION DISTINCT"
+ + " SELECT " + itemIDBColumn + " AS id FROM " + diffsTable + ", "
+ + dataModel.getPreferenceTable() + " WHERE " + itemIDAColumn + " = "
+ + dataModel.getItemIDColumn() + " AND " + dataModel.getUserIDColumn() + "=?) "
+ + "possible_item_ids WHERE id NOT IN (SELECT " + dataModel.getItemIDColumn() + " FROM "
+ + dataModel.getPreferenceTable() + " WHERE " + dataModel.getUserIDColumn() + "=?)",
// deleteDiffsSQL
"TRUNCATE " + diffsTable,
// createDiffsSQL
- "INSERT INTO " + diffsTable + " (" + itemIDAColumn + ", " + itemIDBColumn + ", " + avgColumn +
- ", " + countColumn + ") SELECT prefsA." + dataModel.getItemIDColumn() + ", prefsB." +
- dataModel.getItemIDColumn() + ',' +" AVG(prefsB." + dataModel.getPreferenceColumn() +
- " - prefsA." + dataModel.getPreferenceColumn() + ")," + " COUNT(1) AS count FROM " +
- dataModel.getPreferenceTable() + " prefsA, " + dataModel.getPreferenceTable() + " prefsB WHERE prefsA." +
- dataModel.getUserIDColumn() + " = prefsB." + dataModel.getUserIDColumn() +
- " AND prefsA." + dataModel.getItemIDColumn() + " < prefsB." +
- dataModel.getItemIDColumn() + ' ' + " GROUP BY prefsA." + dataModel.getItemIDColumn() +
- ", prefsB." + dataModel.getItemIDColumn() + " HAVING count >=?",
+ "INSERT INTO " + diffsTable + " (" + itemIDAColumn + ", " + itemIDBColumn + ", " + avgColumn + ", "
+ + countColumn + ") SELECT prefsA." + dataModel.getItemIDColumn() + ", prefsB."
+ + dataModel.getItemIDColumn() + ',' + " AVG(prefsB." + dataModel.getPreferenceColumn()
+ + " - prefsA." + dataModel.getPreferenceColumn() + ")," + " COUNT(1) AS count FROM "
+ + dataModel.getPreferenceTable() + " prefsA, " + dataModel.getPreferenceTable()
+ + " prefsB WHERE prefsA." + dataModel.getUserIDColumn() + " = prefsB."
+ + dataModel.getUserIDColumn() + " AND prefsA." + dataModel.getItemIDColumn() + " < prefsB."
+ + dataModel.getItemIDColumn() + ' ' + " GROUP BY prefsA." + dataModel.getItemIDColumn()
+ + ", prefsB." + dataModel.getItemIDColumn() + " HAVING count >=?",
// diffsExistSQL
- "SELECT COUNT(1) FROM " + diffsTable,
- minDiffCount);
+ "SELECT COUNT(1) FROM " + diffsTable, minDiffCount);
}
-
+
/**
- * @see MySQLJDBCDataModel#getFetchSize()
+ * @see MySQLJDBCDataModel#getFetchSize()
*/
@Override
protected int getFetchSize() {
return Integer.MIN_VALUE;
}
-
+
}
\ No newline at end of file
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVD.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVD.java?rev=909912&r1=909911&r2=909912&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVD.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ExpectationMaximizationSVD.java Sat Feb 13 20:54:05 2010
@@ -17,59 +17,63 @@
package org.apache.mahout.cf.taste.impl.recommender.svd;
-import org.apache.mahout.common.RandomUtils;
-
import java.util.Random;
+import org.apache.mahout.common.RandomUtils;
+
/** Calculates the SVD using an Expectation Maximization algorithm. */
public final class ExpectationMaximizationSVD {
-
+
private static final Random random = RandomUtils.getRandom();
-
+
private static final double LEARNING_RATE = 0.005;
/** Parameter used to prevent overfitting. 0.02 is a good value. */
private static final double K = 0.02;
/** Random noise applied to starting values. */
private static final double r = 0.005;
-
+
private final int m;
private final int n;
private final int k;
-
+
/** User singular vector. */
private final double[][] leftVector;
-
+
/** Item singular vector. */
private final double[][] rightVector;
-
+
/**
- * @param m number of columns
- * @param n number of rows
- * @param k number of features
- * @param defaultValue default starting values for the SVD vectors
+ * @param m
+ * number of columns
+ * @param n
+ * number of rows
+ * @param k
+ * number of features
+ * @param defaultValue
+ * default starting values for the SVD vectors
*/
public ExpectationMaximizationSVD(int m, int n, int k, double defaultValue) {
- this(m, n, k, defaultValue, r);
+ this(m, n, k, defaultValue, ExpectationMaximizationSVD.r);
}
-
+
public ExpectationMaximizationSVD(int m, int n, int k, double defaultValue, double noise) {
this.m = m;
this.n = n;
this.k = k;
-
+
leftVector = new double[m][k];
rightVector = new double[n][k];
-
+
for (int i = 0; i < k; i++) {
for (int j = 0; j < m; j++) {
- leftVector[j][i] = defaultValue + (random.nextDouble() - 0.5) * noise;
+ leftVector[j][i] = defaultValue + (ExpectationMaximizationSVD.random.nextDouble() - 0.5) * noise;
}
for (int j = 0; j < n; j++) {
- rightVector[j][i] = defaultValue + (random.nextDouble() - 0.5) * noise;
+ rightVector[j][i] = defaultValue + (ExpectationMaximizationSVD.random.nextDouble() - 0.5) * noise;
}
}
}
-
+
public double getDotProduct(int i, int j) {
double result = 1.0;
double[] leftVectorI = leftVector[i];
@@ -79,25 +83,27 @@
}
return result;
}
-
+
public void train(int i, int j, int k, double value) {
double err = value - getDotProduct(i, j);
double[] leftVectorI = leftVector[i];
double[] rightVectorJ = rightVector[j];
- leftVectorI[k] += LEARNING_RATE * (err * rightVectorJ[k] - K * leftVectorI[k]);
- rightVectorJ[k] += LEARNING_RATE * (err * leftVectorI[k] - K * rightVectorJ[k]);
+ leftVectorI[k] += ExpectationMaximizationSVD.LEARNING_RATE
+ * (err * rightVectorJ[k] - ExpectationMaximizationSVD.K * leftVectorI[k]);
+ rightVectorJ[k] += ExpectationMaximizationSVD.LEARNING_RATE
+ * (err * leftVectorI[k] - ExpectationMaximizationSVD.K * rightVectorJ[k]);
}
-
+
int getM() {
return m;
}
-
+
int getN() {
return n;
}
-
+
int getK() {
return k;
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java?rev=909912&r1=909911&r2=909912&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java Sat Feb 13 20:54:05 2010
@@ -17,6 +17,13 @@
package org.apache.mahout.cf.taste.impl.recommender.svd;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.Callable;
+
import org.apache.mahout.cf.taste.common.NoSuchItemException;
import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.common.Refreshable;
@@ -25,103 +32,99 @@
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
-import org.apache.mahout.cf.taste.recommender.IDRescorer;
-import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.cf.taste.impl.recommender.AbstractRecommender;
import org.apache.mahout.cf.taste.impl.recommender.TopItems;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
import org.apache.mahout.cf.taste.recommender.RecommendedItem;
import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.common.RandomUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.List;
-import java.util.Random;
-import java.util.concurrent.Callable;
-
/**
- * <p>A {@link Recommender} which uses Single Value Decomposition to find the main features of the data set.
+ * <p>
+ * A {@link Recommender} which uses Single Value Decomposition to find the main features of the data set.
* Thanks to Simon Funk for the hints in the implementation.
*/
public final class SVDRecommender extends AbstractRecommender {
-
+
private static final Logger log = LoggerFactory.getLogger(SVDRecommender.class);
private static final Random random = RandomUtils.getRandom();
-
+
private final RefreshHelper refreshHelper;
-
+
/** Number of features */
private final int numFeatures;
-
+
private final FastByIDMap<Integer> userMap;
private final FastByIDMap<Integer> itemMap;
private final ExpectationMaximizationSVD emSvd;
private final List<Preference> cachedPreferences;
-
+
/**
- * @param numFeatures the number of features
- * @param initialSteps number of initial training steps
+ * @param numFeatures
+ * the number of features
+ * @param initialSteps
+ * number of initial training steps
*/
public SVDRecommender(DataModel dataModel, int numFeatures, int initialSteps) throws TasteException {
super(dataModel);
-
+
this.numFeatures = numFeatures;
-
+
int numUsers = dataModel.getNumUsers();
userMap = new FastByIDMap<Integer>(numUsers);
-
+
int idx = 0;
LongPrimitiveIterator userIterator = dataModel.getUserIDs();
while (userIterator.hasNext()) {
userMap.put(userIterator.nextLong(), idx++);
}
-
+
int numItems = dataModel.getNumItems();
itemMap = new FastByIDMap<Integer>(numItems);
-
+
idx = 0;
LongPrimitiveIterator itemIterator = dataModel.getItemIDs();
while (itemIterator.hasNext()) {
itemMap.put(itemIterator.nextLong(), idx++);
}
-
+
double average = getAveragePreference();
- double defaultValue = Math.sqrt((average - 1.0) / (double) numFeatures);
-
+ double defaultValue = Math.sqrt((average - 1.0) / numFeatures);
+
emSvd = new ExpectationMaximizationSVD(numUsers, numItems, numFeatures, defaultValue);
cachedPreferences = new ArrayList<Preference>(numUsers);
recachePreferences();
-
+
refreshHelper = new RefreshHelper(new Callable<Object>() {
@Override
public Object call() throws TasteException {
recachePreferences();
- //TODO: train again
+ // TODO: train again
return null;
}
});
refreshHelper.addDependency(dataModel);
-
+
train(initialSteps);
}
-
+
private void recachePreferences() throws TasteException {
cachedPreferences.clear();
DataModel dataModel = getDataModel();
LongPrimitiveIterator it = dataModel.getUserIDs();
while (it.hasNext()) {
- for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+ for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
cachedPreferences.add(pref);
}
}
}
-
+
private double getAveragePreference() throws TasteException {
RunningAverage average = new FullRunningAverage();
DataModel dataModel = getDataModel();
@@ -133,15 +136,15 @@
}
return average.getAverage();
}
-
+
public void train(int steps) {
for (int i = 0; i < steps; i++) {
nextTrainStep();
}
}
-
+
private void nextTrainStep() {
- Collections.shuffle(cachedPreferences, random);
+ Collections.shuffle(cachedPreferences, SVDRecommender.random);
for (int i = 0; i < numFeatures; i++) {
for (Preference pref : cachedPreferences) {
int useridx = userMap.get(pref.getUserID());
@@ -150,12 +153,11 @@
}
}
}
-
+
private float predictRating(int user, int item) {
return (float) emSvd.getDotProduct(user, item);
}
-
-
+
@Override
public float estimatePreference(long userID, long itemID) throws TasteException {
Integer useridx = userMap.get(userID);
@@ -168,49 +170,48 @@
}
return predictRating(useridx, itemidx);
}
-
+
@Override
- public List<RecommendedItem> recommend(long userID,
- int howMany,
- IDRescorer rescorer) throws TasteException {
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
if (howMany < 1) {
throw new IllegalArgumentException("howMany must be at least 1");
}
-
- log.debug("Recommending items for user ID '{}'", userID);
-
+
+ SVDRecommender.log.debug("Recommending items for user ID '{}'", userID);
+
FastIDSet possibleItemIDs = getAllOtherItems(userID);
-
+
TopItems.Estimator<Long> estimator = new Estimator(userID);
-
- List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer, estimator);
-
- log.debug("Recommendations are: {}", topItems);
+
+ List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer,
+ estimator);
+
+ SVDRecommender.log.debug("Recommendations are: {}", topItems);
return topItems;
}
-
+
@Override
public void refresh(Collection<Refreshable> alreadyRefreshed) {
refreshHelper.refresh(alreadyRefreshed);
}
-
+
@Override
public String toString() {
return "SVDRecommender[numFeatures:" + numFeatures + ']';
}
-
+
private final class Estimator implements TopItems.Estimator<Long> {
-
+
private final long theUserID;
-
+
private Estimator(long theUserID) {
this.theUserID = theUserID;
}
-
+
@Override
public double estimate(Long itemID) throws TasteException {
return estimatePreference(theUserID, itemID);
}
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java?rev=909912&r1=909911&r2=909912&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java Sat Feb 13 20:54:05 2010
@@ -17,6 +17,9 @@
package org.apache.mahout.cf.taste.impl.similarity;
+import java.util.Collection;
+import java.util.concurrent.Callable;
+
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.common.Weighting;
@@ -29,27 +32,32 @@
import org.apache.mahout.cf.taste.transforms.PreferenceTransform;
import org.apache.mahout.cf.taste.transforms.SimilarityTransform;
-import java.util.Collection;
-import java.util.concurrent.Callable;
-
/** Abstract superclass encapsulating functionality that is common to most implementations in this package. */
abstract class AbstractSimilarity implements UserSimilarity, ItemSimilarity {
-
+
private final DataModel dataModel;
private PreferenceInferrer inferrer;
private PreferenceTransform prefTransform;
private SimilarityTransform similarityTransform;
- private boolean weighted;
+ private final boolean weighted;
private int cachedNumItems;
private int cachedNumUsers;
private final RefreshHelper refreshHelper;
-
- /** <p>Creates a normal (unweighted) {@link AbstractSimilarity}.</p> */
+
+ /**
+ * <p>
+ * Creates a normal (unweighted) {@link AbstractSimilarity}.
+ * </p>
+ */
AbstractSimilarity(DataModel dataModel) throws TasteException {
this(dataModel, Weighting.UNWEIGHTED);
}
-
- /** <p>Creates a possibly weighted {@link AbstractSimilarity}.</p> */
+
+ /**
+ * <p>
+ * Creates a possibly weighted {@link AbstractSimilarity}.
+ * </p>
+ */
AbstractSimilarity(final DataModel dataModel, Weighting weighting) throws TasteException {
if (dataModel == null) {
throw new IllegalArgumentException("dataModel is null");
@@ -68,15 +76,15 @@
});
this.refreshHelper.addDependency(this.dataModel);
}
-
+
final DataModel getDataModel() {
return dataModel;
}
-
+
final PreferenceInferrer getPreferenceInferrer() {
return inferrer;
}
-
+
@Override
public final void setPreferenceInferrer(PreferenceInferrer inferrer) {
if (inferrer == null) {
@@ -86,66 +94,75 @@
refreshHelper.removeDependency(this.inferrer);
this.inferrer = inferrer;
}
-
+
public final PreferenceTransform getPrefTransform() {
return prefTransform;
}
-
+
public final void setPrefTransform(PreferenceTransform prefTransform) {
refreshHelper.addDependency(prefTransform);
refreshHelper.removeDependency(this.prefTransform);
this.prefTransform = prefTransform;
}
-
+
public final SimilarityTransform getSimilarityTransform() {
return similarityTransform;
}
-
+
public final void setSimilarityTransform(SimilarityTransform similarityTransform) {
refreshHelper.addDependency(similarityTransform);
refreshHelper.removeDependency(this.similarityTransform);
this.similarityTransform = similarityTransform;
}
-
+
final boolean isWeighted() {
return weighted;
}
-
+
/**
- * <p>Several subclasses in this package implement this method to actually compute the similarity from figures
- * computed over users or items. Note that the computations in this class "center" the data, such that X and Y's mean
- * are 0.</p>
- *
- * <p>Note that the sum of all X and Y values must then be 0. This value isn't passed down into the standard
- * similarity computations as a result.</p>
- *
- * @param n total number of users or items
- * @param sumXY sum of product of user/item preference values, over all items/users prefererred by both
- * users/items
- * @param sumX2 sum of the square of user/item preference values, over the first item/user
- * @param sumY2 sum of the square of the user/item preference values, over the second item/user
- * @param sumXYdiff2 sum of squares of differences in X and Y values
- * @return similarity value between -1.0 and 1.0, inclusive, or {@link Double#NaN} if no similarity can be computed
- * (e.g. when no items have been rated by both uesrs
+ * <p>
+ * Several subclasses in this package implement this method to actually compute the similarity from figures
+ * computed over users or items. Note that the computations in this class "center" the data, such that X and
+ * Y's mean are 0.
+ * </p>
+ *
+ * <p>
+ * Note that the sum of all X and Y values must then be 0. This value isn't passed down into the standard
+ * similarity computations as a result.
+ * </p>
+ *
+ * @param n
+ * total number of users or items
+ * @param sumXY
+ * sum of product of user/item preference values, over all items/users prefererred by both
+ * users/items
+ * @param sumX2
+ * sum of the square of user/item preference values, over the first item/user
+ * @param sumY2
+ * sum of the square of the user/item preference values, over the second item/user
+ * @param sumXYdiff2
+ * sum of squares of differences in X and Y values
+ * @return similarity value between -1.0 and 1.0, inclusive, or {@link Double#NaN} if no similarity can be
+ * computed (e.g. when no items have been rated by both uesrs
*/
abstract double computeResult(int n, double sumXY, double sumX2, double sumY2, double sumXYdiff2);
-
+
@Override
public double userSimilarity(long userID1, long userID2) throws TasteException {
PreferenceArray xPrefs = dataModel.getPreferencesFromUser(userID1);
PreferenceArray yPrefs = dataModel.getPreferencesFromUser(userID2);
int xLength = xPrefs.length();
int yLength = yPrefs.length();
-
- if (xLength == 0 || yLength == 0) {
+
+ if ((xLength == 0) || (yLength == 0)) {
return Double.NaN;
}
-
+
long xIndex = xPrefs.getItemID(0);
long yIndex = yPrefs.getItemID(0);
int xPrefIndex = 0;
int yPrefIndex = 0;
-
+
double sumX = 0.0;
double sumX2 = 0.0;
double sumY = 0.0;
@@ -153,13 +170,13 @@
double sumXY = 0.0;
double sumXYdiff2 = 0.0;
int count = 0;
-
+
boolean hasInferrer = inferrer != null;
boolean hasPrefTransform = prefTransform != null;
-
+
while (true) {
int compare = xIndex < yIndex ? -1 : xIndex > yIndex ? 1 : 0;
- if (hasInferrer || compare == 0) {
+ if (hasInferrer || (compare == 0)) {
double x;
double y;
if (xIndex == yIndex) {
@@ -176,13 +193,15 @@
// as if the other user expressed that preference
if (compare < 0) {
// X has a value; infer Y's
- x = hasPrefTransform ? prefTransform.getTransformedValue(xPrefs.get(xPrefIndex)) : xPrefs.getValue(xPrefIndex);
+ x = hasPrefTransform ? prefTransform.getTransformedValue(xPrefs.get(xPrefIndex)) : xPrefs
+ .getValue(xPrefIndex);
y = inferrer.inferPreference(userID2, xIndex);
} else {
// compare > 0
// Y has a value; infer X's
x = inferrer.inferPreference(userID1, yIndex);
- y = hasPrefTransform ? prefTransform.getTransformedValue(yPrefs.get(yPrefIndex)) : yPrefs.getValue(yPrefIndex);
+ y = hasPrefTransform ? prefTransform.getTransformedValue(yPrefs.get(yPrefIndex)) : yPrefs
+ .getValue(yPrefIndex);
}
}
sumXY += x * y;
@@ -207,9 +226,9 @@
yIndex = yPrefs.getItemID(yPrefIndex);
}
}
-
+
// "Center" the data. If my math is correct, this'll do it.
- double n = (double) count;
+ double n = count;
double meanX = sumX / n;
double meanY = sumY / n;
// double centeredSumXY = sumXY - meanY * sumX - meanX * sumY + n * meanX * meanY;
@@ -218,35 +237,35 @@
double centeredSumX2 = sumX2 - meanX * sumX;
// double centeredSumY2 = sumY2 - 2.0 * meanY * sumY + n * meanY * meanY;
double centeredSumY2 = sumY2 - meanY * sumY;
-
+
double result = computeResult(count, centeredSumXY, centeredSumX2, centeredSumY2, sumXYdiff2);
-
+
if (similarityTransform != null) {
result = similarityTransform.transformSimilarity(userID1, userID2, result);
}
-
+
if (!Double.isNaN(result)) {
result = normalizeWeightResult(result, count, cachedNumItems);
}
return result;
}
-
+
@Override
public final double itemSimilarity(long itemID1, long itemID2) throws TasteException {
PreferenceArray xPrefs = dataModel.getPreferencesForItem(itemID1);
PreferenceArray yPrefs = dataModel.getPreferencesForItem(itemID2);
int xLength = xPrefs.length();
int yLength = yPrefs.length();
-
- if (xLength == 0 || yLength == 0) {
+
+ if ((xLength == 0) || (yLength == 0)) {
return Double.NaN;
}
-
+
long xIndex = xPrefs.getUserID(0);
long yIndex = yPrefs.getUserID(0);
int xPrefIndex = 0;
int yPrefIndex = 0;
-
+
double sumX = 0.0;
double sumX2 = 0.0;
double sumY = 0.0;
@@ -254,9 +273,9 @@
double sumXY = 0.0;
double sumXYdiff2 = 0.0;
int count = 0;
-
+
// No, pref inferrers and transforms don't appy here. I think.
-
+
while (true) {
int compare = xIndex < yIndex ? -1 : xIndex > yIndex ? 1 : 0;
if (compare == 0) {
@@ -285,9 +304,9 @@
yIndex = yPrefs.getUserID(yPrefIndex);
}
}
-
+
// See comments above on these computations
- double n = (double) count;
+ double n = count;
double meanX = sumX / n;
double meanY = sumY / n;
// double centeredSumXY = sumXY - meanY * sumX - meanX * sumY + n * meanX * meanY;
@@ -296,19 +315,19 @@
double centeredSumX2 = sumX2 - meanX * sumX;
// double centeredSumY2 = sumY2 - 2.0 * meanY * sumY + n * meanY * meanY;
double centeredSumY2 = sumY2 - meanY * sumY;
-
+
double result = computeResult(count, centeredSumXY, centeredSumX2, centeredSumY2, sumXYdiff2);
-
+
if (similarityTransform != null) {
result = similarityTransform.transformSimilarity(itemID1, itemID2, result);
}
-
+
if (!Double.isNaN(result)) {
result = normalizeWeightResult(result, count, cachedNumUsers);
}
return result;
}
-
+
final double normalizeWeightResult(double result, int count, int num) {
if (weighted) {
double scaleFactor = 1.0 - (double) count / (double) (num + 1);
@@ -326,15 +345,15 @@
}
return result;
}
-
+
@Override
public final void refresh(Collection<Refreshable> alreadyRefreshed) {
refreshHelper.refresh(alreadyRefreshed);
}
-
+
@Override
public final String toString() {
return this.getClass().getSimpleName() + "[dataModel:" + dataModel + ",inferrer:" + inferrer + ']';
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java?rev=909912&r1=909911&r2=909912&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java Sat Feb 13 20:54:05 2010
@@ -17,6 +17,8 @@
package org.apache.mahout.cf.taste.impl.similarity;
+import java.util.Collection;
+
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.Cache;
@@ -27,46 +29,46 @@
import org.apache.mahout.cf.taste.model.PreferenceArray;
import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
-import java.util.Collection;
-
/**
- * <p>Implementations of this interface compute an inferred preference for a user and an item that the
- * user has not expressed any preference for. This might be an average of other preferences scores from that user, for
- * example. This technique is sometimes called "default voting".</p>
+ * <p>
+ * Implementations of this interface compute an inferred preference for a user and an item that the user has
+ * not expressed any preference for. This might be an average of other preferences scores from that user, for
+ * example. This technique is sometimes called "default voting".
+ * </p>
*/
public final class AveragingPreferenceInferrer implements PreferenceInferrer {
-
+
private static final Float ZERO = 0.0f;
-
+
private final DataModel dataModel;
- private final Cache<Long, Float> averagePreferenceValue;
-
+ private final Cache<Long,Float> averagePreferenceValue;
+
public AveragingPreferenceInferrer(DataModel dataModel) throws TasteException {
this.dataModel = dataModel;
- Retriever<Long, Float> retriever = new PrefRetriever();
- averagePreferenceValue = new Cache<Long, Float>(retriever, dataModel.getNumUsers());
+ Retriever<Long,Float> retriever = new PrefRetriever();
+ averagePreferenceValue = new Cache<Long,Float>(retriever, dataModel.getNumUsers());
refresh(null);
}
-
+
@Override
public float inferPreference(long userID, long itemID) throws TasteException {
return averagePreferenceValue.get(userID);
}
-
+
@Override
public void refresh(Collection<Refreshable> alreadyRefreshed) {
averagePreferenceValue.clear();
}
-
- private final class PrefRetriever implements Retriever<Long, Float> {
-
+
+ private final class PrefRetriever implements Retriever<Long,Float> {
+
@Override
public Float get(Long key) throws TasteException {
RunningAverage average = new FullRunningAverage();
PreferenceArray prefs = dataModel.getPreferencesFromUser(key);
int size = prefs.length();
if (size == 0) {
- return ZERO;
+ return AveragingPreferenceInferrer.ZERO;
}
for (int i = 0; i < size; i++) {
average.addDatum(prefs.getValue(i));
@@ -74,10 +76,10 @@
return (float) average.getAverage();
}
}
-
+
@Override
public String toString() {
return "AveragingPreferenceInferrer";
}
-
+
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java?rev=909912&r1=909911&r2=909912&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java Sat Feb 13 20:54:05 2010
@@ -17,56 +17,56 @@
package org.apache.mahout.cf.taste.impl.similarity;
+import java.util.Collection;
+
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.Cache;
-import org.apache.mahout.common.LongPair;
import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
import org.apache.mahout.cf.taste.impl.common.Retriever;
import org.apache.mahout.cf.taste.model.DataModel;
import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
-
-import java.util.Collection;
+import org.apache.mahout.common.LongPair;
/** Caches the results from an underlying {@link ItemSimilarity} implementation. */
public final class CachingItemSimilarity implements ItemSimilarity {
-
+
private final ItemSimilarity similarity;
- private final Cache<LongPair, Double> similarityCache;
-
+ private final Cache<LongPair,Double> similarityCache;
+
public CachingItemSimilarity(ItemSimilarity similarity, DataModel dataModel) throws TasteException {
if (similarity == null) {
throw new IllegalArgumentException("similarity is null");
}
this.similarity = similarity;
int maxCacheSize = dataModel.getNumItems(); // just a dumb heuristic for sizing
- this.similarityCache = new Cache<LongPair, Double>(new SimilarityRetriever(similarity), maxCacheSize);
+ this.similarityCache = new Cache<LongPair,Double>(new SimilarityRetriever(similarity), maxCacheSize);
}
-
+
@Override
public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
LongPair key = itemID1 < itemID2 ? new LongPair(itemID1, itemID2) : new LongPair(itemID2, itemID1);
return similarityCache.get(key);
}
-
+
@Override
public void refresh(Collection<Refreshable> alreadyRefreshed) {
similarityCache.clear();
alreadyRefreshed = RefreshHelper.buildRefreshed(alreadyRefreshed);
RefreshHelper.maybeRefresh(alreadyRefreshed, similarity);
}
-
- private static final class SimilarityRetriever implements Retriever<LongPair, Double> {
+
+ private static final class SimilarityRetriever implements Retriever<LongPair,Double> {
private final ItemSimilarity similarity;
-
+
private SimilarityRetriever(ItemSimilarity similarity) {
this.similarity = similarity;
}
-
+
@Override
public Double get(LongPair key) throws TasteException {
return similarity.itemSimilarity(key.getFirst(), key.getSecond());
}
}
-
+
}
\ No newline at end of file