You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2011/10/02 09:03:17 UTC

svn commit: r1178133 - in /mahout/trunk/core/src: main/java/org/apache/mahout/cf/taste/impl/common/ test/java/org/apache/mahout/cf/taste/impl/common/ test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/

Author: srowen
Date: Sun Oct  2 07:03:17 2011
New Revision: 1178133

URL: http://svn.apache.org/viewvc?rev=1178133&view=rev
Log:
Part of MAHOUT-824 -- some additional tests

Added:
    mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java
    mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
    mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java?rev=1178133&r1=1178132&r2=1178133&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java Sun Oct  2 07:03:17 2011
@@ -33,10 +33,14 @@ public class FullRunningAverage implemen
   private double average;
   
   public FullRunningAverage() {
-    count = 0;
-    average = Double.NaN;
+    this(0, Double.NaN);
   }
-  
+
+  public FullRunningAverage(int count, double average) {
+    this.count = count;
+    this.average = average;    
+  }
+
   /**
    * @param datum
    *          new item to add to the running average

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java?rev=1178133&r1=1178132&r2=1178133&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java Sun Oct  2 07:03:17 2011
@@ -25,10 +25,23 @@ package org.apache.mahout.cf.taste.impl.
  */
 public final class FullRunningAverageAndStdDev extends FullRunningAverage implements RunningAverageAndStdDev {
 
-  private double stdDev = Double.NaN;
+  private double stdDev;
   private double mk;
   private double sk;
   
+  public FullRunningAverageAndStdDev() {
+    mk = 0.0;
+    sk = 0.0;
+    recomputeStdDev();
+  }
+  
+  public FullRunningAverageAndStdDev(int count, double average, double mk, double sk) {
+    super(count, average);
+    this.mk = mk;
+    this.sk = sk;
+    recomputeStdDev();
+  }
+
   @Override
   public synchronized double getStandardDeviation() {
     return stdDev;
@@ -70,11 +83,7 @@ public final class FullRunningAverageAnd
   
   private synchronized void recomputeStdDev() {
     int count = getCount();
-    if (count > 1) {
-      stdDev = Math.sqrt(sk / (count - 1));
-    } else {
-      stdDev = Double.NaN;
-    }
+    stdDev = count > 1 ? Math.sqrt(sk / (count - 1)) : Double.NaN;
   }
   
   @Override

Added: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java?rev=1178133&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java Sun Oct  2 07:03:17 2011
@@ -0,0 +1,147 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import com.google.common.collect.Maps;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import java.util.Map;
+import java.util.Random;
+
+/** <p>Tests {@link FastByIDMap}.</p> */
+public final class FastByIDMapTest extends TasteTestCase {
+
+  @Test
+  public void testPutAndGet() {
+    FastByIDMap<Long> map = new FastByIDMap<Long>();
+    assertNull(map.get(500000L));
+    map.put(500000L, 2L);
+    assertEquals(2L, (long) map.get(500000L));
+  }
+  
+  @Test
+  public void testRemove() {
+    FastByIDMap<Long> map = new FastByIDMap<Long>();
+    map.put(500000L, 2L);
+    map.remove(500000L);
+    assertEquals(0, map.size());
+    assertTrue(map.isEmpty());
+    assertNull(map.get(500000L));
+  }
+  
+  @Test
+  public void testClear() {
+    FastByIDMap<Long> map = new FastByIDMap<Long>();
+    map.put(500000L, 2L);
+    map.clear();
+    assertEquals(0, map.size());
+    assertTrue(map.isEmpty());
+    assertNull(map.get(500000L));
+  }
+  
+  @Test
+  public void testSizeEmpty() {
+    FastByIDMap<Long> map = new FastByIDMap<Long>();
+    assertEquals(0, map.size());
+    assertTrue(map.isEmpty());
+    map.put(500000L, 2L);
+    assertEquals(1, map.size());
+    assertFalse(map.isEmpty());
+    map.remove(500000L);
+    assertEquals(0, map.size());
+    assertTrue(map.isEmpty());
+  }
+  
+  @Test
+  public void testContains() {
+    FastByIDMap<String> map = buildTestFastMap();
+    assertTrue(map.containsKey(500000L));
+    assertTrue(map.containsKey(47L));
+    assertTrue(map.containsKey(2L));
+    assertTrue(map.containsValue("alpha"));
+    assertTrue(map.containsValue("bang"));
+    assertTrue(map.containsValue("beta"));
+    assertFalse(map.containsKey(999));
+    assertFalse(map.containsValue("something"));
+  }
+
+  @Test
+  public void testRehash() {
+    FastByIDMap<String> map = buildTestFastMap();
+    map.remove(500000L);
+    map.rehash();
+    assertNull(map.get(500000L));
+    assertEquals("bang", map.get(47L));
+  }
+  
+  @Test
+  public void testGrow() {
+    FastByIDMap<String> map = new FastByIDMap<String>(1,1);
+    map.put(500000L, "alpha");
+    map.put(47L, "bang");
+    assertNull(map.get(500000L));
+    assertEquals("bang", map.get(47L));
+  }
+   
+  @Test
+  public void testVersusHashMap() {
+    FastByIDMap<String> actual = new FastByIDMap<String>();
+    Map<Long, String> expected = Maps.newHashMapWithExpectedSize(1000000);
+    Random r = RandomUtils.getRandom();
+    for (int i = 0; i < 1000000; i++) {
+      double d = r.nextDouble();
+      Long key = (long) r.nextInt(100);
+      if (d < 0.4) {
+        assertEquals(expected.get(key), actual.get(key));
+      } else {
+        if (d < 0.7) {
+          assertEquals(expected.put(key, "bang"), actual.put(key, "bang"));
+        } else {
+          assertEquals(expected.remove(key), actual.remove(key));
+        }
+        assertEquals(expected.size(), actual.size());
+        assertEquals(expected.isEmpty(), actual.isEmpty());
+      }
+    }
+  }
+  
+  @Test
+  public void testMaxSize() {
+    FastByIDMap<String> map = new FastByIDMap<String>();
+    map.put(4, "bang");
+    assertEquals(1, map.size());
+    map.put(47L, "bang");
+    assertEquals(2, map.size());
+    assertNull(map.get(500000L));
+    map.put(47L, "buzz");
+    assertEquals(2, map.size());
+    assertEquals("buzz", map.get(47L));
+  }
+  
+  
+  private static FastByIDMap<String> buildTestFastMap() {
+    FastByIDMap<String> map = new FastByIDMap<String>();
+    map.put(500000L, "alpha");
+    map.put(47L, "bang");
+    map.put(2L, "beta");
+    return map;
+  }
+  
+}

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java?rev=1178133&r1=1178132&r2=1178133&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java Sun Oct  2 07:03:17 2011
@@ -74,5 +74,34 @@ public final class RunningAverageAndStdD
     assertEquals(1000.0 / Math.sqrt(12.0), average.getStandardDeviation(), SMALL_EPSILON);
 
   }
+  
+  @Test
+  public void testStddev() {
+    
+    RunningAverageAndStdDev runningAverage = new FullRunningAverageAndStdDev();
+
+    assertEquals(0, runningAverage.getCount());
+    assertTrue(Double.isNaN(runningAverage.getAverage()));
+    runningAverage.addDatum(1.0);
+    assertEquals(1, runningAverage.getCount());
+    assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+    assertTrue(Double.isNaN(runningAverage.getStandardDeviation()));
+    runningAverage.addDatum(1.0);
+    assertEquals(2, runningAverage.getCount());
+    assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+    assertEquals(0.0, runningAverage.getStandardDeviation(), EPSILON);
+
+    runningAverage.addDatum(7.0);
+    assertEquals(3, runningAverage.getCount());
+    assertEquals(3.0, runningAverage.getAverage(), EPSILON); 
+    assertEquals(3.464101552963257, runningAverage.getStandardDeviation(), EPSILON);
+
+    runningAverage.addDatum(5.0);
+    assertEquals(4, runningAverage.getCount());
+    assertEquals(3.5, runningAverage.getAverage(), EPSILON); 
+    assertEquals(3.0, runningAverage.getStandardDeviation(), EPSILON);
+
+  }
+  
 
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java?rev=1178133&r1=1178132&r2=1178133&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java Sun Oct  2 07:03:17 2011
@@ -56,5 +56,20 @@ public final class RunningAverageTest ex
     assertEquals(2, runningAverage.getCount());
     assertEquals(2.0, runningAverage.getAverage(), EPSILON);
   }
