You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2008/08/08 05:52:03 UTC

svn commit: r683831 - in /lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood: CachingUserNeighborhood.java NearestNUserNeighborhood.java ThresholdUserNeighborhood.java

Author: srowen
Date: Thu Aug  7 20:52:02 2008
New Revision: 683831

URL: http://svn.apache.org/viewvc?rev=683831&view=rev
Log:
Factor out caching in UserNeighborhood implementations into a wrapper

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/CachingUserNeighborhood.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNUserNeighborhood.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdUserNeighborhood.java

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/CachingUserNeighborhood.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/CachingUserNeighborhood.java?rev=683831&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/CachingUserNeighborhood.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/CachingUserNeighborhood.java Thu Aug  7 20:52:02 2008
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.neighborhood;
+
+import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
+import org.apache.mahout.cf.taste.model.User;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.Cache;
+import org.apache.mahout.cf.taste.impl.common.Retriever;
+
+import java.util.Collection;
+
+/**
+ * A caching wrapper around an underlying {@link UserNeighborhood} implementation.
+ */
+public final class CachingUserNeighborhood implements UserNeighborhood {
+
+  private final UserNeighborhood neighborhood;
+  private final Cache<Object, Collection<User>> neighborhoodCache;
+
+  public CachingUserNeighborhood(UserNeighborhood neighborhood) {
+    if (neighborhood == null) {
+      throw new IllegalArgumentException("neighborhood is null");
+    }
+    this.neighborhood = neighborhood;
+    this.neighborhoodCache = new Cache<Object, Collection<User>>(new NeighborhoodRetriever(neighborhood));
+  }
+
+  public Collection<User> getUserNeighborhood(Object userID) throws TasteException {
+    return neighborhoodCache.get(userID);
+  }
+
+  public void refresh() {
+    neighborhoodCache.clear();
+    neighborhood.refresh();
+  }
+
+  private static final class NeighborhoodRetriever implements Retriever<Object, Collection<User>> {
+    private final UserNeighborhood neighborhood;
+    private NeighborhoodRetriever(UserNeighborhood neighborhood) {
+      this.neighborhood = neighborhood;
+    }
+    public Collection<User> get(Object key) throws TasteException {
+      return neighborhood.getUserNeighborhood(key);
+    }
+  }
+
+}

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNUserNeighborhood.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNUserNeighborhood.java?rev=683831&r1=683830&r2=683831&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNUserNeighborhood.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNUserNeighborhood.java Thu Aug  7 20:52:02 2008
@@ -19,7 +19,6 @@
 
 import org.apache.mahout.cf.taste.common.TasteException;
 import org.apache.mahout.cf.taste.correlation.UserCorrelation;
-import org.apache.mahout.cf.taste.impl.common.Cache;
 import org.apache.mahout.cf.taste.model.DataModel;
 import org.apache.mahout.cf.taste.model.User;
 import org.slf4j.Logger;
@@ -40,7 +39,7 @@
 
   private static final Logger log = LoggerFactory.getLogger(NearestNUserNeighborhood.class);
 
