You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2008/05/09 23:35:17 UTC
svn commit: r654943 [6/9] - in /lucene/mahout/trunk/core: ./ lib/
src/main/examples/org/ src/main/examples/org/apache/
src/main/examples/org/apache/mahout/ src/main/examples/org/apache/mahout/cf/
src/main/examples/org/apache/mahout/cf/taste/ src/main/e...
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TreeClusteringRecommender2.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TreeClusteringRecommender2.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TreeClusteringRecommender2.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TreeClusteringRecommender2.java Fri May 9 14:35:12 2008
@@ -0,0 +1,507 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Item;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.User;
+import org.apache.mahout.cf.taste.recommender.ClusteringRecommender;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Rescorer;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.ListIterator;
+import java.util.Map;
+import java.util.concurrent.locks.ReentrantLock;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * <p>A {@link org.apache.mahout.cf.taste.recommender.Recommender} that clusters
+ * {@link org.apache.mahout.cf.taste.model.User}s, then determines
+ * the clusters' top recommendations. This implementation builds clusters by repeatedly merging clusters
+ * until only a certain number remain, meaning that each cluster is sort of a tree of other clusters.</p>
+ *
+ * <p>This {@link org.apache.mahout.cf.taste.recommender.Recommender} therefore has a few properties to note:</p>
+ * <ul>
+ * <li>For all {@link org.apache.mahout.cf.taste.model.User}s in a cluster, recommendations will be the same</li>
+ * <li>{@link #estimatePreference(Object, Object)} may well return {@link Double#NaN}; it does so when asked
+ * to estimate preference for an {@link org.apache.mahout.cf.taste.model.Item} for which no preference is expressed in the
+ * {@link org.apache.mahout.cf.taste.model.User}s in the cluster.</li>
+ * </ul>
+ *
+ * <p>This is an <em>experimental</em> implementation which tries to gain a lot of speed at the cost of
+ * accuracy in building clusters, compared to {@link org.apache.mahout.cf.taste.impl.recommender.TreeClusteringRecommender}.
+ * It will sometimes cluster two other clusters together that may not be the exact closest two clusters
+ * in existence. This may not affect the recommendation quality much, but it potentially speeds up the
+ * clustering process dramatically.</p>
+ */
+public final class TreeClusteringRecommender2 extends AbstractRecommender implements ClusteringRecommender {
+
+ private static final Logger log = Logger.getLogger(TreeClusteringRecommender2.class.getName());
+
+ private final ClusterSimilarity clusterSimilarity;
+ private final int numClusters;
+ private final double clusteringThreshold;
+ private final boolean clusteringByThreshold;
+ private Map<Object, List<RecommendedItem>> topRecsByUserID;
+ private Collection<Collection<User>> allClusters;
+ private Map<Object, Collection<User>> clustersByUserID;
+ private boolean clustersBuilt;
+ private final ReentrantLock refreshLock;
+ private final ReentrantLock buildClustersLock;
+
+ /**
+ * @param dataModel {@link org.apache.mahout.cf.taste.model.DataModel} which provdes {@link org.apache.mahout.cf.taste.model.User}s
+ * @param clusterSimilarity {@link org.apache.mahout.cf.taste.impl.recommender.ClusterSimilarity} used to compute
+ * cluster similarity
+ * @param numClusters desired number of clusters to create
+ * @throws IllegalArgumentException if arguments are <code>null</code>, or <code>numClusters</code> is
+ * less than 2
+ */
+ public TreeClusteringRecommender2(DataModel dataModel,
+ ClusterSimilarity clusterSimilarity,
+ int numClusters) {
+ super(dataModel);
+ if (clusterSimilarity == null) {
+ throw new IllegalArgumentException("clusterSimilarity is null");
+ }
+ if (numClusters < 2) {
+ throw new IllegalArgumentException("numClusters must be at least 2");
+ }
+ this.clusterSimilarity = clusterSimilarity;
+ this.numClusters = numClusters;
+ this.clusteringThreshold = Double.NaN;
+ this.clusteringByThreshold = false;
+ this.refreshLock = new ReentrantLock();
+ this.buildClustersLock = new ReentrantLock();
+ }
+
+ /**
+ * @param dataModel {@link org.apache.mahout.cf.taste.model.DataModel} which provdes {@link org.apache.mahout.cf.taste.model.User}s
+ * @param clusterSimilarity {@link org.apache.mahout.cf.taste.impl.recommender.ClusterSimilarity} used to compute
+ * cluster similarity
+ * @param clusteringThreshold clustering similarity threshold; clusters will be aggregated into larger
+ * clusters until the next two nearest clusters' similarity drops below this threshold
+ * @throws IllegalArgumentException if arguments are <code>null</code>, or <code>clusteringThreshold</code> is
+ * {@link Double#NaN}
+ */
+ public TreeClusteringRecommender2(DataModel dataModel,
+ ClusterSimilarity clusterSimilarity,
+ double clusteringThreshold) {
+ super(dataModel);
+ if (clusterSimilarity == null) {
+ throw new IllegalArgumentException("clusterSimilarity is null");
+ }
+ if (Double.isNaN(clusteringThreshold)) {
+ throw new IllegalArgumentException("clusteringThreshold must not be NaN");
+ }
+ this.clusterSimilarity = clusterSimilarity;
+ this.numClusters = Integer.MIN_VALUE;
+ this.clusteringThreshold = clusteringThreshold;
+ this.clusteringByThreshold = true;
+ this.refreshLock = new ReentrantLock();
+ this.buildClustersLock = new ReentrantLock();
+ }
+
+ public List<RecommendedItem> recommend(Object userID, int howMany, Rescorer<Item> rescorer)
+ throws TasteException {
+ if (userID == null || rescorer == null) {
+ throw new IllegalArgumentException("userID or rescorer is null");
+ }
+ if (howMany < 1) {
+ throw new IllegalArgumentException("howMany must be at least 1");
+ }
+ checkClustersBuilt();
+
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Recommending items for user ID '" + userID + '\'');
+ }
+
+ List<RecommendedItem> recommended = topRecsByUserID.get(userID);
+ if (recommended == null) {
+ return Collections.emptyList();
+ }
+
+ User theUser = getDataModel().getUser(userID);
+ List<RecommendedItem> rescored = new ArrayList<RecommendedItem>(recommended.size());
+ // Only add items the user doesn't already have a preference for.
+ // And that the rescorer doesn't "reject".
+ for (RecommendedItem recommendedItem : recommended) {
+ Item item = recommendedItem.getItem();
+ if (rescorer.isFiltered(item)) {
+ continue;
+ }
+ if (theUser.getPreferenceFor(item.getID()) == null &&
+ !Double.isNaN(rescorer.rescore(item, recommendedItem.getValue()))) {
+ rescored.add(recommendedItem);
+ }
+ }
+ Collections.sort(rescored, new ByRescoreComparator(rescorer));
+
+ return rescored;
+ }
+
+ public double estimatePreference(Object userID, Object itemID) throws TasteException {
+ if (userID == null || itemID == null) {
+ throw new IllegalArgumentException("userID or itemID is null");
+ }
+ DataModel model = getDataModel();
+ User theUser = model.getUser(userID);
+ Preference actualPref = theUser.getPreferenceFor(itemID);
+ if (actualPref != null) {
+ return actualPref.getValue();
+ }
+ checkClustersBuilt();
+ List<RecommendedItem> topRecsForUser = topRecsByUserID.get(userID);
+ if (topRecsForUser != null) {
+ for (RecommendedItem item : topRecsForUser) {
+ if (itemID.equals(item.getItem().getID())) {
+ return item.getValue();
+ }
+ }
+ }
+ // Hmm, we have no idea. The item is not in the user's cluster
+ return Double.NaN;
+ }
+
+ public Collection<User> getCluster(Object userID) throws TasteException {
+ if (userID == null) {
+ throw new IllegalArgumentException("userID is null");
+ }
+ checkClustersBuilt();
+ Collection<User> cluster = clustersByUserID.get(userID);
+ if (cluster == null) {
+ return Collections.emptyList();
+ } else {
+ return cluster;
+ }
+ }
+
+ public Collection<Collection<User>> getClusters() throws TasteException {
+ checkClustersBuilt();
+ return allClusters;
+ }
+
+ private void checkClustersBuilt() throws TasteException {
+ if (!clustersBuilt) {
+ buildClusters();
+ }
+ }
+
+ private static final class ClusterClusterPair implements Comparable<ClusterClusterPair> {
+
+ private final Collection<User> cluster1;
+ private final Collection<User> cluster2;
+ private final double similarity;
+
+ private ClusterClusterPair(Collection<User> cluster1,
+ Collection<User> cluster2,
+ double similarity) {
+ this.cluster1 = cluster1;
+ this.cluster2 = cluster2;
+ this.similarity = similarity;
+ }
+
+ private Collection<User> getCluster1() {
+ return cluster1;
+ }
+
+ private Collection<User> getCluster2() {
+ return cluster2;
+ }
+
+ private double getSimilarity() {
+ return similarity;
+ }
+
+ @Override
+ public int hashCode() {
+ return cluster1.hashCode() ^ cluster2.hashCode() ^ Double.valueOf(similarity).hashCode();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof ClusterClusterPair)) {
+ return false;
+ }
+ ClusterClusterPair other = (ClusterClusterPair) o;
+ return cluster1.equals(other.cluster1) &&
+ cluster2.equals(other.cluster2) &&
+ similarity == other.similarity;
+ }
+
+ public int compareTo(ClusterClusterPair other) {
+ double otherSimilarity = other.similarity;
+ if (similarity > otherSimilarity) {
+ return -1;
+ } else if (similarity < otherSimilarity) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+ }
+
+ private void buildClusters() throws TasteException {
+ try {
+ buildClustersLock.lock();
+
+ DataModel model = getDataModel();
+ int numUsers = model.getNumUsers();
+
+ if (numUsers == 0) {
+
+ topRecsByUserID = Collections.emptyMap();
+ clustersByUserID = Collections.emptyMap();
+
+ } else {
+
+ List<Collection<User>> clusters = new LinkedList<Collection<User>>();
+ // Begin with a cluster for each user:
+ for (User user : model.getUsers()) {
+ Collection<User> newCluster = new HashSet<User>();
+ newCluster.add(user);
+ clusters.add(newCluster);
+ }
+
+ boolean done = false;
+ while (!done) {
+
+ // We find a certain number of closest clusters...
+ boolean full = false;
+ LinkedList<ClusterClusterPair> queue = new LinkedList<ClusterClusterPair>();
+ int i = 0;
+ for (Collection<User> cluster1 : clusters) {
+ i++;
+ ListIterator<Collection<User>> it2 = clusters.listIterator(i);
+ while (it2.hasNext()) {
+ Collection<User> cluster2 = it2.next();
+ double similarity = clusterSimilarity.getSimilarity(cluster1, cluster2);
+ if (!Double.isNaN(similarity) &&
+ (!full || similarity > queue.getLast().getSimilarity())) {
+ ListIterator<ClusterClusterPair> queueIterator =
+ queue.listIterator(queue.size());
+ while (queueIterator.hasPrevious()) {
+ if (similarity <= queueIterator.previous().getSimilarity()) {
+ queueIterator.next();
+ break;
+ }
+ }
+ queueIterator.add(new ClusterClusterPair(cluster1, cluster2, similarity));
+ if (full) {
+ queue.removeLast();
+ } else if (queue.size() > numUsers) { // use numUsers as queue size limit
+ full = true;
+ queue.removeLast();
+ }
+ }
+ }
+ }
+
+ // The first one is definitely the closest pair in existence so we can cluster
+ // the two together, put it back into the set of clusters, and start again. Instead
+ // we assume everything else in our list of closest cluster pairs is still pretty good,
+ // and we cluster them too.
+
+ while (!queue.isEmpty()) {
+
+ if (!clusteringByThreshold && clusters.size() <= numClusters) {
+ done = true;
+ break;
+ }
+
+ ClusterClusterPair top = queue.removeFirst();
+
+ if (clusteringByThreshold && top.getSimilarity() < clusteringThreshold) {
+ done = true;
+ break;
+ }
+
+ Collection<User> cluster1 = top.getCluster1();
+ Collection<User> cluster2 = top.getCluster2();
+
+ // Pull out current two clusters from clusters
+ Iterator<Collection<User>> clusterIterator = clusters.iterator();
+ boolean removed1 = false;
+ boolean removed2 = false;
+ while (clusterIterator.hasNext() && !(removed1 && removed2)) {
+ Collection<User> current = clusterIterator.next();
+ // Yes, use == here
+ if (!removed1 && cluster1 == current) {
+ clusterIterator.remove();
+ removed1 = true;
+ } else if (!removed2 && cluster2 == current) {
+ clusterIterator.remove();
+ removed2 = true;
+ }
+ }
+
+ // The only catch is if a cluster showed it twice in the list of best cluster pairs;
+ // have to remove the others. Pull out anything referencing these clusters from queue
+ for (Iterator<ClusterClusterPair> queueIterator = queue.iterator();
+ queueIterator.hasNext();) {
+ ClusterClusterPair pair = queueIterator.next();
+ Collection<User> pair1 = pair.getCluster1();
+ Collection<User> pair2 = pair.getCluster2();
+ if (pair1 == cluster1 || pair1 == cluster2 || pair2 == cluster1 || pair2 == cluster2) {
+ queueIterator.remove();
+ }
+ }
+
+ // Make new merged cluster
+ Collection<User> merged = new HashSet<User>(cluster1.size() + cluster2.size());
+ merged.addAll(cluster1);
+ merged.addAll(cluster2);
+
+ // Compare against other clusters; update queue if needed
+ // That new pair we're just adding might be pretty close to something else, so
+ // catch that case here and put it back into our queue
+ for (Collection<User> cluster : clusters) {
+ double similarity = clusterSimilarity.getSimilarity(merged, cluster);
+ if (similarity > queue.getLast().getSimilarity()) {
+ ListIterator<ClusterClusterPair> queueIterator = queue.listIterator();
+ while (queueIterator.hasNext()) {
+ if (similarity > queueIterator.next().getSimilarity()) {
+ queueIterator.previous();
+ break;
+ }
+ }
+ queueIterator.add(new ClusterClusterPair(merged, cluster, similarity));
+ }
+ }
+
+ // Finally add new cluster to list
+ clusters.add(merged);
+
+ }
+
+ }
+
+ topRecsByUserID = computeTopRecsPerUserID(clusters);
+ clustersByUserID = computeClustersPerUserID(clusters);
+ allClusters = clusters;
+
+ }
+
+ clustersBuilt = true;
+ } finally {
+ buildClustersLock.unlock();
+ }
+ }
+
+ private static Map<Object, List<RecommendedItem>> computeTopRecsPerUserID(Iterable<Collection<User>> clusters)
+ throws TasteException {
+ Map<Object, List<RecommendedItem>> recsPerUser = new HashMap<Object, List<RecommendedItem>>();
+ for (Collection<User> cluster : clusters) {
+ List<RecommendedItem> recs = computeTopRecsForCluster(cluster);
+ for (User user : cluster) {
+ recsPerUser.put(user.getID(), recs);
+ }
+ }
+ return Collections.unmodifiableMap(recsPerUser);
+ }
+
+ private static List<RecommendedItem> computeTopRecsForCluster(Collection<User> cluster)
+ throws TasteException {
+
+ Collection<Item> allItems = new HashSet<Item>();
+ for (User user : cluster) {
+ Preference[] prefs = user.getPreferencesAsArray();
+ for (int i = 0; i < prefs.length; i++) {
+ allItems.add(prefs[i].getItem());
+ }
+ }
+
+ TopItems.Estimator<Item> estimator = new Estimator(cluster);
+
+ List<RecommendedItem> topItems =
+ TopItems.getTopItems(Integer.MAX_VALUE, allItems, NullRescorer.getItemInstance(), estimator);
+
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Recommendations are: " + topItems);
+ }
+ return Collections.unmodifiableList(topItems);
+ }
+
+ private static Map<Object, Collection<User>> computeClustersPerUserID(Collection<Collection<User>> clusters) {
+ Map<Object, Collection<User>> clustersPerUser = new HashMap<Object, Collection<User>>(clusters.size());
+ for (Collection<User> cluster : clusters) {
+ for (User user : cluster) {
+ clustersPerUser.put(user.getID(), cluster);
+ }
+ }
+ return clustersPerUser;
+ }
+
+ @Override
+ public void refresh() {
+ if (refreshLock.isLocked()) {
+ return;
+ }
+ try {
+ refreshLock.lock();
+ super.refresh();
+ clusterSimilarity.refresh();
+ try {
+ buildClusters();
+ } catch (TasteException te) {
+ log.log(Level.WARNING, "Unexpected excpetion while refreshing", te);
+ }
+ } finally {
+ refreshLock.unlock();
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "TreeClusteringRecommender2[clusterSimilarity:" + clusterSimilarity + ']';
+ }
+
+ private static class Estimator implements TopItems.Estimator<Item> {
+
+ private final Collection<User> cluster;
+
+ private Estimator(Collection<User> cluster) {
+ this.cluster = cluster;
+ }
+
+ public double estimate(Item item) {
+ RunningAverage average = new FullRunningAverage();
+ for (User user : cluster) {
+ Preference pref = user.getPreferenceFor(item.getID());
+ if (pref != null) {
+ average.addDatum(pref.getValue());
+ }
+ }
+ return average.getAverage();
+ }
+ }
+}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/InvertedRunningAverage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/InvertedRunningAverage.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/InvertedRunningAverage.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/InvertedRunningAverage.java Fri May 9 14:35:12 2008
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.slopeone;
+
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+
+final class InvertedRunningAverage implements RunningAverage {
+
+ private final RunningAverage delegate;
+
+ InvertedRunningAverage(RunningAverage delegate) {
+ this.delegate = delegate;
+ }
+
+ public void addDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ public void removeDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ public void changeDatum(double delta) {
+ throw new UnsupportedOperationException();
+ }
+
+ public int getCount() {
+ return delegate.getCount();
+ }
+
+ public double getAverage() {
+ return -delegate.getAverage();
+ }
+
+}
\ No newline at end of file
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/InvertedRunningAverageAndStdDev.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/InvertedRunningAverageAndStdDev.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/InvertedRunningAverageAndStdDev.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/InvertedRunningAverageAndStdDev.java Fri May 9 14:35:12 2008
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.slopeone;
+
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+
+final class InvertedRunningAverageAndStdDev implements RunningAverageAndStdDev {
+
+ private final RunningAverageAndStdDev delegate;
+
+ InvertedRunningAverageAndStdDev(RunningAverageAndStdDev delegate) {
+ this.delegate = delegate;
+ }
+
+ public void addDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ public void removeDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ public void changeDatum(double delta) {
+ throw new UnsupportedOperationException();
+ }
+
+ public int getCount() {
+ return delegate.getCount();
+ }
+
+ public double getAverage() {
+ return -delegate.getAverage();
+ }
+
+ public double getStandardDeviation() {
+ return delegate.getStandardDeviation();
+ }
+
+}
\ No newline at end of file
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorage.java Fri May 9 14:35:12 2008
@@ -0,0 +1,297 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.slopeone;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.CompactRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.CompactRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.FastMap;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Item;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.User;
+import org.apache.mahout.cf.taste.recommender.slopeone.DiffStorage;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * <p>An implementation of {@link DiffStorage} that merely stores item-item diffs in memory.
+ * It is fast, but can consume a great deal of memory.</p>
+ */
+public final class MemoryDiffStorage implements DiffStorage {
+
+ private static final Logger log = Logger.getLogger(MemoryDiffStorage.class.getName());
+
+ private final DataModel dataModel;
+ private final boolean stdDevWeighted;
+ private final boolean compactAverages;
+ private final long maxEntries;
+ private final Map<Object, Map<Object, RunningAverage>> averageDiffs;
+ private final Map<Object, RunningAverage> averageItemPref;
+ private final ReadWriteLock buildAverageDiffsLock;
+ private final ReentrantLock refreshLock;
+
+ /**
+ * <p>Creates a new {@link MemoryDiffStorage}.</p>
+ *
+ * <p>See {@link org.apache.mahout.cf.taste.impl.recommender.slopeone.SlopeOneRecommender} for the
+ * meaning of <code>stdDevWeighted</code>. If <code>compactAverages</code>
+ * is set, this uses alternate data structures ({@link CompactRunningAverage} versus
+ * {@link FullRunningAverage}) that use almost 50% less memory but store item-item
+ * averages less accurately. <code>maxEntries</code> controls the maximum number of item-item average
+ * preference differences that will be tracked internally. After the limit is reached,
+ * if a new item-item pair is observed in the data it will be ignored. This is recommended for large datasets.
+ * The first <code>maxEntries</code>
+ * item-item pairs observed in the data are tracked. Assuming that item ratings are reasonably distributed
+ * among users, this should only ignore item-item pairs that are very infrequently co-rated by a user.
+ * The intuition is that data on these infrequently co-rated item-item pairs is less reliable and should
+ * be the first that is ignored. This parameter can be used to limit the memory requirements of
+ * {@link SlopeOneRecommender}, which otherwise grow as the square
+ * of the number of items that exist in the {@link DataModel}. Memory requirements can reach gigabytes
+ * with only about 10000 items, so this may be necessary on larger datasets.
+ *
+ * @param dataModel
+ * @param stdDevWeighted see {@link org.apache.mahout.cf.taste.impl.recommender.slopeone.SlopeOneRecommender}
+ * @param compactAverages if <code>true</code>,
+ * use {@link CompactRunningAverage} instead of {@link FullRunningAverage} internally
+ * @param maxEntries maximum number of item-item average preference differences to track internally
+ * @throws IllegalArgumentException if <code>maxEntries</code> is not positive or <code>dataModel</code>
+ * is null
+ */
+ public MemoryDiffStorage(DataModel dataModel,
+ boolean stdDevWeighted,
+ boolean compactAverages,
+ long maxEntries) throws TasteException {
+ if (dataModel == null) {
+ throw new IllegalArgumentException("dataModel is null");
+ }
+ if (maxEntries <= 0L) {
+ throw new IllegalArgumentException("maxEntries must be positive");
+ }
+ this.dataModel = dataModel;
+ this.stdDevWeighted = stdDevWeighted;
+ this.compactAverages = compactAverages;
+ this.maxEntries = maxEntries;
+ this.averageDiffs = new FastMap<Object, Map<Object, RunningAverage>>(1003);
+ this.averageItemPref = new FastMap<Object, RunningAverage>(101);
+ this.buildAverageDiffsLock = new ReentrantReadWriteLock();
+ this.refreshLock = new ReentrantLock();
+ buildAverageDiffs();
+ }
+
+ public RunningAverage getDiff(Object itemID1, Object itemID2) {
+ Map<Object, RunningAverage> level2Map = averageDiffs.get(itemID1);
+ RunningAverage average = null;
+ if (level2Map != null) {
+ average = level2Map.get(itemID2);
+ }
+ boolean inverted = false;
+ if (average == null) {
+ level2Map = averageDiffs.get(itemID2);
+ if (level2Map != null) {
+ average = level2Map.get(itemID1);
+ inverted = true;
+ }
+ }
+ if (inverted) {
+ if (average == null) {
+ return null;
+ }
+ return stdDevWeighted ?
+ new InvertedRunningAverageAndStdDev((RunningAverageAndStdDev) average) :
+ new InvertedRunningAverage(average);
+ } else {
+ return average;
+ }
+ }
+
+ public RunningAverage[] getDiffs(Object userID, Object itemID, Preference[] prefs) {
+ try {
+ buildAverageDiffsLock.readLock().lock();
+ int size = prefs.length;
+ RunningAverage[] result = new RunningAverage[size];
+ for (int i = 0; i < size; i++) {
+ result[i] = getDiff(prefs[i].getItem().getID(), itemID);
+ }
+ return result;
+ } finally {
+ buildAverageDiffsLock.readLock().unlock();
+ }
+ }
+
+ public RunningAverage getAverageItemPref(Object itemID) {
+ return averageItemPref.get(itemID);
+ }
+
+ public void updateItemPref(Object itemID, double prefDelta, boolean remove) {
+ if (!remove && stdDevWeighted) {
+ throw new UnsupportedOperationException("Can't update only when stdDevWeighted is set");
+ }
+ try {
+ buildAverageDiffsLock.readLock().lock();
+ for (Map.Entry<Object, Map<Object, RunningAverage>> entry : averageDiffs.entrySet()) {
+ boolean matchesItemID1 = itemID.equals(entry.getKey());
+ for (Map.Entry<Object, RunningAverage> entry2 : entry.getValue().entrySet()) {
+ RunningAverage average = entry2.getValue();
+ if (matchesItemID1) {
+ if (remove) {
+ average.removeDatum(prefDelta);
+ } else {
+ average.changeDatum(-prefDelta);
+ }
+ } else if (itemID.equals(entry2.getKey())) {
+ if (remove) {
+ average.removeDatum(-prefDelta);
+ } else {
+ average.changeDatum(prefDelta);
+ }
+ }
+ }
+ }
+ RunningAverage itemAverage = averageItemPref.get(itemID);
+ if (itemAverage != null) {
+ itemAverage.changeDatum(prefDelta);
+ }
+ } finally {
+ buildAverageDiffsLock.readLock().unlock();
+ }
+ }
+
+ public Set<Item> getRecommendableItems(Object userID) throws TasteException {
+ User user = dataModel.getUser(userID);
+ Set<Item> result = new HashSet<Item>(dataModel.getNumItems());
+ for (Item item : dataModel.getItems()) {
+ // If not already preferred by the user, add it
+ if (user.getPreferenceFor(item.getID()) == null) {
+ result.add(item);
+ }
+ }
+ return result;
+ }
+
+ private void buildAverageDiffs() throws TasteException {
+ log.info("Building average diffs...");
+ try {
+ buildAverageDiffsLock.writeLock().lock();
+ long averageCount = 0L;
+ for (User user : dataModel.getUsers()) {
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Processing prefs for user " + user + "...");
+ }
+ // Save off prefs for the life of this loop iteration
+ Preference[] userPreferences = user.getPreferencesAsArray();
+ int length = userPreferences.length;
+ for (int i = 0; i < length; i++) {
+ Preference prefA = userPreferences[i];
+ double prefAValue = prefA.getValue();
+ Object itemIDA = prefA.getItem().getID();
+ Map<Object, RunningAverage> aMap = averageDiffs.get(itemIDA);
+ if (aMap == null) {
+ aMap = new HashMap<Object, RunningAverage>();
+ averageDiffs.put(itemIDA, aMap);
+ }
+ for (int j = i + 1; j < length; j++) {
+ // This is a performance-critical block
+ Preference prefB = userPreferences[j];
+ Object itemIDB = prefB.getItem().getID();
+ RunningAverage average = aMap.get(itemIDB);
+ if (average == null && averageCount < maxEntries) {
+ average = buildRunningAverage();
+ aMap.put(itemIDB, average);
+ averageCount++;
+ }
+ if (average != null) {
+ average.addDatum(prefB.getValue() - prefAValue);
+ }
+
+ }
+ RunningAverage itemAverage = averageItemPref.get(itemIDA);
+ if (itemAverage == null) {
+ itemAverage = buildRunningAverage();
+ averageItemPref.put(itemIDA, itemAverage);
+ }
+ itemAverage.addDatum(prefAValue);
+ }
+ }
+
+ // Go back and prune inconsequential diffs. "Inconsequential" means, here, an average
+ // so small (< 1 / numItems^3) that it contributes very little to computations
+ double numItems = (double) dataModel.getNumItems();
+ double threshold = 1.0 / numItems / numItems / numItems;
+ for (Iterator<Map<Object, RunningAverage>> it1 = averageDiffs.values().iterator(); it1.hasNext();) {
+ Map<Object, RunningAverage> map = it1.next();
+ for (Iterator<RunningAverage> it2 = map.values().iterator(); it2.hasNext();) {
+ RunningAverage average = it2.next();
+ if (Math.abs(average.getAverage()) < threshold) {
+ it2.remove();
+ }
+ }
+ if (map.isEmpty()) {
+ it1.remove();
+ }
+ }
+
+ } finally {
+ buildAverageDiffsLock.writeLock().unlock();
+ }
+ }
+
+ private RunningAverage buildRunningAverage() {
+ if (stdDevWeighted) {
+ return compactAverages ? new CompactRunningAverageAndStdDev() : new FullRunningAverageAndStdDev();
+ } else {
+ return compactAverages ? new CompactRunningAverage() : new FullRunningAverage();
+ }
+ }
+
+ public void refresh() {
+ if (refreshLock.isLocked()) {
+ return;
+ }
+ try {
+ refreshLock.lock();
+ dataModel.refresh();
+ try {
+ buildAverageDiffs();
+ } catch (TasteException te) {
+ log.log(Level.WARNING, "Unexpected exception while refreshing", te);
+ }
+ } finally {
+ refreshLock.unlock();
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "MemoryDiffStorage";
+ }
+
+}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/SlopeOneRecommender.java Fri May 9 14:35:12 2008
@@ -0,0 +1,221 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.slopeone;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.recommender.AbstractRecommender;
+import org.apache.mahout.cf.taste.impl.recommender.TopItems;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Item;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.User;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Rescorer;
+import org.apache.mahout.cf.taste.recommender.slopeone.DiffStorage;
+
+import java.util.List;
+import java.util.NoSuchElementException;
+import java.util.Set;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * <p>A basic "slope one" recommender. (See an <a href="http://www.daniel-lemire.com/fr/abstracts/SDM2005.html">
+ * excellent summary here</a> for example.) This {@link org.apache.mahout.cf.taste.recommender.Recommender} is especially
+ * suitable when user preferences are updating frequently as it can incorporate this information without
+ * expensive recomputation.</p>
+ *
+ * <p>This implementation can also be used as a "weighted slope one" recommender.</p>
+ */
+public final class SlopeOneRecommender extends AbstractRecommender {
+
+ private static final Logger log = Logger.getLogger(SlopeOneRecommender.class.getName());
+
+ private final boolean weighted;
+ private final boolean stdDevWeighted;
+ private final DiffStorage diffStorage;
+
+ /**
+ * <p>Creates a default (weighted) {@link SlopeOneRecommender} based on the given {@link DataModel}.</p>
+ *
+ * @param dataModel data model
+ */
+ public SlopeOneRecommender(DataModel dataModel) throws TasteException {
+ this(dataModel, true, true, new MemoryDiffStorage(dataModel, true, false, Long.MAX_VALUE));
+ }
+
+ /**
+ * <p>Creates a {@link SlopeOneRecommender} based on the given {@link DataModel}.</p>
+ *
+ * <p>If <code>weighted</code> is set, acts as a weighted slope one recommender.
+ * This implementation also includes an experimental "standard deviation" weighting which weights
+ * item-item ratings diffs with lower standard deviation more highly, on the theory that they are more
+ * reliable.</p>
+ *
+ * @param dataModel
+ * @param weighted if <code>true</code>, acts as a weighted slope one recommender
+ * @param stdDevWeighted use optional standard deviation weighting of diffs
+ * @throws IllegalArgumentException if <code>diffStorage</code> is null, or stdDevWeighted is set
+ * when weighted is not set
+ */
+ public SlopeOneRecommender(DataModel dataModel,
+ boolean weighted,
+ boolean stdDevWeighted,
+ DiffStorage diffStorage) {
+ super(dataModel);
+ if (stdDevWeighted && !weighted) {
+ throw new IllegalArgumentException("weighted required when stdDevWeighted is set");
+ }
+ if (diffStorage == null) {
+ throw new IllegalArgumentException("diffStorage is null");
+ }
+ this.weighted = weighted;
+ this.stdDevWeighted = stdDevWeighted;
+ this.diffStorage = diffStorage;
+ }
+
+ public List<RecommendedItem> recommend(Object userID, int howMany, Rescorer<Item> rescorer)
+ throws TasteException {
+ if (userID == null) {
+ throw new IllegalArgumentException("userID is null");
+ }
+ if (howMany < 1) {
+ throw new IllegalArgumentException("howMany must be at least 1");
+ }
+ if (rescorer == null) {
+ throw new IllegalArgumentException("rescorer is null");
+ }
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Recommending items for user ID '" + userID + '\'');
+ }
+
+ User theUser = getDataModel().getUser(userID);
+ Set<Item> allItems = diffStorage.getRecommendableItems(userID);
+
+ TopItems.Estimator<Item> estimator = new Estimator(theUser);
+
+ List<RecommendedItem> topItems = TopItems.getTopItems(howMany, allItems, rescorer, estimator);
+
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Recommendations are: " + topItems);
+ }
+ return topItems;
+ }
+
+ public double estimatePreference(Object userID, Object itemID) throws TasteException {
+ DataModel model = getDataModel();
+ User theUser = model.getUser(userID);
+ Preference actualPref = theUser.getPreferenceFor(itemID);
+ if (actualPref != null) {
+ return actualPref.getValue();
+ }
+ return doEstimatePreference(theUser, itemID);
+ }
+
+ private double doEstimatePreference(User theUser, Object itemID) throws TasteException {
+ double count = 0.0;
+ double totalPreference = 0.0;
+ Preference[] prefs = theUser.getPreferencesAsArray();
+ RunningAverage[] averages = diffStorage.getDiffs(theUser.getID(), itemID, prefs);
+ for (int i = 0; i < prefs.length; i++) {
+ RunningAverage averageDiff = averages[i];
+ if (averageDiff != null) {
+ Preference pref = prefs[i];
+ double averageDiffValue = averageDiff.getAverage();
+ if (weighted) {
+ double weight = (double) averageDiff.getCount();
+ if (stdDevWeighted) {
+ double stdev = ((RunningAverageAndStdDev) averageDiff).getStandardDeviation();
+ if (!Double.isNaN(stdev)) {
+ weight /= 1.0 + stdev;
+ }
+ // If stdev is NaN, then it is because count is 1. Because we're weighting by count,
+ // the weight is already relatively low. We effectively assume stdev is 0.0 here and
+ // that is reasonable enough. Otherwise, dividing by NaN would yield a weight of NaN
+ // and disqualify this pref entirely
+ // (Thanks Daemmon)
+ }
+ totalPreference += weight * (pref.getValue() + averageDiffValue);
+ count += weight;
+ } else {
+ totalPreference += pref.getValue() + averageDiffValue;
+ count += 1.0;
+ }
+ }
+ }
+ if (count <= 0.0) {
+ RunningAverage itemAverage = diffStorage.getAverageItemPref(itemID);
+ return itemAverage == null ? Double.NaN : itemAverage.getAverage();
+ } else {
+ return totalPreference / count;
+ }
+ }
+
+ @Override
+ public void setPreference(Object userID, Object itemID, double value) throws TasteException {
+ DataModel dataModel = getDataModel();
+ double prefDelta;
+ try {
+ User theUser = dataModel.getUser(userID);
+ Preference oldPref = theUser.getPreferenceFor(itemID);
+ prefDelta = oldPref == null ? value : value - oldPref.getValue();
+ } catch (NoSuchElementException nsee) {
+ prefDelta = value;
+ }
+ super.setPreference(userID, itemID, value);
+ diffStorage.updateItemPref(itemID, prefDelta, false);
+ }
+
+ @Override
+ public void removePreference(Object userID, Object itemID) throws TasteException {
+ DataModel dataModel = getDataModel();
+ User theUser = dataModel.getUser(userID);
+ Preference oldPref = theUser.getPreferenceFor(itemID);
+ super.removePreference(userID, itemID);
+ if (oldPref != null) {
+ diffStorage.updateItemPref(itemID, oldPref.getValue(), true);
+ }
+ }
+
+ @Override
+ public void refresh() {
+ diffStorage.refresh();
+ }
+
+ @Override
+ public String toString() {
+ return "SlopeOneRecommender[weighted:" + weighted + ", stdDevWeighted:" + stdDevWeighted +
+ ", diffStorage:" + diffStorage + ']';
+ }
+
+ private final class Estimator implements TopItems.Estimator<Item> {
+
+ private final User theUser;
+
+ private Estimator(User theUser) {
+ this.theUser = theUser;
+ }
+
+ public double estimate(Item item) throws TasteException {
+ return doEstimatePreference(theUser, item.getID());
+ }
+ }
+
+}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java Fri May 9 14:35:12 2008
@@ -0,0 +1,363 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.slopeone.jdbc;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.IOUtils;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.Item;
+import org.apache.mahout.cf.taste.model.JDBCDataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.recommender.slopeone.DiffStorage;
+
+import javax.sql.DataSource;
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.locks.ReentrantLock;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * <p>A {@link DiffStorage} which stores diffs in a database. Database-specific implementations subclass
+ * this abstract class. Note that this implementation has a fairly particular dependence on the
+ * {@link org.apache.mahout.cf.taste.model.DataModel} used; it needs a {@link JDBCDataModel} attached to the same
+ * database since its efficent operation depends on accessing preference data in the database directly.</p>
+ */
+public abstract class AbstractJDBCDiffStorage implements DiffStorage {
+
+ private static final Logger log = Logger.getLogger(AbstractJDBCDiffStorage.class.getName());
+
+ public static final String DEFAULT_DIFF_TABLE = "taste_slopeone_diffs";
+ public static final String DEFAULT_ITEM_A_COLUMN = "item_id_a";
+ public static final String DEFAULT_ITEM_B_COLUMN = "item_id_b";
+ public static final String DEFAULT_COUNT_COLUMN = "count";
+ public static final String DEFAULT_AVERAGE_DIFF_COLUMN = "average_diff";
+
+ private final JDBCDataModel dataModel;
+ private final DataSource dataSource;
+ private final String getDiffSQL;
+ private final String getDiffsSQL;
+ private final String getAverageItemPrefSQL;
+ private final String[] updateDiffSQLs;
+ private final String[] removeDiffSQLs;
+ private final String getRecommendableItemsSQL;
+ private final String deleteDiffsSQL;
+ private final String createDiffsSQL;
+ private final String diffsExistSQL;
+ private final int minDiffCount;
+ private final ReentrantLock refreshLock;
+
+ protected AbstractJDBCDiffStorage(JDBCDataModel dataModel,
+ String getDiffSQL,
+ String getDiffsSQL,
+ String getAverageItemPrefSQL,
+ String[] updateDiffSQLs,
+ String[] removeDiffSQLs,
+ String getRecommendableItemsSQL,
+ String deleteDiffsSQL,
+ String createDiffsSQL,
+ String diffsExistSQL,
+ int minDiffCount) throws TasteException {
+ if (dataModel == null) {
+ throw new IllegalArgumentException("dataModel is null");
+ }
+ if (minDiffCount < 0) {
+ throw new IllegalArgumentException("minDiffCount is not positive");
+ }
+ this.dataModel = dataModel;
+ this.dataSource = dataModel.getDataSource();
+ this.getDiffSQL = getDiffSQL;
+ this.getDiffsSQL = getDiffsSQL;
+ this.getAverageItemPrefSQL = getAverageItemPrefSQL;
+ this.updateDiffSQLs = updateDiffSQLs;
+ this.removeDiffSQLs = removeDiffSQLs;
+ this.getRecommendableItemsSQL = getRecommendableItemsSQL;
+ this.deleteDiffsSQL = deleteDiffsSQL;
+ this.createDiffsSQL = createDiffsSQL;
+ this.diffsExistSQL = diffsExistSQL;
+ this.minDiffCount = minDiffCount;
+ this.refreshLock = new ReentrantLock();
+ if (isDiffsExist()) {
+ log.info("Diffs already exist in database; using them instead of recomputing");
+ } else {
+ log.info("No diffs exist in database; recomputing...");
+ buildAverageDiffs();
+ }
+ }
+
+ public RunningAverage getDiff(Object itemID1, Object itemID2) throws TasteException {
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ ResultSet rs = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(getDiffSQL);
+ stmt.setObject(1, itemID1);
+ stmt.setObject(2, itemID2);
+ stmt.setObject(3, itemID2);
+ stmt.setObject(4, itemID1);
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Executing SQL query: " + getDiffSQL);
+ }
+ rs = stmt.executeQuery();
+ if (rs.next()) {
+ return new FixedRunningAverage(rs.getInt(1), rs.getDouble(2));
+ } else {
+ return null;
+ }
+ } catch (SQLException sqle) {
+ log.log(Level.WARNING, "Exception while retrieving diff", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.safeClose(rs, stmt, conn);
+ }
+ }
+
+ public RunningAverage[] getDiffs(Object userID, Object itemID, Preference[] prefs)
+ throws TasteException {
+ int size = prefs.length;
+ RunningAverage[] result = new RunningAverage[size];
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ ResultSet rs = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(getDiffsSQL);
+ stmt.setObject(1, itemID);
+ stmt.setObject(2, userID);
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Executing SQL query: " + getDiffsSQL);
+ }
+ rs = stmt.executeQuery();
+ // We should have up to one result for each Preference in prefs
+ // They are both ordered by item. Step through and create a RunningAverage[]
+ // with nulls for Preferences that have no corresponding result row
+ int i = 0;
+ while (rs.next()) {
+ String nextResultItemID = rs.getString(3);
+ while (!prefs[i].getItem().getID().equals(nextResultItemID)) {
+ i++;
+ // result[i] is null for these values of i
+ }
+ result[i] = new FixedRunningAverage(rs.getInt(1), rs.getDouble(2));
+ i++;
+ }
+ } catch (SQLException sqle) {
+ log.log(Level.WARNING, "Exception while retrieving diff", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.safeClose(rs, stmt, conn);
+ }
+ return result;
+ }
+
+ public RunningAverage getAverageItemPref(Object itemID) throws TasteException {
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ ResultSet rs = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(getAverageItemPrefSQL);
+ stmt.setObject(1, itemID);
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Executing SQL query: " + getAverageItemPrefSQL);
+ }
+ rs = stmt.executeQuery();
+ if (rs.next()) {
+ int count = rs.getInt(1);
+ if (count > 0) {
+ return new FixedRunningAverage(count, rs.getDouble(2));
+ }
+ }
+ return null;
+ } catch (SQLException sqle) {
+ log.log(Level.WARNING, "Exception while retrieving average item pref", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.safeClose(rs, stmt, conn);
+ }
+ }
+
+ public void updateItemPref(Object itemID, double prefDelta, boolean remove)
+ throws TasteException {
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ try {
+ conn = dataSource.getConnection();
+ if (remove) {
+ stmt = doPartialUpdate(removeDiffSQLs[0], itemID, prefDelta, conn);
+ stmt = doPartialUpdate(removeDiffSQLs[1], itemID, prefDelta, conn);
+ } else {
+ stmt = doPartialUpdate(updateDiffSQLs[0], itemID, prefDelta, conn);
+ stmt = doPartialUpdate(updateDiffSQLs[1], itemID, prefDelta, conn);
+ }
+ } catch (SQLException sqle) {
+ log.log(Level.WARNING, "Exception while updating item diff", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.safeClose(null, stmt, conn);
+ }
+ }
+
+ private static PreparedStatement doPartialUpdate(String sql,
+ Object itemID,
+ double prefDelta,
+ Connection conn) throws SQLException {
+ PreparedStatement stmt = conn.prepareStatement(sql);
+ stmt.setDouble(1, prefDelta);
+ stmt.setObject(2, itemID);
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Executing SQL update: " + sql);
+ }
+ stmt.executeUpdate();
+ return stmt;
+ }
+
+ public Set<Item> getRecommendableItems(Object userID) throws TasteException {
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ ResultSet rs = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(getRecommendableItemsSQL);
+ stmt.setObject(1, userID);
+ stmt.setObject(2, userID);
+ stmt.setObject(3, userID);
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Executing SQL query: " + getRecommendableItemsSQL);
+ }
+ rs = stmt.executeQuery();
+ Set<Item> items = new HashSet<Item>();
+ while (rs.next()) {
+ items.add(dataModel.getItem(rs.getObject(1), true));
+ }
+ return items;
+ } catch (SQLException sqle) {
+ log.log(Level.WARNING, "Exception while retrieving recommendable items", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.safeClose(rs, stmt, conn);
+ }
+ }
+
+ private void buildAverageDiffs() throws TasteException {
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(deleteDiffsSQL);
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Executing SQL update: " + deleteDiffsSQL);
+ }
+ stmt.executeUpdate();
+ } catch (SQLException sqle) {
+ log.log(Level.WARNING, "Exception while deleting diffs", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.safeClose(null, stmt, conn);
+ }
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(createDiffsSQL);
+ stmt.setInt(1, minDiffCount);
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Executing SQL update: " + createDiffsSQL);
+ }
+ stmt.executeUpdate();
+ } catch (SQLException sqle) {
+ log.log(Level.WARNING, "Exception while creating diffs", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.safeClose(null, stmt, conn);
+ }
+ }
+
+ private boolean isDiffsExist() throws TasteException {
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ ResultSet rs = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(diffsExistSQL);
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Executing SQL query: " + diffsExistSQL);
+ }
+ rs = stmt.executeQuery();
+ rs.next();
+ return rs.getInt(1) > 0;
+ } catch (SQLException sqle) {
+ log.log(Level.WARNING, "Exception while deleting diffs", sqle);
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.safeClose(rs, stmt, conn);
+ }
+ }
+
+ public void refresh() {
+ if (refreshLock.isLocked()) {
+ return;
+ }
+ try {
+ refreshLock.lock();
+ dataModel.refresh();
+ try {
+ buildAverageDiffs();
+ } catch (TasteException te) {
+ log.log(Level.WARNING, "Unexpected exception while refreshing", te);
+ }
+ } finally {
+ refreshLock.unlock();
+ }
+ }
+
+ private static class FixedRunningAverage implements RunningAverage {
+
+ private final int count;
+ private final double average;
+
+ private FixedRunningAverage(int count, double average) {
+ this.count = count;
+ this.average = average;
+ }
+
+ public void addDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ public void removeDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ public void changeDatum(double delta) {
+ throw new UnsupportedOperationException();
+ }
+
+ public int getCount() {
+ return count;
+ }
+
+ public double getAverage() {
+ return average;
+ }
+ }
+
+}
\ No newline at end of file
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java Fri May 9 14:35:12 2008
@@ -0,0 +1,154 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.slopeone.jdbc;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.model.jdbc.AbstractJDBCDataModel;
+import org.apache.mahout.cf.taste.impl.model.jdbc.MySQLJDBCDataModel;
+
+/**
+ * <p>MySQL-specific implementation. Should be used in conjunction with a
+ * {@link MySQLJDBCDataModel}. This implementation stores item-item diffs in a MySQL
+ * database and encapsulates some other slope-one-specific operations that are needed
+ * on the preference data in the database. It assumes the database has a schema like:</p>
+ *
+ * <table>
+ * <tr><th>item_id_a</th><th>item_id_b</th><th>average_diff</th><th>count</th></tr>
+ * <tr><td>123</td><td>234</td><td>0.5</td><td>5</td></tr>
+ * <tr><td>123</td><td>789</td><td>-1.33</td><td>3</td></tr>
+ * <tr><td>234</td><td>789</td><td>2.1</td><td>1</td></tr>
+ * </table>
+ *
+ * <p><code>item_id_a</code> and <code>item_id_b</code> must have type compatible with
+ * the Java <code>String</code> type. <code>average_diff</code> must be compatible with
+ * <code>double</code> and <code>count</code> must be compatible with <code>int</code>.</p>
+ *
+ * <p>The following command sets up a suitable table in MySQL:</p>
+ *
+ * <pre>
+ * CREATE TABLE taste_slopeone_diffs (
+ * item_id_a VARCHAR(10) NOT NULL,
+ * item_id_b VARCHAR(10) NOT NULL,
+ * average_diff FLOAT NOT NULL,
+ * count INT NOT NULL,
+ * PRIMARY KEY (item_id_a, item_id_b),
+ * INDEX (item_id_a),
+ * INDEX (item_id_b)
+ * )
+ * </pre>
+ */
+public final class MySQLJDBCDiffStorage extends AbstractJDBCDiffStorage {
+
+ private static final int DEFAULT_MIN_DIFF_COUNT = 2;
+
+ public MySQLJDBCDiffStorage(MySQLJDBCDataModel dataModel) throws TasteException {
+ this(dataModel,
+ AbstractJDBCDataModel.DEFAULT_PREFERENCE_TABLE,
+ AbstractJDBCDataModel.DEFAULT_USER_ID_COLUMN,
+ AbstractJDBCDataModel.DEFAULT_ITEM_ID_COLUMN,
+ AbstractJDBCDataModel.DEFAULT_PREFERENCE_COLUMN,
+ DEFAULT_DIFF_TABLE,
+ DEFAULT_ITEM_A_COLUMN,
+ DEFAULT_ITEM_B_COLUMN,
+ DEFAULT_COUNT_COLUMN,
+ DEFAULT_AVERAGE_DIFF_COLUMN,
+ DEFAULT_MIN_DIFF_COUNT);
+ }
+
+ public MySQLJDBCDiffStorage(MySQLJDBCDataModel dataModel,
+ String preferenceTable,
+ String userIDColumn,
+ String itemIDColumn,
+ String preferenceColumn,
+ String diffsTable,
+ String itemIDAColumn,
+ String itemIDBColumn,
+ String countColumn,
+ String avgColumn,
+ int minDiffCount) throws TasteException {
+ super(dataModel,
+ // getDiffSQL
+ "SELECT " + countColumn + ", " + avgColumn + " FROM " + diffsTable +
+ " WHERE " + itemIDAColumn + "=? AND " + itemIDBColumn + "=? UNION " +
+ "SELECT " + countColumn + ", " + avgColumn + " FROM " + diffsTable +
+ " WHERE " + itemIDAColumn + "=? AND " + itemIDBColumn + "=?",
+ // getDiffsSQL
+ "SELECT " + countColumn + ", " + avgColumn + ", " + itemIDAColumn + " FROM " + diffsTable + ", " +
+ preferenceTable + " WHERE " + itemIDBColumn + "=? AND " + itemIDAColumn + " = " + itemIDColumn +
+ " AND " + userIDColumn + "=? ORDER BY " + itemIDAColumn,
+ // getAverageItemPrefSQL
+ "SELECT COUNT(1), AVG(" + preferenceColumn + ") FROM " + preferenceTable +
+ " WHERE " + itemIDColumn + "=?",
+ // updateDiffSQLs
+ new String[]{
+ "UPDATE " + diffsTable + " SET " + avgColumn + " = " + avgColumn + " - (? / " + countColumn +
+ ") WHERE " + itemIDAColumn + "=?",
+ "UPDATE " + diffsTable + " SET " + avgColumn + " = " + avgColumn + " + (? / " + countColumn +
+ ") WHERE " + itemIDBColumn + "=?"
+ },
+ // removeDiffSQL
+ new String[]{
+ "UPDATE " + diffsTable + " SET " + countColumn + " = " + countColumn + "-1, " +
+ avgColumn + " = " + avgColumn + " * ((" + countColumn + " + 1) / CAST(" + countColumn +
+ " AS DECIMAL)) + ? / CAST(" + countColumn + " AS DECIMAL) WHERE " + itemIDAColumn + "=?",
+ "UPDATE " + diffsTable + " SET " + countColumn + " = " + countColumn + "-1, " +
+ avgColumn + " = " + avgColumn + " * ((" + countColumn + " + 1) / CAST(" + countColumn +
+ " AS DECIMAL)) - ? / CAST(" + countColumn + " AS DECIMAL) WHERE " + itemIDBColumn + "=?"
+ },
+ // getRecommendableItemsSQL
+ "SELECT id FROM " +
+ "(SELECT " + itemIDAColumn + " AS id FROM " + diffsTable + ", " + preferenceTable +
+ " WHERE " + itemIDBColumn + " = item_id AND " + userIDColumn + "=? UNION DISTINCT" +
+ " SELECT " + itemIDBColumn + " AS id FROM " + diffsTable + ", " + preferenceTable +
+ " WHERE " + itemIDAColumn + " = item_id AND " + userIDColumn +
+ "=?) possible_item_ids WHERE id NOT IN (SELECT " + itemIDColumn + " FROM " + preferenceTable +
+ " WHERE " + userIDColumn + "=?)",
+ // deleteDiffsSQL
+ "DELETE FROM " + diffsTable,
+ // createDiffsSQL
+ "INSERT INTO " + diffsTable + " (" + itemIDAColumn + ", " + itemIDBColumn + ", " + avgColumn +
+ ", " + countColumn + ") SELECT prefsA." + itemIDColumn + ", prefsB." + itemIDColumn + ',' +
+ " AVG(prefsB." + preferenceColumn + " - prefsA." + preferenceColumn + ")," +
+ " COUNT(1) AS count FROM " + preferenceTable + " prefsA, " + preferenceTable + " prefsB WHERE prefsA." +
+ userIDColumn + " = prefsB." + userIDColumn + " AND prefsA." + itemIDColumn + " < prefsB." +
+ itemIDColumn + ' ' + " GROUP BY prefsA." + itemIDColumn +
+ ", prefsB." + itemIDColumn + " HAVING count >=?",
+ // diffsExistSQL
+ "SELECT COUNT(1) FROM " + diffsTable,
+ minDiffCount);
+ }
+
+ /*
+ public static void main(final String... args) throws Exception {
+ Logger.getLogger("org.apache.mahout.cf.taste").setLevel(Level.FINE);
+ final MysqlDataSource dataSource = new MysqlDataSource();
+ dataSource.setUser("mysql");
+ dataSource.setDatabaseName("test");
+ dataSource.setServerName("localhost");
+ final DataSource pooledDataSource = new ConnectionPoolDataSource(dataSource);
+ final MySQLJDBCDataModel model = new MySQLJDBCDataModel(pooledDataSource);
+ final MySQLJDBCDiffStorage diffStorage = new MySQLJDBCDiffStorage(model);
+ final Recommender slopeOne = new SlopeOneRecommender(model, true, false, diffStorage);
+ long start = System.currentTimeMillis();
+ System.out.println(slopeOne.recommend(args[0], 20));
+ long end = System.currentTimeMillis();
+ System.out.println(end - start);
+ }
+ */
+
+}
\ No newline at end of file
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/CaseAmplification.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/CaseAmplification.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/CaseAmplification.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/CaseAmplification.java Fri May 9 14:35:12 2008
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.transforms;
+
+import org.apache.mahout.cf.taste.transforms.CorrelationTransform;
+
+/**
+ * <p>Applies "case amplification" to correlations. This essentially makes big values bigger
+ * and small values smaller by raising each score to a power. It could however be used to achieve the
+ * opposite effect.</p>
+ */
+public final class CaseAmplification implements CorrelationTransform<Object> {
+
+ private final double factor;
+
+ /**
+ * <p>Creates a {@link CaseAmplification} transformation based on the given factor.</p>
+ *
+ * @param factor transformation factor
+ * @throws IllegalArgumentException if factor is 0.0 or {@link Double#NaN}
+ */
+ public CaseAmplification(double factor) {
+ if (Double.isNaN(factor) || factor == 0.0) {
+ throw new IllegalArgumentException("factor is 0 or NaN");
+ }
+ this.factor = factor;
+ }
+
+ /**
+ * <p>Transforms one correlation value. This implementation is such that it's possible to define this
+ * transformation on one value in isolation. The "thing" parameters are therefore unused.</p>
+ *
+ * @param thing1 unused
+ * @param thing2 unused
+ * @param value correlation to transform
+ * @return <code>value<sup>factor</sup></code> if value is nonnegative;
+ * <code>-value<sup>-factor</sup></code> otherwise
+ */
+ public double transformCorrelation(Object thing1, Object thing2, double value) {
+ return value < 0.0 ? -Math.pow(-value, factor) : Math.pow(value, factor);
+ }
+
+ public void refresh() {
+ // do nothing
+ }
+
+ @Override
+ public String toString() {
+ return "CaseAmplification[factor:" + factor + ']';
+ }
+
+}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/Counters.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/Counters.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/Counters.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/Counters.java Fri May 9 14:35:12 2008
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.transforms;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * <p>A simple, fast utility class that maps keys to counts.</p>
+ */
+final class Counters<T> {
+
+ private final Map<T, MutableInteger> counts = new HashMap<T, MutableInteger>(1009);
+
+ void increment(T key) {
+ MutableInteger count = counts.get(key);
+ if (count == null) {
+ MutableInteger newCount = new MutableInteger();
+ newCount.value = 1;
+ counts.put(key, newCount);
+ } else {
+ count.value++;
+ }
+ }
+
+ int getCount(T key) {
+ MutableInteger count = counts.get(key);
+ return count == null ? 0 : count.value;
+ }
+
+ int size() {
+ return counts.size();
+ }
+
+ Iterable<Map.Entry<T, MutableInteger>> getEntrySet() {
+ return counts.entrySet();
+ }
+
+ @Override
+ public String toString() {
+ return "Counters[" + counts + ']';
+ }
+
+ static final class MutableInteger {
+
+ // This is intentionally package-private in order to allow access from the containing Counters class
+ // without making the compiler generate a synthetic accessor
+ int value;
+
+ @Override
+ public String toString() {
+ return "MutableInteger[" + value + ']';
+ }
+ }
+
+}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/InverseUserFrequency.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/InverseUserFrequency.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/InverseUserFrequency.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/InverseUserFrequency.java Fri May 9 14:35:12 2008
@@ -0,0 +1,122 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.transforms;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Item;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.User;
+import org.apache.mahout.cf.taste.transforms.PreferenceTransform2;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * <p>Implements an "inverse user frequency" transformation, which boosts preference values for items for which few
+ * users have expressed a preference, and reduces preference values for items for which many users have expressed
+ * a preference. The idea is that these "rare" {@link Item}s are more useful in deciding how similar two users'
+ * tastes are, and so should be emphasized in other calculatioons. This idea is mentioned in
+ * <a href="ftp://ftp.research.microsoft.com/pub/tr/tr-98-12.pdf">Empirical Analysis of Predictive Algorithms for
+ * Collaborative Filtering</a>.</p>
+ *
+ * <p>A scaling factor is computed for each {@link Item} by dividing the total number of users by the number of
+ * users expressing a preference for that item, and taking the log of that value. The log base of this calculation
+ * can be controlled in the constructor. Intuitively, the right value for the base is equal to the average
+ * number of users who express a preference for each item in your model. If each item has about 100 preferences
+ * on average, 100.0 is a good log base.</p>
+ */
+public final class InverseUserFrequency implements PreferenceTransform2 {
+
+ private static final Logger log = Logger.getLogger(InverseUserFrequency.class.getName());
+
+ private final DataModel dataModel;
+ private final double logBase;
+ private final AtomicReference<Map<Item, Double>> iufFactors;
+
+ /**
+ * <p>Creates a {@link InverseUserFrequency} transformation. Computations use the given log base.</p>
+ *
+ * @param dataModel {@link DataModel} from which to calculate user frequencies
+ * @param logBase calculation logarithm base
+ * @throws IllegalArgumentException if dataModel is <code>null</code> or logBase is {@link Double#NaN} or <= 1.0
+ */
+ public InverseUserFrequency(DataModel dataModel, double logBase) {
+ if (dataModel == null) {
+ throw new IllegalArgumentException("dataModel is null");
+ }
+ if (Double.isNaN(logBase) || logBase <= 1.0) {
+ throw new IllegalArgumentException("logBase is NaN or <= 1.0");
+ }
+ this.dataModel = dataModel;
+ this.logBase = logBase;
+ this.iufFactors = new AtomicReference<Map<Item, Double>>(new HashMap<Item, Double>(1009));
+ refresh();
+ }
+
+ /**
+ * @return log base used in this object's calculations
+ */
+ public double getLogBase() {
+ return logBase;
+ }
+
+ public double getTransformedValue(Preference pref) {
+ Double factor = iufFactors.get().get(pref.getItem());
+ if (factor != null) {
+ return pref.getValue() * factor;
+ }
+ return pref.getValue();
+ }
+
+ public void refresh() {
+ try {
+ Counters<Item> itemPreferenceCounts = new Counters<Item>();
+ synchronized (this) {
+ int numUsers = 0;
+ for (User user : dataModel.getUsers()) {
+ Preference[] prefs = user.getPreferencesAsArray();
+ for (int i = 0; i < prefs.length; i++) {
+ itemPreferenceCounts.increment(prefs[i].getItem());
+ }
+ numUsers++;
+ }
+ Map<Item, Double> newIufFactors =
+ new HashMap<Item, Double>(1 + (4 * itemPreferenceCounts.size()) / 3, 0.75f);
+ double logFactor = Math.log(logBase);
+ for (Map.Entry<Item, Counters.MutableInteger> entry : itemPreferenceCounts.getEntrySet()) {
+ newIufFactors.put(entry.getKey(),
+ Math.log((double) numUsers / (double) entry.getValue().value) / logFactor);
+ }
+ iufFactors.set(Collections.unmodifiableMap(newIufFactors));
+ }
+ } catch (TasteException dme) {
+ log.log(Level.WARNING, "Unable to refresh", dme);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "InverseUserFrequency[logBase:" + logBase + ']';
+ }
+
+}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/ZScore.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/ZScore.java?rev=654943&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/ZScore.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/transforms/ZScore.java Fri May 9 14:35:12 2008
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.transforms;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.SoftCache;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.User;
+import org.apache.mahout.cf.taste.transforms.PreferenceTransform2;
+
+/**
+ * <p>Normalizes preference values for a {@link User} by converting them to
+ * <a href="http://mathworld.wolfram.com/z-Score.html">"z-scores"</a>. This process
+ * normalizes preference values to adjust for variation in mean and variance of a
+ * user's preferences.</p>
+ *
+ * <p>Imagine two users, one who tends to rate every movie he/she sees four or five stars,
+ * and another who uses the full one to five star range when assigning ratings. This
+ * transform normalizes away the difference in scale used by the two users so that both
+ * have a mean preference of 0.0 and a standard deviation of 1.0.</p>
+ */
+public final class ZScore implements PreferenceTransform2 {
+
+ private final SoftCache<User, RunningAverageAndStdDev> meanAndStdevs;
+
+ public ZScore() {
+ this.meanAndStdevs = new SoftCache<User, RunningAverageAndStdDev>(new MeanStdevRetriever());
+ refresh();
+ }
+
+ public double getTransformedValue(Preference pref) throws TasteException {
+ RunningAverageAndStdDev meanAndStdev = meanAndStdevs.get(pref.getUser());
+ if (meanAndStdev.getCount() > 1) {
+ double stdev = meanAndStdev.getStandardDeviation();
+ if (stdev > 0.0) {
+ return (pref.getValue() - meanAndStdev.getAverage()) / stdev;
+ }
+ }
+ return 0.0;
+ }
+
+ public void refresh() {
+ // do nothing
+ }
+
+ @Override
+ public String toString() {
+ return "ZScore";
+ }
+
+ private static class MeanStdevRetriever implements SoftCache.Retriever<User, RunningAverageAndStdDev> {
+
+ public RunningAverageAndStdDev getValue(User user) throws TasteException {
+ RunningAverageAndStdDev running = new FullRunningAverageAndStdDev();
+ Preference[] prefs = user.getPreferencesAsArray();
+ for (int i = 0; i < prefs.length; i++) {
+ running.addDatum(prefs[i].getValue());
+ }
+ return running;
+ }
+ }
+
+}