You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2011/02/03 13:51:07 UTC

svn commit: r1066799 - in /mahout/trunk: core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/ core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/ core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/ core/src/test/jav...

Author: srowen
Date: Thu Feb  3 12:51:06 2011
New Revision: 1066799

URL: http://svn.apache.org/viewvc?rev=1066799&view=rev
Log:
MAHOUT-603 Standardize log-likelihood implementation

Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedLoglikelihoodVectorSimilarity.java
    mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedVectorSimilarityTestCase.java
    mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java
    mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/LogLikelihoodTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java Thu Feb  3 12:51:06 2011
@@ -27,6 +27,7 @@ import org.apache.mahout.cf.taste.model.
 import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
 import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
 import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import org.apache.mahout.math.stats.LogLikelihood;
 
 /**
  * See <a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.14.5962">
@@ -64,10 +65,11 @@ public final class LogLikelihoodSimilari
       return Double.NaN;
     }
     int numItems = dataModel.getNumItems();
-    double logLikelihood = twoLogLambda(intersectionSize,
-                                        prefs1Size - intersectionSize,
-                                        prefs2Size,
-                                        numItems - prefs2Size);
+    double logLikelihood =
+        LogLikelihood.logLikelihoodRatio(intersectionSize,
+                                         prefs1Size - intersectionSize,
+                                         prefs2Size - intersectionSize,
+                                         numItems - prefs1Size - prefs2Size + intersectionSize);
     return 1.0 - 1.0 / (1.0 + logLikelihood);
   }
   
@@ -96,29 +98,14 @@ public final class LogLikelihoodSimilari
       return Double.NaN;
     }
     int preferring2 = dataModel.getNumUsersWithPreferenceFor(itemID2);
-    double logLikelihood = twoLogLambda(preferring1and2,
-                                        preferring1 - preferring1and2,
-                                        preferring2,
-                                        numUsers - preferring2);
+    double logLikelihood =
+        LogLikelihood.logLikelihoodRatio(preferring1and2,
+                                         preferring1 - preferring1and2,
+                                         preferring2 - preferring1and2,
+                                         numUsers - preferring1 - preferring2 + preferring1and2);
     return 1.0 - 1.0 / (1.0 + logLikelihood);
   }
-  
-  static double twoLogLambda(double k1, double k2, double n1, double n2) {
-    double p = (k1 + k2) / (n1 + n2);
-    return 2.0 * (logL(k1 / n1, k1, n1)
-                  + logL(k2 / n2, k2, n2)
-                  - logL(p, k1, n1)
-                  - logL(p, k2, n2));
-  }
-  
-  private static double logL(double p, double k, double n) {
-    return k * safeLog(p) + (n - k) * safeLog(1.0 - p);
-  }
-  
-  private static double safeLog(double d) {
-    return d <= 0.0 ? 0.0 : Math.log(d);
-  }
-  
+
   @Override
   public void refresh(Collection<Refreshable> alreadyRefreshed) {
     alreadyRefreshed = RefreshHelper.buildRefreshed(alreadyRefreshed);

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedLoglikelihoodVectorSimilarity.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedLoglikelihoodVectorSimilarity.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedLoglikelihoodVectorSimilarity.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedLoglikelihoodVectorSimilarity.java Thu Feb  3 12:51:06 2011
@@ -19,6 +19,7 @@ package org.apache.mahout.math.hadoop.si
 
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.hadoop.similarity.Cooccurrence;
+import org.apache.mahout.math.stats.LogLikelihood;
 
 /**
  * distributed implementation of loglikelihood as vector similarity measure
@@ -38,10 +39,11 @@ public class DistributedLoglikelihoodVec
     int occurrencesA = (int) weightOfVectorA;
     int occurrencesB = (int) weightOfVectorB;
 
-    double logLikelihood = twoLogLambda(cooccurrenceCount,
-                                        occurrencesA - cooccurrenceCount,
-                                        occurrencesB,
-                                        numberOfColumns - occurrencesB);
+    double logLikelihood =
+        LogLikelihood.logLikelihoodRatio(cooccurrenceCount,
+                                         occurrencesA - cooccurrenceCount,
+                                         occurrencesB - cooccurrenceCount,
+                                         numberOfColumns - occurrencesA - occurrencesB + cooccurrenceCount);
 
     return 1.0 - 1.0 / (1.0 + logLikelihood);
   }
@@ -51,19 +53,4 @@ public class DistributedLoglikelihoodVec
     return (double) countElements(v.iterateNonZero());
   }
 
-  private static double twoLogLambda(double k1, double k2, double n1, double n2) {
-    double p = (k1 + k2) / (n1 + n2);
-    return 2.0 * (logL(k1 / n1, k1, n1)
-                  + logL(k2 / n2, k2, n2)
-                  - logL(p, k1, n1)
-                  - logL(p, k2, n2));
-  }
-
-  private static double logL(double p, double k, double n) {
-    return k * safeLog(p) + (n - k) * safeLog(1.0 - p);
-  }
-
-  private static double safeLog(double d) {
-    return d <= 0.0 ? 0.0 : Math.log(d);
-  }
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java Thu Feb  3 12:51:06 2011
@@ -35,23 +35,40 @@ public final class LogLikelihoodSimilari
                     {null, 1.0, 1.0, 1.0, 1.0},
             });
 
-    double correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(1, 0);
-    assertCorrelationEquals(0.12160727029227925, correlation);
+    LogLikelihoodSimilarity similarity = new LogLikelihoodSimilarity(dataModel);
 
-    correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(0, 1);
-    assertCorrelationEquals(0.12160727029227925, correlation);
+    assertCorrelationEquals(0.12160727029227925, similarity.itemSimilarity(1, 0));
+    assertCorrelationEquals(0.12160727029227925, similarity.itemSimilarity(0, 1));
 
-    correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(2, 1);
-    assertCorrelationEquals(0.5423213660693733, correlation);
+    assertCorrelationEquals(0.5423213660693732, similarity.itemSimilarity(1, 2));
+    assertCorrelationEquals(0.5423213660693732, similarity.itemSimilarity(2, 1));
 
-    correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(2, 3);
-    assertCorrelationEquals(0.6905400104897509, correlation);
+    assertCorrelationEquals(0.6905400104897509, similarity.itemSimilarity(2, 3));
+    assertCorrelationEquals(0.6905400104897509, similarity.itemSimilarity(3, 2));
 
-    correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(3, 4);
-    assertCorrelationEquals(0.8706358464330881, correlation);
+    assertCorrelationEquals(0.8706358464330881, similarity.itemSimilarity(3, 4));
+    assertCorrelationEquals(0.8706358464330881, similarity.itemSimilarity(4, 3));
+  }
+
+  @Test
+  public void testNoSimilarity() throws Exception {
+
+    DataModel dataModel = getDataModel(
+        new long[] {1, 2, 3, 4},
+        new Double[][] {
+                {1.0, null, 1.0, 1.0},
+                {1.0, null, 1.0, 1.0},
+                {null, 1.0, 1.0, 1.0},
+                {null, 1.0, 1.0, 1.0},
+        });
+
+    LogLikelihoodSimilarity similarity = new LogLikelihoodSimilarity(dataModel);
+
+    assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(1, 0));
+    assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(0, 1));
 
-    correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(4, 3);
-    assertCorrelationEquals(0.8706358464330881, correlation);
+    assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(2, 3));
+    assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(3, 2));
   }
 
   @Test

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedVectorSimilarityTestCase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedVectorSimilarityTestCase.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedVectorSimilarityTestCase.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedVectorSimilarityTestCase.java Thu Feb  3 12:51:06 2011
@@ -66,6 +66,10 @@ public abstract class DistributedVectorS
     }
 
     double result = similarity.similarity(rowA, rowB, cooccurrences, weightA, weightB, numberOfColumns);
-    assertEquals(expectedSimilarity, result, EPSILON);
+    if (Double.isNaN(expectedSimilarity)) {
+      assertTrue(Double.isNaN(result));
+    } else {
+      assertEquals(expectedSimilarity, result, EPSILON);
+    }
   }
 }

Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java Thu Feb  3 12:51:06 2011
@@ -24,6 +24,7 @@ import com.google.common.collect.Orderin
 import java.util.Collections;
 import java.util.List;
 import java.util.PriorityQueue;
+import java.util.Queue;
 
 /**
  * Utility methods for working with log-likelihood
@@ -126,7 +127,10 @@ public final class LogLikelihood {
    * in a than b.  Use -Double.MAX_VALUE (not Double.MIN_VALUE !) to not use a threshold.
    * @return  A list of scored items with their scores.
    */
