You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/03/29 12:59:48 UTC

svn commit: r928711 - in /lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl: common/ recommender/slopeone/jdbc/

Author: srowen
Date: Mon Mar 29 10:59:47 2010
New Revision: 928711

URL: http://svn.apache.org/viewvc?rev=928711&view=rev
Log:
Add standard deviation support to JDBC diff storage

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java
      - copied, changed from r928681, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java

Copied: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java (from r928681, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java)
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java?p2=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java&p1=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java&r1=928681&r2=928711&rev=928711&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java Mon Mar 29 10:59:47 2010
@@ -21,80 +21,58 @@ import java.io.Serializable;
 
 /**
  * <p>
- * A simple class that can keep track of a running avearage of a series of numbers. One can add to or remove
- * from the series, as well as update a datum in the series. The class does not actually keep track of the
- * series of values, just its running average, so it doesn't even matter if you remove/change a value that
- * wasn't added.
+ * A simple class that represents a fixed value of an average and count. This is useful
+ * when an API needs to return {@link RunningAverage} but is not in a position to accept
+ * updates to it.
  * </p>
  */
-public class FullRunningAverage implements RunningAverage, Serializable {
-  
-  private int count;
-  private double average;
-  
-  public FullRunningAverage() {
-    count = 0;
-    average = Double.NaN;
+public class FixedRunningAverage implements RunningAverage, Serializable {
+
+  private final double average;
+  private final int count;
+
+  public FixedRunningAverage(double average, int count) {
+    this.average = average;
+    this.count = count;
   }
-  
+
   /**
-   * @param datum
-   *          new item to add to the running average
+   * @throws UnsupportedOperationException
    */
   @Override
   public synchronized void addDatum(double datum) {
-    if (++count == 1) {
-      average = datum;
-    } else {
-      average = average * (count - 1) / count + datum / count;
-    }
+    throw new UnsupportedOperationException();
   }
-  
+
   /**
-   * @param datum
-   *          item to remove to the running average
-   * @throws IllegalStateException
-   *           if count is 0
+   * @throws UnsupportedOperationException
    */
   @Override
   public synchronized void removeDatum(double datum) {
-    if (count == 0) {
-      throw new IllegalStateException();
-    }
-    if (--count == 0) {
-      average = Double.NaN;
-    } else {
-      average = average * (count + 1) / count - datum / count;
-    }
+    throw new UnsupportedOperationException();
   }
-  
+
   /**
-   * @param delta
-   *          amount by which to change a datum in the running average
-   * @throws IllegalStateException
-   *           if count is 0
+   * @throws UnsupportedOperationException
    */
   @Override
   public synchronized void changeDatum(double delta) {
-    if (count == 0) {
-      throw new IllegalStateException();
-    }
-    average += delta / count;
+    throw new UnsupportedOperationException();
   }
-  
+
   @Override
   public synchronized int getCount() {
     return count;
   }
-  
+
   @Override
   public synchronized double getAverage() {
     return average;
   }
-  
+
   @Override
   public synchronized String toString() {
     return String.valueOf(average);
   }
-  
-}
+
+}
\ No newline at end of file

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java?rev=928711&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java Mon Mar 29 10:59:47 2010
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+/**
+ * <p>
+ * A simple class that represents a fixed value of an average, count and standard deviation. This is useful
+ * when an API needs to return {@link RunningAverageAndStdDev} but is not in a position to accept
+ * updates to it.
+ * </p>
+ */
+public final class FixedRunningAverageAndStdDev extends FixedRunningAverage implements RunningAverageAndStdDev {
+
+  private final double stdDev;
+
+  public FixedRunningAverageAndStdDev(double average, double stdDev, int count) {
+    super(average, count);
+    this.stdDev = stdDev;
+  }
+
+  @Override
+  public synchronized String toString() {
+    return super.toString() + ',' + stdDev;
+  }
+
+  @Override
+  public double getStandardDeviation() {
+    return stdDev;
+  }
+
+}
\ No newline at end of file

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java?rev=928711&r1=928710&r2=928711&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java Mon Mar 29 10:59:47 2010
@@ -30,6 +30,8 @@ import javax.sql.DataSource;
 import org.apache.mahout.cf.taste.common.Refreshable;
 import org.apache.mahout.cf.taste.common.TasteException;
 import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.FixedRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.FixedRunningAverageAndStdDev;
 import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
 import org.apache.mahout.cf.taste.impl.common.RunningAverage;
 import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent;
@@ -45,7 +47,7 @@ import org.slf4j.LoggerFactory;
  * A {@link DiffStorage} which stores diffs in a database. Database-specific implementations subclass this
  * abstract class. Note that this implementation has a fairly particular dependence on the
  * {@link org.apache.mahout.cf.taste.model.DataModel} used; it needs a {@link JDBCDataModel} attached to the
- * same database since its efficent operation depends on accessing preference data in the database directly.
+ * same database since its efficient operation depends on accessing preference data in the database directly.
  * </p>
  */
 public abstract class AbstractJDBCDiffStorage extends AbstractJDBCComponent implements DiffStorage {
@@ -57,7 +59,8 @@ public abstract class AbstractJDBCDiffSt
   public static final String DEFAULT_ITEM_B_COLUMN = "item_id_b";
   public static final String DEFAULT_COUNT_COLUMN = "count";
   public static final String DEFAULT_AVERAGE_DIFF_COLUMN = "average_diff";
-  
+  public static final String DEFAULT_STDEV_COLUMN = "standard_deviation";
+
   private final DataSource dataSource;
   private final String getDiffSQL;
   private final String getDiffsSQL;
@@ -140,7 +143,7 @@ public abstract class AbstractJDBCDiffSt
       stmt.setLong(4, itemID1);
       log.debug("Executing SQL query: {}", getDiffSQL);
       rs = stmt.executeQuery();
-      return rs.next() ? new FixedRunningAverage(rs.getInt(1), rs.getDouble(2)) : null;
+      return rs.next() ? new FixedRunningAverageAndStdDev(rs.getDouble(2), rs.getDouble(3), rs.getInt(1)) : null;
     } catch (SQLException sqle) {
       log.warn("Exception while retrieving diff", sqle);
       throw new TasteException(sqle);
@@ -175,7 +178,7 @@ public abstract class AbstractJDBCDiffSt
           i++;
           // result[i] is null for these values of i
         }
-        result[i] = new FixedRunningAverage(rs.getInt(1), rs.getDouble(2));
+        result[i] = new FixedRunningAverageAndStdDev(rs.getDouble(2), rs.getDouble(3), rs.getInt(1));
         i++;
       }
     } catch (SQLException sqle) {
@@ -204,7 +207,7 @@ public abstract class AbstractJDBCDiffSt
       if (rs.next()) {
         int count = rs.getInt(1);
         if (count > 0) {
-          return new FixedRunningAverage(count, rs.getDouble(2));
+          return new FixedRunningAverage(rs.getDouble(2), count);
         }
       }
       return null;
@@ -215,7 +218,12 @@ public abstract class AbstractJDBCDiffSt
       IOUtils.quietClose(rs, stmt, conn);
     }
   }