+  
+  @Test
+  public void testCopyConstructor() {
+    RunningAverage runningAverage = new FullRunningAverage();
+
+    runningAverage.addDatum(1.0);
+    runningAverage.addDatum(1.0);
+    assertEquals(2, runningAverage.getCount());
+    assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+
+    RunningAverage copy = new FullRunningAverage(runningAverage.getCount(), runningAverage.getAverage());
+    assertEquals(2, copy.getCount());
+    assertEquals(1.0, copy.getAverage(), EPSILON);
+    
+  }
 
 }

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java?rev=1178133&r1=1178132&r2=1178133&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/MemoryDiffStorageTest.java Sun Oct  2 07:03:17 2011
@@ -17,9 +17,12 @@
 
 package org.apache.mahout.cf.taste.impl.recommender.slopeone;
 
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
 import org.apache.mahout.cf.taste.common.Weighting;
 import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
 import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
 import org.apache.mahout.cf.taste.model.DataModel;
 import org.junit.Test;
 
@@ -27,6 +30,52 @@ import org.junit.Test;
 public final class MemoryDiffStorageTest extends TasteTestCase {
 
   @Test
+  public void testRecommendableIDsVariedWeighted() throws Exception {
+    DataModel model = getDataModelVaried();
+    MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.WEIGHTED, Long.MAX_VALUE);
+    FastIDSet recommendableItemIDs = storage.getRecommendableItemIDs(1);
+    assertEquals(3, recommendableItemIDs.size());
+    assertTrue(recommendableItemIDs.contains(1));
+    recommendableItemIDs = storage.getRecommendableItemIDs(2);
+    assertEquals(2, recommendableItemIDs.size());
+    assertTrue(recommendableItemIDs.contains(2));
+    assertTrue(recommendableItemIDs.contains(3));
+    
+    recommendableItemIDs = storage.getRecommendableItemIDs(3);
+    assertEquals(1, recommendableItemIDs.size());
+    assertTrue(recommendableItemIDs.contains(3));
+    
+    recommendableItemIDs = storage.getRecommendableItemIDs(4);
+    assertEquals(0, recommendableItemIDs.size());
+    // the last item has only one recommendation, and so only 4 items are usable
+    recommendableItemIDs = storage.getRecommendableItemIDs(5);
+    assertEquals(0, recommendableItemIDs.size());
+  }
+  
+  @Test
+  public void testRecommendableIDsPockedUnweighted() throws Exception {
+    DataModel model = getDataModelPocked();
+    MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.UNWEIGHTED, Long.MAX_VALUE);
+    FastIDSet recommendableItemIDs = storage.getRecommendableItemIDs(1);
+    assertEquals(0, recommendableItemIDs.size());
+    recommendableItemIDs = storage.getRecommendableItemIDs(2);
+    assertEquals(1, recommendableItemIDs.size());
+    recommendableItemIDs = storage.getRecommendableItemIDs(3);
+    assertEquals(0, recommendableItemIDs.size());
+    
+    recommendableItemIDs = storage.getRecommendableItemIDs(4);
+    assertEquals(0, recommendableItemIDs.size());
+    
+  }
+  
+  @Test (expected=NoSuchUserException.class)
+  public void testUnRecommendableID() throws Exception {
+    DataModel model = getDataModel();
+    MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.WEIGHTED, Long.MAX_VALUE);
+    storage.getRecommendableItemIDs(0);
+  }
+  
+  @Test
   public void testGetDiff() throws Exception {
     DataModel model = getDataModel();
     MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.UNWEIGHTED, Long.MAX_VALUE);