-  private final Cache<Object, Collection<User>> cache;
+  private final int n;
 
   /**
    * @param n neighborhood size
@@ -71,67 +70,53 @@
     if (n < 1) {
       throw new IllegalArgumentException("n must be at least 1");
     }
-    this.cache = new Cache<Object, Collection<User>>(new Retriever(n), dataModel.getNumUsers());
+    this.n = n;
   }
 
   public Collection<User> getUserNeighborhood(Object userID) throws TasteException {
-    return cache.get(userID);
-  }
-
-  @Override
-  public String toString() {
-    return "NearestNUserNeighborhood";
-  }
-
-
-  private final class Retriever implements org.apache.mahout.cf.taste.impl.common.Retriever<Object, Collection<User>> {
-
-    private final int n;
-
-    private Retriever(int n) {
-      this.n = n;
-    }
+    log.trace("Computing neighborhood around user ID '{}'", userID);
 
-    public Collection<User> get(Object key) throws TasteException {
-      log.trace("Computing neighborhood around user ID '{}'", key);
-
-      DataModel dataModel = getDataModel();
-      User theUser = dataModel.getUser(key);
-      UserCorrelation userCorrelationImpl = getUserCorrelation();
-
-      LinkedList<UserCorrelationPair> queue = new LinkedList<UserCorrelationPair>();
-      boolean full = false;
-      for (User user : dataModel.getUsers()) {
-        if (sampleForUser() && !key.equals(user.getID())) {
-          double theCorrelation = userCorrelationImpl.userCorrelation(theUser, user);
-          if (!Double.isNaN(theCorrelation) && (!full || theCorrelation > queue.getLast().theCorrelation)) {
-            ListIterator<UserCorrelationPair> iterator = queue.listIterator(queue.size());
-            while (iterator.hasPrevious()) {
-              if (theCorrelation <= iterator.previous().theCorrelation) {
-                iterator.next();
-                break;
-              }
-            }
-            iterator.add(new UserCorrelationPair(user, theCorrelation));
-            if (full) {
-              queue.removeLast();
-            } else if (queue.size() > n) {
-              full = true;
-              queue.removeLast();
+    DataModel dataModel = getDataModel();
+    User theUser = dataModel.getUser(userID);
+    UserCorrelation userCorrelationImpl = getUserCorrelation();
+
+    LinkedList<UserCorrelationPair> queue = new LinkedList<UserCorrelationPair>();
+    boolean full = false;
+    for (User user : dataModel.getUsers()) {
+      if (sampleForUser() && !userID.equals(user.getID())) {
+        double theCorrelation = userCorrelationImpl.userCorrelation(theUser, user);
+        if (!Double.isNaN(theCorrelation) && (!full || theCorrelation > queue.getLast().theCorrelation)) {
+          ListIterator<UserCorrelationPair> iterator = queue.listIterator(queue.size());
+          while (iterator.hasPrevious()) {
+            if (theCorrelation <= iterator.previous().theCorrelation) {
+              iterator.next();
+              break;
             }
           }
+          iterator.add(new UserCorrelationPair(user, theCorrelation));
+          if (full) {
+            queue.removeLast();
+          } else if (queue.size() > n) {
+            full = true;
+            queue.removeLast();
+          }
         }
       }
+    }
 
-      List<User> neighborhood = new ArrayList<User>(queue.size());
-      for (UserCorrelationPair pair : queue) {
-        neighborhood.add(pair.user);
-      }
+    List<User> neighborhood = new ArrayList<User>(queue.size());
+    for (UserCorrelationPair pair : queue) {
+      neighborhood.add(pair.user);
+    }
 
-      log.trace("UserNeighborhood around user ID '{}' is: {}", key, neighborhood);
+    log.trace("UserNeighborhood around user ID '{}' is: {}", userID, neighborhood);
 
-      return Collections.unmodifiableList(neighborhood);
-    }
+    return Collections.unmodifiableList(neighborhood);
+  }
+
+  @Override
+  public String toString() {
+    return "NearestNUserNeighborhood";
   }
 
   private static final class UserCorrelationPair implements Comparable<UserCorrelationPair> {

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdUserNeighborhood.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdUserNeighborhood.java?rev=683831&r1=683830&r2=683831&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdUserNeighborhood.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdUserNeighborhood.java Thu Aug  7 20:52:02 2008
@@ -19,7 +19,6 @@
 
 import org.apache.mahout.cf.taste.common.TasteException;
 import org.apache.mahout.cf.taste.correlation.UserCorrelation;
-import org.apache.mahout.cf.taste.impl.common.Cache;
 import org.apache.mahout.cf.taste.model.DataModel;
 import org.apache.mahout.cf.taste.model.User;
 import org.slf4j.Logger;
@@ -40,7 +39,7 @@
 
   private static final Logger log = LoggerFactory.getLogger(ThresholdUserNeighborhood.class);
 
-  private final Cache<Object, Collection<User>> cache;
+  private final double threshold;
 
   /**
    * @param threshold similarity threshold
@@ -74,50 +73,36 @@
     if (Double.isNaN(threshold)) {
       throw new IllegalArgumentException("threshold must not be NaN");
     }
-    this.cache = new Cache<Object, Collection<User>>(new Retriever(threshold), dataModel.getNumUsers());
+    this.threshold = threshold;
   }
 
   public Collection<User> getUserNeighborhood(Object userID) throws TasteException {
-    return cache.get(userID);
-  }
-
-  @Override
-  public String toString() {
-    return "ThresholdUserNeighborhood";
-  }
-
-
-  private final class Retriever implements org.apache.mahout.cf.taste.impl.common.Retriever<Object, Collection<User>> {
+    log.trace("Computing neighborhood around user ID '{}'", userID);
 
-    private final double threshold;
-
-    private Retriever(double threshold) {
-      this.threshold = threshold;
-    }
-
-    public Collection<User> get(Object key) throws TasteException {
-      log.trace("Computing neighborhood around user ID '{}'", key);
-
-      DataModel dataModel = getDataModel();
-      User theUser = dataModel.getUser(key);
-      List<User> neighborhood = new ArrayList<User>();
-      Iterator<? extends User> users = dataModel.getUsers().iterator();
-      UserCorrelation userCorrelationImpl = getUserCorrelation();
-
-      while (users.hasNext()) {
-        User user = users.next();
-        if (sampleForUser() && !key.equals(user.getID())) {
-          double theCorrelation = userCorrelationImpl.userCorrelation(theUser, user);
-          if (!Double.isNaN(theCorrelation) && theCorrelation >= threshold) {
-            neighborhood.add(user);
-          }
+    DataModel dataModel = getDataModel();
+    User theUser = dataModel.getUser(userID);
+    List<User> neighborhood = new ArrayList<User>();
+    Iterator<? extends User> users = dataModel.getUsers().iterator();
+    UserCorrelation userCorrelationImpl = getUserCorrelation();
+
+    while (users.hasNext()) {
+      User user = users.next();
+      if (sampleForUser() && !userID.equals(user.getID())) {
+        double theCorrelation = userCorrelationImpl.userCorrelation(theUser, user);
+        if (!Double.isNaN(theCorrelation) && theCorrelation >= threshold) {
+          neighborhood.add(user);
         }
       }
+    }
 
-      log.trace("UserNeighborhood around user ID '{}' is: {}", key, neighborhood);
+    log.trace("UserNeighborhood around user ID '{}' is: {}", userID, neighborhood);
 
-      return Collections.unmodifiableList(neighborhood);
-    }
+    return Collections.unmodifiableList(neighborhood);
+  }
+
+  @Override
+  public String toString() {
+    return "ThresholdUserNeighborhood";
   }
 
 }