-  
+
+  /**
+   * Note that this implementation does <em>not</em> update standard deviations. This would
+   * be expensive relative to the value of slightly adjusting these values, which are merely
+   * used as weighted. Rebuilding the diffs table will update standard deviations.
+   */
   @Override
   public void updateItemPref(long itemID, float prefDelta, boolean remove) throws TasteException {
     Connection conn = null;
@@ -330,41 +338,4 @@ public abstract class AbstractJDBCDiffSt
   public void refresh(Collection<Refreshable> alreadyRefreshed) {
     refreshHelper.refresh(alreadyRefreshed);
   }
-  
-  private static class FixedRunningAverage implements RunningAverage {
-    
-    private final int count;
-    private final double average;
-    
-    private FixedRunningAverage(int count, double average) {
-      this.count = count;
-      this.average = average;
-    }
-    
-    @Override
-    public void addDatum(double datum) {
-      throw new UnsupportedOperationException();
-    }
-    
-    @Override
-    public void removeDatum(double datum) {
-      throw new UnsupportedOperationException();
-    }
-    
-    @Override
-    public void changeDatum(double delta) {
-      throw new UnsupportedOperationException();
-    }
-    
-    @Override
-    public int getCount() {
-      return count;
-    }
-    
-    @Override
-    public double getAverage() {
-      return average;
-    }
-  }
-  
 }