@@ -34,69 +83,133 @@ public final class MemoryDiffStorageTest
     assertEquals(0.23333333333333334, average.getAverage(), EPSILON);
     assertEquals(3, average.getCount());
   }
-
+  
   @Test
   public void testAdd() throws Exception {
     DataModel model = getDataModel();
     MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.UNWEIGHTED, Long.MAX_VALUE);
-
+    
     RunningAverage average1 = storage.getDiff(0, 2);
     assertEquals(0.1, average1.getAverage(), EPSILON);
     assertEquals(3, average1.getCount());
-
+    
     RunningAverage average2 = storage.getDiff(1, 2);
     assertEquals(0.23333332935969034, average2.getAverage(), EPSILON);
     assertEquals(3, average2.getCount());
-
+    
     storage.addItemPref(1, 2, 0.8f);
-
+    
     average1 = storage.getDiff(0, 2);
     assertEquals(0.25, average1.getAverage(), EPSILON);
     assertEquals(4, average1.getCount());
-
+    
     average2 = storage.getDiff(1, 2);
     assertEquals(0.3, average2.getAverage(), EPSILON);
     assertEquals(4, average2.getCount());
   }
-
+  
   @Test
   public void testUpdate() throws Exception {
     DataModel model = getDataModel();
     MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.UNWEIGHTED, Long.MAX_VALUE);
