You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sr...@apache.org on 2010/03/29 12:59:48 UTC
svn commit: r928711 - in
/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl:
common/ recommender/slopeone/jdbc/
Author: srowen
Date: Mon Mar 29 10:59:47 2010
New Revision: 928711
URL: http://svn.apache.org/viewvc?rev=928711&view=rev
Log:
Add standard deviation support to JDBC diff storage
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java
- copied, changed from r928681, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java
Copied: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java (from r928681, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java)
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java?p2=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java&p1=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java&r1=928681&r2=928711&rev=928711&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java Mon Mar 29 10:59:47 2010
@@ -21,80 +21,58 @@ import java.io.Serializable;
/**
* <p>
- * A simple class that can keep track of a running avearage of a series of numbers. One can add to or remove
- * from the series, as well as update a datum in the series. The class does not actually keep track of the
- * series of values, just its running average, so it doesn't even matter if you remove/change a value that
- * wasn't added.
+ * A simple class that represents a fixed value of an average and count. This is useful
+ * when an API needs to return {@link RunningAverage} but is not in a position to accept
+ * updates to it.
* </p>
*/
-public class FullRunningAverage implements RunningAverage, Serializable {
-
- private int count;
- private double average;
-
- public FullRunningAverage() {
- count = 0;
- average = Double.NaN;
+public class FixedRunningAverage implements RunningAverage, Serializable {
+
+ private final double average;
+ private final int count;
+
+ public FixedRunningAverage(double average, int count) {
+ this.average = average;
+ this.count = count;
}
-
+
/**
- * @param datum
- * new item to add to the running average
+ * @throws UnsupportedOperationException
*/
@Override
public synchronized void addDatum(double datum) {
- if (++count == 1) {
- average = datum;
- } else {
- average = average * (count - 1) / count + datum / count;
- }
+ throw new UnsupportedOperationException();
}
-
+
/**
- * @param datum
- * item to remove to the running average
- * @throws IllegalStateException
- * if count is 0
+ * @throws UnsupportedOperationException
*/
@Override
public synchronized void removeDatum(double datum) {
- if (count == 0) {
- throw new IllegalStateException();
- }
- if (--count == 0) {
- average = Double.NaN;
- } else {
- average = average * (count + 1) / count - datum / count;
- }
+ throw new UnsupportedOperationException();
}
-
+
/**
- * @param delta
- * amount by which to change a datum in the running average
- * @throws IllegalStateException
- * if count is 0
+ * @throws UnsupportedOperationException
*/
@Override
public synchronized void changeDatum(double delta) {
- if (count == 0) {
- throw new IllegalStateException();
- }
- average += delta / count;
+ throw new UnsupportedOperationException();
}
-
+
@Override
public synchronized int getCount() {
return count;
}
-
+
@Override
public synchronized double getAverage() {
return average;
}
-
+
@Override
public synchronized String toString() {
return String.valueOf(average);
}
-
-}
+
+}
\ No newline at end of file
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java?rev=928711&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java Mon Mar 29 10:59:47 2010
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+/**
+ * <p>
+ * A simple class that represents a fixed value of an average, count and standard deviation. This is useful
+ * when an API needs to return {@link RunningAverageAndStdDev} but is not in a position to accept
+ * updates to it.
+ * </p>
+ */
+public final class FixedRunningAverageAndStdDev extends FixedRunningAverage implements RunningAverageAndStdDev {
+
+ private final double stdDev;
+
+ public FixedRunningAverageAndStdDev(double average, double stdDev, int count) {
+ super(average, count);
+ this.stdDev = stdDev;
+ }
+
+ @Override
+ public synchronized String toString() {
+ return super.toString() + ',' + stdDev;
+ }
+
+ @Override
+ public double getStandardDeviation() {
+ return stdDev;
+ }
+
+}
\ No newline at end of file
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java?rev=928711&r1=928710&r2=928711&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/AbstractJDBCDiffStorage.java Mon Mar 29 10:59:47 2010
@@ -30,6 +30,8 @@ import javax.sql.DataSource;
import org.apache.mahout.cf.taste.common.Refreshable;
import org.apache.mahout.cf.taste.common.TasteException;
import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.FixedRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.FixedRunningAverageAndStdDev;
import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
import org.apache.mahout.cf.taste.impl.common.RunningAverage;
import org.apache.mahout.cf.taste.impl.common.jdbc.AbstractJDBCComponent;
@@ -45,7 +47,7 @@ import org.slf4j.LoggerFactory;
* A {@link DiffStorage} which stores diffs in a database. Database-specific implementations subclass this
* abstract class. Note that this implementation has a fairly particular dependence on the
* {@link org.apache.mahout.cf.taste.model.DataModel} used; it needs a {@link JDBCDataModel} attached to the
- * same database since its efficent operation depends on accessing preference data in the database directly.
+ * same database since its efficient operation depends on accessing preference data in the database directly.
* </p>
*/
public abstract class AbstractJDBCDiffStorage extends AbstractJDBCComponent implements DiffStorage {
@@ -57,7 +59,8 @@ public abstract class AbstractJDBCDiffSt
public static final String DEFAULT_ITEM_B_COLUMN = "item_id_b";
public static final String DEFAULT_COUNT_COLUMN = "count";
public static final String DEFAULT_AVERAGE_DIFF_COLUMN = "average_diff";
-
+ public static final String DEFAULT_STDEV_COLUMN = "standard_deviation";
+
private final DataSource dataSource;
private final String getDiffSQL;
private final String getDiffsSQL;
@@ -140,7 +143,7 @@ public abstract class AbstractJDBCDiffSt
stmt.setLong(4, itemID1);
log.debug("Executing SQL query: {}", getDiffSQL);
rs = stmt.executeQuery();
- return rs.next() ? new FixedRunningAverage(rs.getInt(1), rs.getDouble(2)) : null;
+ return rs.next() ? new FixedRunningAverageAndStdDev(rs.getDouble(2), rs.getDouble(3), rs.getInt(1)) : null;
} catch (SQLException sqle) {
log.warn("Exception while retrieving diff", sqle);
throw new TasteException(sqle);
@@ -175,7 +178,7 @@ public abstract class AbstractJDBCDiffSt
i++;
// result[i] is null for these values of i
}
- result[i] = new FixedRunningAverage(rs.getInt(1), rs.getDouble(2));
+ result[i] = new FixedRunningAverageAndStdDev(rs.getDouble(2), rs.getDouble(3), rs.getInt(1));
i++;
}
} catch (SQLException sqle) {
@@ -204,7 +207,7 @@ public abstract class AbstractJDBCDiffSt
if (rs.next()) {
int count = rs.getInt(1);
if (count > 0) {
- return new FixedRunningAverage(count, rs.getDouble(2));
+ return new FixedRunningAverage(rs.getDouble(2), count);
}
}
return null;
@@ -215,7 +218,12 @@ public abstract class AbstractJDBCDiffSt
IOUtils.quietClose(rs, stmt, conn);
}
}
-
+
+ /**
+ * Note that this implementation does <em>not</em> update standard deviations. This would
+ * be expensive relative to the value of slightly adjusting these values, which are merely
+ * used as weighted. Rebuilding the diffs table will update standard deviations.
+ */
@Override
public void updateItemPref(long itemID, float prefDelta, boolean remove) throws TasteException {
Connection conn = null;
@@ -330,41 +338,4 @@ public abstract class AbstractJDBCDiffSt
public void refresh(Collection<Refreshable> alreadyRefreshed) {
refreshHelper.refresh(alreadyRefreshed);
}
-
- private static class FixedRunningAverage implements RunningAverage {
-
- private final int count;
- private final double average;
-
- private FixedRunningAverage(int count, double average) {
- this.count = count;
- this.average = average;
- }
-
- @Override
- public void addDatum(double datum) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void removeDatum(double datum) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public void changeDatum(double delta) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public int getCount() {
- return count;
- }
-
- @Override
- public double getAverage() {
- return average;
- }
- }
-
}
\ No newline at end of file
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java?rev=928711&r1=928710&r2=928711&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/slopeone/jdbc/MySQLJDBCDiffStorage.java Mon Mar 29 10:59:47 2010
@@ -34,32 +34,36 @@ import org.apache.mahout.cf.taste.impl.m
* <th>item_id_a</th>
* <th>item_id_b</th>
* <th>average_diff</th>
+ * <th>standard_deviation</th>
* <th>count</th>
* </tr>
* <tr>
* <td>123</td>
* <td>234</td>
* <td>0.5</td>
+ * <td>0.12</td>
* <td>5</td>
* </tr>
* <tr>
* <td>123</td>
* <td>789</td>
* <td>-1.33</td>
+ * <td>0.2</td>
* <td>3</td>
* </tr>
* <tr>
* <td>234</td>
* <td>789</td>
* <td>2.1</td>
+ * <td>1.03</td>
* <td>1</td>
* </tr>
* </table>
*
* <p>
* <code>item_id_a</code> and <code>item_id_b</code> should have types compatible with the long primitive
- * type. <code>average_diff</code> must be compatible with <code>float</code> and <code>count</code> must be
- * compatible with <code>int</code>.
+ * type. <code>average_diff</code> and <code>standard_deviation</code> must be compatible with
+ * <code>float</code> and <code>count</code> must be compatible with <code>int</code>.
* </p>
*
* <p>
@@ -73,6 +77,7 @@ import org.apache.mahout.cf.taste.impl.m
* item_id_a BIGINT NOT NULL,
* item_id_b BIGINT NOT NULL,
* average_diff FLOAT NOT NULL,
+ * standard_deviation FLOAT NOT NULL,
* count INT NOT NULL,
* PRIMARY KEY (item_id_a, item_id_b),
* INDEX (item_id_a),
@@ -87,8 +92,14 @@ public final class MySQLJDBCDiffStorage
private static final int DEFAULT_MIN_DIFF_COUNT = 2;
public MySQLJDBCDiffStorage(AbstractJDBCDataModel dataModel) throws TasteException {
- this(dataModel, DEFAULT_DIFF_TABLE, DEFAULT_ITEM_A_COLUMN, DEFAULT_ITEM_B_COLUMN, DEFAULT_COUNT_COLUMN,
- DEFAULT_AVERAGE_DIFF_COLUMN, DEFAULT_MIN_DIFF_COUNT);
+ this(dataModel,
+ DEFAULT_DIFF_TABLE,
+ DEFAULT_ITEM_A_COLUMN,
+ DEFAULT_ITEM_B_COLUMN,
+ DEFAULT_COUNT_COLUMN,
+ DEFAULT_AVERAGE_DIFF_COLUMN,
+ DEFAULT_STDEV_COLUMN,
+ DEFAULT_MIN_DIFF_COUNT);
}
public MySQLJDBCDiffStorage(AbstractJDBCDataModel dataModel,
@@ -97,14 +108,17 @@ public final class MySQLJDBCDiffStorage
String itemIDBColumn,
String countColumn,
String avgColumn,
+ String stdevColumn,
int minDiffCount) throws TasteException {
super(dataModel,
- // getDiffSQL
- "SELECT " + countColumn + ", " + avgColumn + " FROM " + diffsTable + " WHERE " + itemIDAColumn
+ // getDiffSQL
+ "SELECT " + countColumn + ", " + avgColumn + ", " + stdevColumn + " FROM "
+ + diffsTable + " WHERE " + itemIDAColumn
+ "=? AND " + itemIDBColumn + "=? UNION " + "SELECT " + countColumn + ", " + avgColumn + " FROM "
+ diffsTable + " WHERE " + itemIDAColumn + "=? AND " + itemIDBColumn + "=?",
// getDiffsSQL
- "SELECT " + countColumn + ", " + avgColumn + ", " + itemIDAColumn + " FROM " + diffsTable + ", "
+ "SELECT " + countColumn + ", " + avgColumn + ", " + stdevColumn + ", " + itemIDAColumn
+ + " FROM " + diffsTable + ", "
+ dataModel.getPreferenceTable() + " WHERE " + itemIDBColumn + "=? AND " + itemIDAColumn + " = "
+ dataModel.getItemIDColumn() + " AND " + dataModel.getUserIDColumn() + "=? ORDER BY "
+ itemIDAColumn,
@@ -139,17 +153,20 @@ public final class MySQLJDBCDiffStorage
// deleteDiffsSQL
"TRUNCATE " + diffsTable,
// createDiffsSQL
- "INSERT INTO " + diffsTable + " (" + itemIDAColumn + ", " + itemIDBColumn + ", " + avgColumn + ", "
- + countColumn + ") SELECT prefsA." + dataModel.getItemIDColumn() + ", prefsB."
- + dataModel.getItemIDColumn() + ',' + " AVG(prefsB." + dataModel.getPreferenceColumn()
- + " - prefsA." + dataModel.getPreferenceColumn() + ")," + " COUNT(1) AS count FROM "
- + dataModel.getPreferenceTable() + " prefsA, " + dataModel.getPreferenceTable()
- + " prefsB WHERE prefsA." + dataModel.getUserIDColumn() + " = prefsB."
- + dataModel.getUserIDColumn() + " AND prefsA." + dataModel.getItemIDColumn() + " < prefsB."
- + dataModel.getItemIDColumn() + ' ' + " GROUP BY prefsA." + dataModel.getItemIDColumn()
- + ", prefsB." + dataModel.getItemIDColumn() + " HAVING count >=?",
+ "INSERT INTO " + diffsTable + " (" + itemIDAColumn + ", " + itemIDBColumn + ", " + avgColumn
+ + ", " + stdevColumn + ", " + countColumn + ") SELECT prefsA." + dataModel.getItemIDColumn()
+ + ", prefsB." + dataModel.getItemIDColumn() + ", AVG(prefsB." + dataModel.getPreferenceColumn()
+ + " - prefsA." + dataModel.getPreferenceColumn() + "), STDDEV_POP(prefsB."
+ + dataModel.getPreferenceColumn() + " - prefsA." + dataModel.getPreferenceColumn()
+ + "), COUNT(1) AS count FROM " + dataModel.getPreferenceTable() + " prefsA, "
+ + dataModel.getPreferenceTable() + " prefsB WHERE prefsA." + dataModel.getUserIDColumn()
+ + " = prefsB." + dataModel.getUserIDColumn() + " AND prefsA." + dataModel.getItemIDColumn()
+ + " < prefsB." + dataModel.getItemIDColumn() + ' ' + " GROUP BY prefsA."
+ + dataModel.getItemIDColumn() + ", prefsB." + dataModel.getItemIDColumn()
+ + " HAVING count >= ?",
// diffsExistSQL
- "SELECT COUNT(1) FROM " + diffsTable, minDiffCount);
+ "SELECT COUNT(1) FROM " + diffsTable,
+ minDiffCount);
}
/**