\ No newline at end of file

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java?rev=928711&r1=928710&r2=928711&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java Mon Mar 29 10:59:47 2010
@@ -34,32 +34,36 @@ import org.apache.mahout.cf.taste.impl.m
  * <th>item_id_a</th>
  * <th>item_id_b</th>
  * <th>average_diff</th>
+ * <th>standard_deviation</th>
  * <th>count</th>
  * </tr>
  * <tr>
  * <td>123</td>
  * <td>234</td>
  * <td>0.5</td>
+ * <td>0.12</td>
  * <td>5</td>
  * </tr>
  * <tr>
  * <td>123</td>
  * <td>789</td>
  * <td>-1.33</td>
+ * <td>0.2</td>
  * <td>3</td>
  * </tr>
  * <tr>
  * <td>234</td>
  * <td>789</td>
  * <td>2.1</td>
+ * <td>1.03</td>
  * <td>1</td>
  * </tr>
  * </table>
  * 
  * <p>
  * <code>item_id_a</code> and <code>item_id_b</code> should have types compatible with the long primitive
- * type. <code>average_diff</code> must be compatible with <code>float</code> and <code>count</code> must be
- * compatible with <code>int</code>.
+ * type. <code>average_diff</code> and <code>standard_deviation</code> must be compatible with
+ * <code>float</code> and <code>count</code> must be compatible with <code>int</code>.
  * </p>
  * 
  * <p>
@@ -73,6 +77,7 @@ import org.apache.mahout.cf.taste.impl.m
  *   item_id_a BIGINT NOT NULL,
  *   item_id_b BIGINT NOT NULL,
  *   average_diff FLOAT NOT NULL,
+ *   standard_deviation FLOAT NOT NULL,
  *   count INT NOT NULL,
  *   PRIMARY KEY (item_id_a, item_id_b),
  *   INDEX (item_id_a),
