You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2011/02/03 13:51:07 UTC
svn commit: r1066799 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/
core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/
core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/
core/src/test/jav...
Author: srowen
Date: Thu Feb 3 12:51:06 2011
New Revision: 1066799
URL: http://svn.apache.org/viewvc?rev=1066799&view=rev
Log:
MAHOUT-603 Standardize log-likelihood implementation
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedLoglikelihoodVectorSimilarity.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedVectorSimilarityTestCase.java
mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/LogLikelihoodTest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java Thu Feb 3 12:51:06 2011
@@ -27,6 +27,7 @@ import org.apache.mahout.cf.taste.model.
import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import org.apache.mahout.math.stats.LogLikelihood;
/**
* See <a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.14.5962">
@@ -64,10 +65,11 @@ public final class LogLikelihoodSimilari
return Double.NaN;
}
int numItems = dataModel.getNumItems();
- double logLikelihood = twoLogLambda(intersectionSize,
- prefs1Size - intersectionSize,
- prefs2Size,
- numItems - prefs2Size);
+ double logLikelihood =
+ LogLikelihood.logLikelihoodRatio(intersectionSize,
+ prefs1Size - intersectionSize,
+ prefs2Size - intersectionSize,
+ numItems - prefs1Size - prefs2Size + intersectionSize);
return 1.0 - 1.0 / (1.0 + logLikelihood);
}
@@ -96,29 +98,14 @@ public final class LogLikelihoodSimilari
return Double.NaN;
}
int preferring2 = dataModel.getNumUsersWithPreferenceFor(itemID2);
- double logLikelihood = twoLogLambda(preferring1and2,
- preferring1 - preferring1and2,
- preferring2,
- numUsers - preferring2);
+ double logLikelihood =
+ LogLikelihood.logLikelihoodRatio(preferring1and2,
+ preferring1 - preferring1and2,
+ preferring2 - preferring1and2,
+ numUsers - preferring1 - preferring2 + preferring1and2);
return 1.0 - 1.0 / (1.0 + logLikelihood);
}
-
- static double twoLogLambda(double k1, double k2, double n1, double n2) {
- double p = (k1 + k2) / (n1 + n2);
- return 2.0 * (logL(k1 / n1, k1, n1)
- + logL(k2 / n2, k2, n2)
- - logL(p, k1, n1)
- - logL(p, k2, n2));
- }
-
- private static double logL(double p, double k, double n) {
- return k * safeLog(p) + (n - k) * safeLog(1.0 - p);
- }
-
- private static double safeLog(double d) {
- return d <= 0.0 ? 0.0 : Math.log(d);
- }
-
+
@Override
public void refresh(Collection<Refreshable> alreadyRefreshed) {
alreadyRefreshed = RefreshHelper.buildRefreshed(alreadyRefreshed);
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedLoglikelihoodVectorSimilarity.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedLoglikelihoodVectorSimilarity.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedLoglikelihoodVectorSimilarity.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedLoglikelihoodVectorSimilarity.java Thu Feb 3 12:51:06 2011
@@ -19,6 +19,7 @@ package org.apache.mahout.math.hadoop.si
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.hadoop.similarity.Cooccurrence;
+import org.apache.mahout.math.stats.LogLikelihood;
/**
* distributed implementation of loglikelihood as vector similarity measure
@@ -38,10 +39,11 @@ public class DistributedLoglikelihoodVec
int occurrencesA = (int) weightOfVectorA;
int occurrencesB = (int) weightOfVectorB;
- double logLikelihood = twoLogLambda(cooccurrenceCount,
- occurrencesA - cooccurrenceCount,
- occurrencesB,
- numberOfColumns - occurrencesB);
+ double logLikelihood =
+ LogLikelihood.logLikelihoodRatio(cooccurrenceCount,
+ occurrencesA - cooccurrenceCount,
+ occurrencesB - cooccurrenceCount,
+ numberOfColumns - occurrencesA - occurrencesB + cooccurrenceCount);
return 1.0 - 1.0 / (1.0 + logLikelihood);
}
@@ -51,19 +53,4 @@ public class DistributedLoglikelihoodVec
return (double) countElements(v.iterateNonZero());
}
- private static double twoLogLambda(double k1, double k2, double n1, double n2) {
- double p = (k1 + k2) / (n1 + n2);
- return 2.0 * (logL(k1 / n1, k1, n1)
- + logL(k2 / n2, k2, n2)
- - logL(p, k1, n1)
- - logL(p, k2, n2));
- }
-
- private static double logL(double p, double k, double n) {
- return k * safeLog(p) + (n - k) * safeLog(1.0 - p);
- }
-
- private static double safeLog(double d) {
- return d <= 0.0 ? 0.0 : Math.log(d);
- }
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java Thu Feb 3 12:51:06 2011
@@ -35,23 +35,40 @@ public final class LogLikelihoodSimilari
{null, 1.0, 1.0, 1.0, 1.0},
});
- double correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(1, 0);
- assertCorrelationEquals(0.12160727029227925, correlation);
+ LogLikelihoodSimilarity similarity = new LogLikelihoodSimilarity(dataModel);
- correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(0, 1);
- assertCorrelationEquals(0.12160727029227925, correlation);
+ assertCorrelationEquals(0.12160727029227925, similarity.itemSimilarity(1, 0));
+ assertCorrelationEquals(0.12160727029227925, similarity.itemSimilarity(0, 1));
- correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(2, 1);
- assertCorrelationEquals(0.5423213660693733, correlation);
+ assertCorrelationEquals(0.5423213660693732, similarity.itemSimilarity(1, 2));
+ assertCorrelationEquals(0.5423213660693732, similarity.itemSimilarity(2, 1));
- correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(2, 3);
- assertCorrelationEquals(0.6905400104897509, correlation);
+ assertCorrelationEquals(0.6905400104897509, similarity.itemSimilarity(2, 3));
+ assertCorrelationEquals(0.6905400104897509, similarity.itemSimilarity(3, 2));
- correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(3, 4);
- assertCorrelationEquals(0.8706358464330881, correlation);
+ assertCorrelationEquals(0.8706358464330881, similarity.itemSimilarity(3, 4));
+ assertCorrelationEquals(0.8706358464330881, similarity.itemSimilarity(4, 3));
+ }
+
+ @Test
+ public void testNoSimilarity() throws Exception {
+
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3, 4},
+ new Double[][] {
+ {1.0, null, 1.0, 1.0},
+ {1.0, null, 1.0, 1.0},
+ {null, 1.0, 1.0, 1.0},
+ {null, 1.0, 1.0, 1.0},
+ });
+
+ LogLikelihoodSimilarity similarity = new LogLikelihoodSimilarity(dataModel);
+
+ assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(1, 0));
+ assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(0, 1));
- correlation = new LogLikelihoodSimilarity(dataModel).itemSimilarity(4, 3);
- assertCorrelationEquals(0.8706358464330881, correlation);
+ assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(2, 3));
+ assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(3, 2));
}
@Test
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedVectorSimilarityTestCase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedVectorSimilarityTestCase.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedVectorSimilarityTestCase.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/vector/DistributedVectorSimilarityTestCase.java Thu Feb 3 12:51:06 2011
@@ -66,6 +66,10 @@ public abstract class DistributedVectorS
}
double result = similarity.similarity(rowA, rowB, cooccurrences, weightA, weightB, numberOfColumns);
- assertEquals(expectedSimilarity, result, EPSILON);
+ if (Double.isNaN(expectedSimilarity)) {
+ assertTrue(Double.isNaN(result));
+ } else {
+ assertEquals(expectedSimilarity, result, EPSILON);
+ }
}
}
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/LogLikelihood.java Thu Feb 3 12:51:06 2011
@@ -24,6 +24,7 @@ import com.google.common.collect.Orderin
import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;
+import java.util.Queue;
/**
* Utility methods for working with log-likelihood
@@ -126,7 +127,10 @@ public final class LogLikelihood {
* in a than b. Use -Double.MAX_VALUE (not Double.MIN_VALUE !) to not use a threshold.
* @return A list of scored items with their scores.
*/
- public static <T> List<ScoredItem<T>> compareFrequencies(Multiset<T> a, Multiset<T> b, int maxReturn, double threshold) {
+ public static <T> List<ScoredItem<T>> compareFrequencies(Multiset<T> a,
+ Multiset<T> b,
+ int maxReturn,
+ double threshold) {
int totalA = a.size();
int totalB = b.size();
@@ -135,7 +139,7 @@ public final class LogLikelihood {
return Double.compare(tScoredItem.score, tScoredItem1.score);
}
};
- PriorityQueue<ScoredItem<T>> best = new PriorityQueue<ScoredItem<T>>(maxReturn + 1, byScoreAscending);
+ Queue<ScoredItem<T>> best = new PriorityQueue<ScoredItem<T>>(maxReturn + 1, byScoreAscending);
for (T t : a.elementSet()) {
compareAndAdd(a, b, maxReturn, threshold, totalA, totalB, best, t);
@@ -156,7 +160,14 @@ public final class LogLikelihood {
return r;
}
- private static <T> void compareAndAdd(Multiset<T> a, Multiset<T> b, int maxReturn, double threshold, int totalA, int totalB, PriorityQueue<ScoredItem<T>> best, T t) {
+ private static <T> void compareAndAdd(Multiset<T> a,
+ Multiset<T> b,
+ int maxReturn,
+ double threshold,
+ int totalA,
+ int totalB,
+ Queue<ScoredItem<T>> best,
+ T t) {
int kA = a.count(t);
int kB = b.count(t);
double score = rootLogLikelihoodRatio(kA, totalA - kA, kB, totalB - kB);
@@ -169,13 +180,21 @@ public final class LogLikelihood {
}
}
- public final static class ScoredItem<T> {
- public T item;
- public double score;
+ public static final class ScoredItem<T> {
+ private final T item;
+ private final double score;
public ScoredItem(T item, double score) {
this.item = item;
this.score = score;
}
+
+ public double getScore() {
+ return score;
+ }
+
+ public T getItem() {
+ return item;
+ }
}
}
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/LogLikelihoodTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/LogLikelihoodTest.java?rev=1066799&r1=1066798&r2=1066799&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/LogLikelihoodTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/LogLikelihoodTest.java Thu Feb 3 12:51:06 2011
@@ -114,19 +114,19 @@ public final class LogLikelihoodTest ext
// comparing frequencies, we should be able to find 8 items with score > 0
List<LogLikelihood.ScoredItem<Integer>> r = LogLikelihood.compareFrequencies(w1, w2, 8, 0);
assertTrue(r.size() <= 8);
- assertTrue(r.size() > 0);
+ assertFalse(r.isEmpty());
for (LogLikelihood.ScoredItem<Integer> item : r) {
- assertTrue(item.score >= 0);
+ assertTrue(item.getScore() >= 0);
}
// the most impressive should be 7
- assertEquals(7, (int) r.get(0).item);
+ assertEquals(7, (int) r.get(0).getItem());
// make sure scores are descending
- double lastScore = r.get(0).score;
+ double lastScore = r.get(0).getScore();
for (LogLikelihood.ScoredItem<Integer> item : r) {
- assertTrue(item.score <= lastScore);
- lastScore = item.score;
+ assertTrue(item.getScore() <= lastScore);
+ lastScore = item.getScore();
}
// now as many as have score >= 1
@@ -134,14 +134,14 @@ public final class LogLikelihoodTest ext
// only the boosted items should make the cut
assertEquals(3, r.size());
- assertEquals(7, (int) r.get(0).item);
- assertEquals(5, (int) r.get(1).item);
- assertEquals(6, (int) r.get(2).item);
+ assertEquals(7, (int) r.get(0).getItem());
+ assertEquals(5, (int) r.get(1).getItem());
+ assertEquals(6, (int) r.get(2).getItem());
r = LogLikelihood.compareFrequencies(w1, w2, 1000, -100);
Multiset<Integer> k = HashMultiset.create();
for (LogLikelihood.ScoredItem<Integer> item : r) {
- k.add(item.item);
+ k.add(item.getItem());
}
for (int i = 0; i < 25; i++) {
assertTrue("i = " + i, k.count(i) == 1 || w2.count(i) == 0);
@@ -149,18 +149,18 @@ public final class LogLikelihoodTest ext
// all values that had non-zero counts in larger set should have result scores
assertEquals(w2.elementSet().size(), r.size());
- assertEquals(7, (int) r.get(0).item);
- assertEquals(5, (int) r.get(1).item);
- assertEquals(6, (int) r.get(2).item);
+ assertEquals(7, (int) r.get(0).getItem());
+ assertEquals(5, (int) r.get(1).getItem());
+ assertEquals(6, (int) r.get(2).getItem());
// the last item should definitely have negative score
- assertTrue(r.get(r.size() - 1).score < 0);
+ assertTrue(r.get(r.size() - 1).getScore() < 0);
// make sure scores are descending
- lastScore = r.get(0).score;
+ lastScore = r.get(0).getScore();
for (LogLikelihood.ScoredItem<Integer> item : r) {
- assertTrue(item.score <= lastScore);
- lastScore = item.score;
+ assertTrue(item.getScore() <= lastScore);
+ lastScore = item.getScore();
}
}
@@ -170,7 +170,7 @@ public final class LogLikelihoodTest ext
* @param rand A random number generator.
* @return A single sample from the multinomial distribution.
*/
- private int sample(Vector p, Random rand) {
+ private static int sample(Vector p, Random rand) {
double u = rand.nextDouble();
// simple sequential algorithm. Not the fastest, but we don't care