-  public static <T> List<ScoredItem<T>> compareFrequencies(Multiset<T> a, Multiset<T> b, int maxReturn, double threshold) {
+  public static <T> List<ScoredItem<T>> compareFrequencies(Multiset<T> a,
+                                                           Multiset<T> b,
+                                                           int maxReturn,
+                                                           double threshold) {
     int totalA = a.size();
     int totalB = b.size();
 
@@ -135,7 +139,7 @@ public final class LogLikelihood {
         return Double.compare(tScoredItem.score, tScoredItem1.score);
       }
     };
-    PriorityQueue<ScoredItem<T>> best = new PriorityQueue<ScoredItem<T>>(maxReturn + 1, byScoreAscending);
+    Queue<ScoredItem<T>> best = new PriorityQueue<ScoredItem<T>>(maxReturn + 1, byScoreAscending);
 
     for (T t : a.elementSet()) {
       compareAndAdd(a, b, maxReturn, threshold, totalA, totalB, best, t);
@@ -156,7 +160,14 @@ public final class LogLikelihood {
     return r;
   }
 
-  private static <T> void compareAndAdd(Multiset<T> a, Multiset<T> b, int maxReturn, double threshold, int totalA, int totalB, PriorityQueue<ScoredItem<T>> best, T t) {
+  private static <T> void compareAndAdd(Multiset<T> a,
+                                        Multiset<T> b,
+                                        int maxReturn,
+                                        double threshold,
+                                        int totalA,
+                                        int totalB,
+                                        Queue<ScoredItem<T>> best,
+                                        T t) {
     int kA = a.count(t);
     int kB = b.count(t);
     double score = rootLogLikelihoodRatio(kA, totalA - kA, kB, totalB - kB);
@@ -169,13 +180,21 @@ public final class LogLikelihood {
     }
   }
 
-  public final static class ScoredItem<T> {
-    public T item;
-    public double score;
+  public static final class ScoredItem<T> {
+    private final T item;
+    private final double score;
 
     public ScoredItem(T item, double score) {
       this.item = item;
       this.score = score;
     }
+
+    public double getScore() {
+      return score;
+    }
+
+    public T getItem() {
+      return item;
+    }
   }
 }

Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/LogLikelihoodTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/LogLikelihoodTest.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/LogLikelihoodTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/LogLikelihoodTest.java Thu Feb  3 12:51:06 2011
@@ -114,19 +114,19 @@ public final class LogLikelihoodTest ext
     // comparing frequencies, we should be able to find 8 items with score > 0
     List<LogLikelihood.ScoredItem<Integer>> r = LogLikelihood.compareFrequencies(w1, w2, 8, 0);
     assertTrue(r.size() <= 8);
-    assertTrue(r.size() > 0);
+    assertFalse(r.isEmpty());
     for (LogLikelihood.ScoredItem<Integer> item : r) {
-      assertTrue(item.score >= 0);
+      assertTrue(item.getScore() >= 0);
     }
 
     // the most impressive should be 7
-    assertEquals(7, (int) r.get(0).item);
+    assertEquals(7, (int) r.get(0).getItem());
 
     // make sure scores are descending
-    double lastScore = r.get(0).score;
+    double lastScore = r.get(0).getScore();
     for (LogLikelihood.ScoredItem<Integer> item : r) {
-      assertTrue(item.score <= lastScore);
-      lastScore = item.score;
+      assertTrue(item.getScore() <= lastScore);
+      lastScore = item.getScore();
     }
 
     // now as many as have score >= 1
@@ -134,14 +134,14 @@ public final class LogLikelihoodTest ext
 
     // only the boosted items should make the cut
     assertEquals(3, r.size());
-    assertEquals(7, (int) r.get(0).item);
-    assertEquals(5, (int) r.get(1).item);
-    assertEquals(6, (int) r.get(2).item);
+    assertEquals(7, (int) r.get(0).getItem());
+    assertEquals(5, (int) r.get(1).getItem());
+    assertEquals(6, (int) r.get(2).getItem());
 
     r = LogLikelihood.compareFrequencies(w1, w2, 1000, -100);
     Multiset<Integer> k = HashMultiset.create();
     for (LogLikelihood.ScoredItem<Integer> item : r) {
-      k.add(item.item);
+      k.add(item.getItem());
     }
     for (int i = 0; i < 25; i++) {
       assertTrue("i = " + i, k.count(i) == 1 || w2.count(i) == 0);
@@ -149,18 +149,18 @@ public final class LogLikelihoodTest ext
 
     // all values that had non-zero counts in larger set should have result scores
     assertEquals(w2.elementSet().size(), r.size());
-    assertEquals(7, (int) r.get(0).item);
-    assertEquals(5, (int) r.get(1).item);
-    assertEquals(6, (int) r.get(2).item);
+    assertEquals(7, (int) r.get(0).getItem());
+    assertEquals(5, (int) r.get(1).getItem());
+    assertEquals(6, (int) r.get(2).getItem());
     
     // the last item should definitely have negative score
-    assertTrue(r.get(r.size() - 1).score < 0);
+    assertTrue(r.get(r.size() - 1).getScore() < 0);
 
     // make sure scores are descending
-    lastScore = r.get(0).score;
+    lastScore = r.get(0).getScore();
     for (LogLikelihood.ScoredItem<Integer> item : r) {
-      assertTrue(item.score <= lastScore);
-      lastScore = item.score;
+      assertTrue(item.getScore() <= lastScore);
+      lastScore = item.getScore();
     }
   }
 
@@ -170,7 +170,7 @@ public final class LogLikelihoodTest ext
    * @param rand   A random number generator.
    * @return  A single sample from the multinomial distribution.
    */
-  private int sample(Vector p, Random rand) {
+  private static int sample(Vector p, Random rand) {
     double u = rand.nextDouble();
 
     // simple sequential algorithm.  Not the fastest, but we don't care