-
+    
     RunningAverage average = storage.getDiff(1, 2);
     assertEquals(0.23333332935969034, average.getAverage(), EPSILON);
     assertEquals(3, average.getCount());
-
+    
     storage.updateItemPref(1, 0.5f);
-
+    
     average = storage.getDiff(1, 2);
     assertEquals(0.06666666666666668, average.getAverage(), EPSILON);
     assertEquals(3, average.getCount());
   }
-
+  
   @Test
   public void testRemove() throws Exception {
     DataModel model = getDataModel();
     MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.UNWEIGHTED, Long.MAX_VALUE);
-
+    
     RunningAverage average1 = storage.getDiff(0, 2);
     assertEquals(0.1, average1.getAverage(), EPSILON);
     assertEquals(3, average1.getCount());
-
+    
     RunningAverage average2 = storage.getDiff(1, 2);
     assertEquals(0.23333332935969034, average2.getAverage(), EPSILON);
     assertEquals(3, average2.getCount());
-
+    
     storage.removeItemPref(4, 2, 0.8f);
-
+    
     average1 = storage.getDiff(0, 2);
     assertEquals(0.1, average1.getAverage(), EPSILON);
     assertEquals(2, average1.getCount());
-
+    
     average2 = storage.getDiff(1, 2);
     assertEquals(0.1, average2.getAverage(), EPSILON);
     assertEquals(2, average2.getCount());
   }
+  
+  @Test (expected=UnsupportedOperationException.class)
+  public void testUpdateWeighted() throws Exception {
+    DataModel model = getDataModelVaried();
+    MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.WEIGHTED, Long.MAX_VALUE);
+    
+    storage.updateItemPref(2, 0.8f);
+  }
+  
+  @Test
+  public void testRemovePref() throws Exception {
+    double eps = 0.0001;
+    DataModel model = getDataModelPocked();
+    MemoryDiffStorage storage = new MemoryDiffStorage(model, Weighting.WEIGHTED, Long.MAX_VALUE);
+    
+    RunningAverageAndStdDev average = (RunningAverageAndStdDev) storage.getDiff(0, 1);
+    assertEquals(-0.033333, average.getAverage(), eps);
+    assertEquals(0.32145, average.getStandardDeviation(), eps);
+    assertEquals(3, average.getCount());
+
+    storage.removeItemPref(2, 1, 0.1f);
+    average = (RunningAverageAndStdDev) storage.getDiff(0, 1);
+    assertEquals(0.00000001, average.getAverage(), eps);
+    assertEquals(0.44721, average.getStandardDeviation(), eps);
+    assertEquals(2, average.getCount());
+  }
+  
+  protected static DataModel getDataModelVaried() {
+    return getDataModel(
+        new long[] {1, 2, 3, 4, 5},
+        new Double[][] {
+            {0.2},
+            {0.4, 0.5},
+            {0.7, 0.1, 0.5},
+            {0.7, 0.3, 0.8, 0.1},
+            {0.2, 0.3, 0.6, 0.1, 0.3},
+        });
+  }
+  
+  protected static DataModel getDataModelPocked() {
+    return getDataModel(
+        new long[] {1, 2, 3, 4},
+        new Double[][] {
+            {0.1, 0.3},
+            {0.2},
+            {0.4, 0.5},
+            {0.7, 0.3, 0.8},
+        });
+  }
+  
+  protected static DataModel getDataModelLarge() {
+    return getDataModel(
+        new long[] {1, 2, 3, 4, 5, 6, 7},
+        new Double[][] {
+            {0.2, .2, .2, .2, .2, .2, .2},
+            {0.4, 0.5, .3, .3, .3, .3, .3},
+            {0.7, 0.1, 0.5, .2, .7, .8, .9},
+            {0.7, 0.3, 0.8, 0.1, .6, .6, .6},
+            {0.2, 0.3, 0.6, 0.1, 0.3, .4, .4},
+            {0.2, 0.3, 0.6, 0.1, 0.3, .4, .4},
+            {0.2, 0.3, 0.6, 0.1, 0.3, .5, .5},
+        });
+  }
+  
 
 }