You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2011/10/02 09:03:17 UTC
svn commit: r1178133 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/cf/taste/impl/common/
test/java/org/apache/mahout/cf/taste/impl/common/
test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/
Author: srowen
Date: Sun Oct 2 07:03:17 2011
New Revision: 1178133
URL: http://svn.apache.org/viewvc?rev=1178133&view=rev
Log:
Part of MAHOUT-824 -- some additional tests
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java?rev=1178133&r1=1178132&r2=1178133&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java Sun Oct 2 07:03:17 2011
@@ -33,10 +33,14 @@ public class FullRunningAverage implemen
private double average;
public FullRunningAverage() {
- count = 0;
- average = Double.NaN;
+ this(0, Double.NaN);
}
-
+
+ public FullRunningAverage(int count, double average) {
+ this.count = count;
+ this.average = average;
+ }
+
/**
* @param datum
* new item to add to the running average
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java?rev=1178133&r1=1178132&r2=1178133&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java Sun Oct 2 07:03:17 2011
@@ -25,10 +25,23 @@ package org.apache.mahout.cf.taste.impl.
*/
public final class FullRunningAverageAndStdDev extends FullRunningAverage implements RunningAverageAndStdDev {
- private double stdDev = Double.NaN;
+ private double stdDev;
private double mk;
private double sk;
+ public FullRunningAverageAndStdDev() {
+ mk = 0.0;
+ sk = 0.0;
+ recomputeStdDev();
+ }
+
+ public FullRunningAverageAndStdDev(int count, double average, double mk, double sk) {
+ super(count, average);
+ this.mk = mk;
+ this.sk = sk;
+ recomputeStdDev();
+ }
+
@Override
public synchronized double getStandardDeviation() {
return stdDev;
@@ -70,11 +83,7 @@ public final class FullRunningAverageAnd
private synchronized void recomputeStdDev() {
int count = getCount();
- if (count > 1) {
- stdDev = Math.sqrt(sk / (count - 1));
- } else {
- stdDev = Double.NaN;
- }
+ stdDev = count > 1 ? Math.sqrt(sk / (count - 1)) : Double.NaN;
}
@Override
Added: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java?rev=1178133&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java Sun Oct 2 07:03:17 2011
@@ -0,0 +1,147 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import com.google.common.collect.Maps;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import java.util.Map;
+import java.util.Random;
+
+/** <p>Tests {@link FastByIDMap}.</p> */
+public final class FastByIDMapTest extends TasteTestCase {
+
+ @Test
+ public void testPutAndGet() {
+ FastByIDMap<Long> map = new FastByIDMap<Long>();
+ assertNull(map.get(500000L));
+ map.put(500000L, 2L);
+ assertEquals(2L, (long) map.get(500000L));
+ }
+
+ @Test
+ public void testRemove() {
+ FastByIDMap<Long> map = new FastByIDMap<Long>();
+ map.put(500000L, 2L);
+ map.remove(500000L);
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ assertNull(map.get(500000L));
+ }
+
+ @Test
+ public void testClear() {
+ FastByIDMap<Long> map = new FastByIDMap<Long>();
+ map.put(500000L, 2L);
+ map.clear();
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ assertNull(map.get(500000L));
+ }
+
+ @Test
+ public void testSizeEmpty() {
+ FastByIDMap<Long> map = new FastByIDMap<Long>();
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ map.put(500000L, 2L);
+ assertEquals(1, map.size());
+ assertFalse(map.isEmpty());
+ map.remove(500000L);
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ }
+
+ @Test
+ public void testContains() {
+ FastByIDMap<String> map = buildTestFastMap();
+ assertTrue(map.containsKey(500000L));
+ assertTrue(map.containsKey(47L));
+ assertTrue(map.containsKey(2L));
+ assertTrue(map.containsValue("alpha"));
+ assertTrue(map.containsValue("bang"));
+ assertTrue(map.containsValue("beta"));
+ assertFalse(map.containsKey(999));
+ assertFalse(map.containsValue("something"));
+ }
+
+ @Test
+ public void testRehash() {
+ FastByIDMap<String> map = buildTestFastMap();
+ map.remove(500000L);
+ map.rehash();
+ assertNull(map.get(500000L));
+ assertEquals("bang", map.get(47L));
+ }
+
+ @Test
+ public void testGrow() {
+ FastByIDMap<String> map = new FastByIDMap<String>(1,1);
+ map.put(500000L, "alpha");
+ map.put(47L, "bang");
+ assertNull(map.get(500000L));
+ assertEquals("bang", map.get(47L));
+ }
+
+ @Test
+ public void testVersusHashMap() {
+ FastByIDMap<String> actual = new FastByIDMap<String>();
+ Map<Long, String> expected = Maps.newHashMapWithExpectedSize(1000000);
+ Random r = RandomUtils.getRandom();
+ for (int i = 0; i < 1000000; i++) {
+ double d = r.nextDouble();
+ Long key = (long) r.nextInt(100);
+ if (d < 0.4) {
+ assertEquals(expected.get(key), actual.get(key));
+ } else {
+ if (d < 0.7) {
+ assertEquals(expected.put(key, "bang"), actual.put(key, "bang"));
+ } else {
+ assertEquals(expected.remove(key), actual.remove(key));
+ }
+ assertEquals(expected.size(), actual.size());
+ assertEquals(expected.isEmpty(), actual.isEmpty());
+ }
+ }
+ }
+
+ @Test
+ public void testMaxSize() {
+ FastByIDMap<String> map = new FastByIDMap<String>();
+ map.put(4, "bang");
+ assertEquals(1, map.size());
+ map.put(47L, "bang");
+ assertEquals(2, map.size());
+ assertNull(map.get(500000L));
+ map.put(47L, "buzz");
+ assertEquals(2, map.size());
+ assertEquals("buzz", map.get(47L));
+ }
+
+
+ private static FastByIDMap<String> buildTestFastMap() {
+ FastByIDMap<String> map = new FastByIDMap<String>();
+ map.put(500000L, "alpha");
+ map.put(47L, "bang");
+ map.put(2L, "beta");
+ return map;
+ }
+
+}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java?rev=1178133&r1=1178132&r2=1178133&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java Sun Oct 2 07:03:17 2011
@@ -74,5 +74,34 @@ public final class RunningAverageAndStdD
assertEquals(1000.0 / Math.sqrt(12.0), average.getStandardDeviation(), SMALL_EPSILON);
}
+
+ @Test
+ public void testStddev() {
+
+ RunningAverageAndStdDev runningAverage = new FullRunningAverageAndStdDev();
+
+ assertEquals(0, runningAverage.getCount());
+ assertTrue(Double.isNaN(runningAverage.getAverage()));
+ runningAverage.addDatum(1.0);
+ assertEquals(1, runningAverage.getCount());
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ assertTrue(Double.isNaN(runningAverage.getStandardDeviation()));
+ runningAverage.addDatum(1.0);
+ assertEquals(2, runningAverage.getCount());
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(0.0, runningAverage.getStandardDeviation(), EPSILON);
+
+ runningAverage.addDatum(7.0);
+ assertEquals(3, runningAverage.getCount());
+ assertEquals(3.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(3.464101552963257, runningAverage.getStandardDeviation(), EPSILON);
+
+ runningAverage.addDatum(5.0);
+ assertEquals(4, runningAverage.getCount());
+ assertEquals(3.5, runningAverage.getAverage(), EPSILON);
+ assertEquals(3.0, runningAverage.getStandardDeviation(), EPSILON);
+
+ }
+
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java?rev=1178133&r1=1178132&r2=1178133&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java Sun Oct 2 07:03:17 2011
@@ -56,5 +56,20 @@ public final class RunningAverageTest ex
assertEquals(2, runningAverage.getCount());
assertEquals(2.0, runningAverage.getAverage(), EPSILON);
}
+
+ @Test
+ public void testCopyConstructor() {
+ RunningAverage runningAverage = new FullRunningAverage();
+
+ runningAverage.addDatum(1.0);
+ runningAverage.addDatum(1.0);
+ assertEquals(2, runningAverage.getCount());
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+
+ RunningAverage copy = new FullRunningAverage(runningAverage.getCount(), runningAverage.getAverage());
+ assertEquals(2, copy.getCount());
+ assertEquals(1.0, copy.getAverage(), EPSILON);
+
+ }
}
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java?rev=1178133&r1=1178132&r2=1178133&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java Sun Oct 2 07:03:17 2011
@@ -17,9 +17,12 @@
package org.apache.mahout.cf.taste.impl.recommender.slopeone;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
import org.apache.mahout.cf.taste.common.Weighting;
import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
import org.apache.mahout.cf.taste.model.DataModel;
import org.junit.Test;
@@ -27,6 +30,52 @@ import org.junit.Test;
public final class MemoryDiffStorageTest extends TasteTestCase {
@Test
+ public void testRecommendableIDsVariedWeighted() throws Exception {
+ DataModel model = getDataModelVaried();
+ MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.WEIGHTED, Long.MAX_VALUE);
+ FastIDSet recommendableItemIDs = storage.getRecommendableItemIDs(1);
+ assertEquals(3, recommendableItemIDs.size());
+ assertTrue(recommendableItemIDs.contains(1));
+ recommendableItemIDs = storage.getRecommendableItemIDs(2);
+ assertEquals(2, recommendableItemIDs.size());
+ assertTrue(recommendableItemIDs.contains(2));
+ assertTrue(recommendableItemIDs.contains(3));
+
+ recommendableItemIDs = storage.getRecommendableItemIDs(3);
+ assertEquals(1, recommendableItemIDs.size());
+ assertTrue(recommendableItemIDs.contains(3));
+
+ recommendableItemIDs = storage.getRecommendableItemIDs(4);
+ assertEquals(0, recommendableItemIDs.size());
+ // the last item has only one recommendation, and so only 4 items are usable
+ recommendableItemIDs = storage.getRecommendableItemIDs(5);
+ assertEquals(0, recommendableItemIDs.size());
+ }
+
+ @Test
+ public void testRecommendableIDsPockedUnweighted() throws Exception {
+ DataModel model = getDataModelPocked();
+ MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.UNWEIGHTED, Long.MAX_VALUE);
+ FastIDSet recommendableItemIDs = storage.getRecommendableItemIDs(1);
+ assertEquals(0, recommendableItemIDs.size());
+ recommendableItemIDs = storage.getRecommendableItemIDs(2);
+ assertEquals(1, recommendableItemIDs.size());
+ recommendableItemIDs = storage.getRecommendableItemIDs(3);
+ assertEquals(0, recommendableItemIDs.size());
+
+ recommendableItemIDs = storage.getRecommendableItemIDs(4);
+ assertEquals(0, recommendableItemIDs.size());
+
+ }
+
+ @Test (expected=NoSuchUserException.class)
+ public void testUnRecommendableID() throws Exception {
+ DataModel model = getDataModel();
+ MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.WEIGHTED, Long.MAX_VALUE);
+ storage.getRecommendableItemIDs(0);
+ }
+
+ @Test
public void testGetDiff() throws Exception {
DataModel model = getDataModel();
MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.UNWEIGHTED, Long.MAX_VALUE);
@@ -34,69 +83,133 @@ public final class MemoryDiffStorageTest
assertEquals(0.23333333333333334, average.getAverage(), EPSILON);
assertEquals(3, average.getCount());
}
-
+
@Test
public void testAdd() throws Exception {
DataModel model = getDataModel();
MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.UNWEIGHTED, Long.MAX_VALUE);
-
+
RunningAverage average1 = storage.getDiff(0, 2);
assertEquals(0.1, average1.getAverage(), EPSILON);
assertEquals(3, average1.getCount());
-
+
RunningAverage average2 = storage.getDiff(1, 2);
assertEquals(0.23333332935969034, average2.getAverage(), EPSILON);
assertEquals(3, average2.getCount());
-
+
storage.addItemPref(1, 2, 0.8f);
-
+
average1 = storage.getDiff(0, 2);
assertEquals(0.25, average1.getAverage(), EPSILON);
assertEquals(4, average1.getCount());
-
+
average2 = storage.getDiff(1, 2);
assertEquals(0.3, average2.getAverage(), EPSILON);
assertEquals(4, average2.getCount());
}
-
+
@Test
public void testUpdate() throws Exception {
DataModel model = getDataModel();
MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.UNWEIGHTED, Long.MAX_VALUE);
-
+
RunningAverage average = storage.getDiff(1, 2);
assertEquals(0.23333332935969034, average.getAverage(), EPSILON);
assertEquals(3, average.getCount());
-
+
storage.updateItemPref(1, 0.5f);
-
+
average = storage.getDiff(1, 2);
assertEquals(0.06666666666666668, average.getAverage(), EPSILON);
assertEquals(3, average.getCount());
}
-
+
@Test
public void testRemove() throws Exception {
DataModel model = getDataModel();
MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.UNWEIGHTED, Long.MAX_VALUE);
-
+
RunningAverage average1 = storage.getDiff(0, 2);
assertEquals(0.1, average1.getAverage(), EPSILON);
assertEquals(3, average1.getCount());
-
+
RunningAverage average2 = storage.getDiff(1, 2);
assertEquals(0.23333332935969034, average2.getAverage(), EPSILON);
assertEquals(3, average2.getCount());
-
+
storage.removeItemPref(4, 2, 0.8f);
-
+
average1 = storage.getDiff(0, 2);
assertEquals(0.1, average1.getAverage(), EPSILON);
assertEquals(2, average1.getCount());
-
+
average2 = storage.getDiff(1, 2);
assertEquals(0.1, average2.getAverage(), EPSILON);
assertEquals(2, average2.getCount());
}
+
+ @Test (expected=UnsupportedOperationException.class)
+ public void testUpdateWeighted() throws Exception {
+ DataModel model = getDataModelVaried();
+ MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.WEIGHTED, Long.MAX_VALUE);
+
+ storage.updateItemPref(2, 0.8f);
+ }
+
+ @Test
+ public void testRemovePref() throws Exception {
+ double eps = 0.0001;
+ DataModel model = getDataModelPocked();
+ MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.WEIGHTED, Long.MAX_VALUE);
+
+ RunningAverageAndStdDev average = (RunningAverageAndStdDev) storage.getDiff(0, 1);
+ assertEquals(-0.033333, average.getAverage(), eps);
+ assertEquals(0.32145, average.getStandardDeviation(), eps);
+ assertEquals(3, average.getCount());
+
+ storage.removeItemPref(2, 1, 0.1f);
+ average = (RunningAverageAndStdDev) storage.getDiff(0, 1);
+ assertEquals(0.00000001, average.getAverage(), eps);
+ assertEquals(0.44721, average.getStandardDeviation(), eps);
+ assertEquals(2, average.getCount());
+ }
+
+ protected static DataModel getDataModelVaried() {
+ return getDataModel(
+ new long[] {1, 2, 3, 4, 5},
+ new Double[][] {
+ {0.2},
+ {0.4, 0.5},
+ {0.7, 0.1, 0.5},
+ {0.7, 0.3, 0.8, 0.1},
+ {0.2, 0.3, 0.6, 0.1, 0.3},
+ });
+ }
+
+ protected static DataModel getDataModelPocked() {
+ return getDataModel(
+ new long[] {1, 2, 3, 4},
+ new Double[][] {
+ {0.1, 0.3},
+ {0.2},
+ {0.4, 0.5},
+ {0.7, 0.3, 0.8},
+ });
+ }
+
+ protected static DataModel getDataModelLarge() {
+ return getDataModel(
+ new long[] {1, 2, 3, 4, 5, 6, 7},
+ new Double[][] {
+ {0.2, .2, .2, .2, .2, .2, .2},
+ {0.4, 0.5, .3, .3, .3, .3, .3},
+ {0.7, 0.1, 0.5, .2, .7, .8, .9},
+ {0.7, 0.3, 0.8, 0.1, .6, .6, .6},
+ {0.2, 0.3, 0.6, 0.1, 0.3, .4, .4},
+ {0.2, 0.3, 0.6, 0.1, 0.3, .4, .4},
+ {0.2, 0.3, 0.6, 0.1, 0.3, .5, .5},
+ });
+ }
+
}