@@ -87,8 +92,14 @@ public final class MySQLJDBCDiffStorage 
   private static final int DEFAULT_MIN_DIFF_COUNT = 2;
   
   public MySQLJDBCDiffStorage(AbstractJDBCDataModel dataModel) throws TasteException {
-    this(dataModel, DEFAULT_DIFF_TABLE, DEFAULT_ITEM_A_COLUMN, DEFAULT_ITEM_B_COLUMN, DEFAULT_COUNT_COLUMN,
-        DEFAULT_AVERAGE_DIFF_COLUMN, DEFAULT_MIN_DIFF_COUNT);
+    this(dataModel,
+         DEFAULT_DIFF_TABLE,
+         DEFAULT_ITEM_A_COLUMN,
+         DEFAULT_ITEM_B_COLUMN,
+         DEFAULT_COUNT_COLUMN,
+         DEFAULT_AVERAGE_DIFF_COLUMN,
+         DEFAULT_STDEV_COLUMN,
+         DEFAULT_MIN_DIFF_COUNT);
   }
   
   public MySQLJDBCDiffStorage(AbstractJDBCDataModel dataModel,
@@ -97,14 +108,17 @@ public final class MySQLJDBCDiffStorage 
                               String itemIDBColumn,
                               String countColumn,
                               String avgColumn,
+                              String stdevColumn,
                               int minDiffCount) throws TasteException {
     super(dataModel,
-    // getDiffSQL
-        "SELECT " + countColumn + ", " + avgColumn + " FROM " + diffsTable + " WHERE " + itemIDAColumn
+        // getDiffSQL
+        "SELECT " + countColumn + ", " + avgColumn + ", " + stdevColumn + " FROM "
+            + diffsTable + " WHERE " + itemIDAColumn
             + "=? AND " + itemIDBColumn + "=? UNION " + "SELECT " + countColumn + ", " + avgColumn + " FROM "
             + diffsTable + " WHERE " + itemIDAColumn + "=? AND " + itemIDBColumn + "=?",
         // getDiffsSQL
-        "SELECT " + countColumn + ", " + avgColumn + ", " + itemIDAColumn + " FROM " + diffsTable + ", "
+        "SELECT " + countColumn + ", " + avgColumn + ", " + stdevColumn + ", " + itemIDAColumn 
+            + " FROM " + diffsTable + ", "
             + dataModel.getPreferenceTable() + " WHERE " + itemIDBColumn + "=? AND " + itemIDAColumn + " = "
             + dataModel.getItemIDColumn() + " AND " + dataModel.getUserIDColumn() + "=? ORDER BY "
             + itemIDAColumn,
@@ -139,17 +153,20 @@ public final class MySQLJDBCDiffStorage 
         // deleteDiffsSQL
         "TRUNCATE " + diffsTable,
         // createDiffsSQL
-        "INSERT INTO " + diffsTable + " (" + itemIDAColumn + ", " + itemIDBColumn + ", " + avgColumn + ", "
-            + countColumn + ") SELECT prefsA." + dataModel.getItemIDColumn() + ", prefsB."
-            + dataModel.getItemIDColumn() + ',' + " AVG(prefsB." + dataModel.getPreferenceColumn()
-            + " - prefsA." + dataModel.getPreferenceColumn() + ")," + " COUNT(1) AS count FROM "
-            + dataModel.getPreferenceTable() + " prefsA, " + dataModel.getPreferenceTable()
-            + " prefsB WHERE prefsA." + dataModel.getUserIDColumn() + " = prefsB."
-            + dataModel.getUserIDColumn() + " AND prefsA." + dataModel.getItemIDColumn() + " < prefsB."
-            + dataModel.getItemIDColumn() + ' ' + " GROUP BY prefsA." + dataModel.getItemIDColumn()
-            + ", prefsB." + dataModel.getItemIDColumn() + " HAVING count >=?",
+        "INSERT INTO " + diffsTable + " (" + itemIDAColumn + ", " + itemIDBColumn + ", " + avgColumn
+            + ", " + stdevColumn + ", " + countColumn + ") SELECT prefsA." + dataModel.getItemIDColumn()
+            + ", prefsB." + dataModel.getItemIDColumn() + ", AVG(prefsB." + dataModel.getPreferenceColumn()
+            + " - prefsA." + dataModel.getPreferenceColumn() + "), STDDEV_POP(prefsB."
+            + dataModel.getPreferenceColumn() + " - prefsA." + dataModel.getPreferenceColumn()
+            + "), COUNT(1) AS count FROM " + dataModel.getPreferenceTable() + " prefsA, "
+            + dataModel.getPreferenceTable() + " prefsB WHERE prefsA." + dataModel.getUserIDColumn()
+            + " = prefsB." + dataModel.getUserIDColumn() + " AND prefsA." + dataModel.getItemIDColumn()
+            + " < prefsB." + dataModel.getItemIDColumn() + ' ' + " GROUP BY prefsA."
+            + dataModel.getItemIDColumn() + ", prefsB." + dataModel.getItemIDColumn()
+            + " HAVING count >= ?",
         // diffsExistSQL
-        "SELECT COUNT(1) FROM " + diffsTable, minDiffCount);
+        "SELECT COUNT(1) FROM " + diffsTable,
+        minDiffCount);
   }
   
   /**