You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/28 14:54:47 UTC
[19/51] [partial] mahout git commit: NO-JIRA Clean up MR refactor
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java
new file mode 100644
index 0000000..08aa5ae
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java
@@ -0,0 +1,97 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.common.RandomUtils;
+
+/**
+ * Produces random recommendations and preference estimates. This is likely only useful as a novelty and for
+ * benchmarking.
+ */
+public final class RandomRecommender extends AbstractRecommender {
+
+ private final Random random = RandomUtils.getRandom();
+ private final float minPref;
+ private final float maxPref;
+
+ public RandomRecommender(DataModel dataModel) throws TasteException {
+ super(dataModel);
+ float maxPref = Float.NEGATIVE_INFINITY;
+ float minPref = Float.POSITIVE_INFINITY;
+ LongPrimitiveIterator userIterator = dataModel.getUserIDs();
+ while (userIterator.hasNext()) {
+ long userID = userIterator.next();
+ PreferenceArray prefs = dataModel.getPreferencesFromUser(userID);
+ for (int i = 0; i < prefs.length(); i++) {
+ float prefValue = prefs.getValue(i);
+ if (prefValue < minPref) {
+ minPref = prefValue;
+ }
+ if (prefValue > maxPref) {
+ maxPref = prefValue;
+ }
+ }
+ }
+ this.minPref = minPref;
+ this.maxPref = maxPref;
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ DataModel dataModel = getDataModel();
+ int numItems = dataModel.getNumItems();
+ List<RecommendedItem> result = new ArrayList<>(howMany);
+ while (result.size() < howMany) {
+ LongPrimitiveIterator it = dataModel.getItemIDs();
+ it.skip(random.nextInt(numItems));
+ long itemID = it.next();
+ if (includeKnownItems || dataModel.getPreferenceValue(userID, itemID) == null) {
+ result.add(new GenericRecommendedItem(itemID, randomPref()));
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) {
+ return randomPref();
+ }
+
+ private float randomPref() {
+ return minPref + random.nextFloat() * (maxPref - minPref);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ getDataModel().refresh(alreadyRefreshed);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java
new file mode 100644
index 0000000..623a60b
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import com.google.common.base.Preconditions;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveArrayIterator;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.SamplingLongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.iterator.FixedSizeSamplingIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Iterator;
+
+/**
+ * <p>Returns all items that have not been rated by the user <em>(3)</em> and that were preferred by another user
+ * <em>(2)</em> that has preferred at least one item <em>(1)</em> that the current user has preferred too.</p>
+ *
+ * <p>This strategy uses sampling to limit the number of items that are considered, by sampling three different
+ * things, noted above:</p>
+ *
+ * <ol>
+ * <li>The items that the user has preferred</li>
+ * <li>The users who also prefer each of those items</li>
+ * <li>The items those users also prefer</li>
+ * </ol>
+ *
+ * <p>There is a maximum associated with each of these three things; if the number of items or users exceeds
+ * that max, it is sampled so that the expected number of items or users actually used in that part of the
+ * computation is equal to the max.</p>
+ *
+ * <p>Three arguments control these three maxima. Each is a "factor" f, which establishes the max at
+ * f * log2(n), where n is the number of users or items in the data. For example if factor #2 is 5,
+ * which controls the number of users sampled per item, then 5 * log2(# users) is the maximum for this
+ * part of the computation.</p>
+ *
+ * <p>Each can be set to not do any limiting with value {@link #NO_LIMIT_FACTOR}.</p>
+ */
+public class SamplingCandidateItemsStrategy extends AbstractCandidateItemsStrategy {
+
+ private static final Logger log = LoggerFactory.getLogger(SamplingCandidateItemsStrategy.class);
+
+ /**
+ * Default factor used if not otherwise specified, for all limits. (30).
+ */
+ public static final int DEFAULT_FACTOR = 30;
+ /**
+ * Specify this value as a factor to mean no limit.
+ */
+ public static final int NO_LIMIT_FACTOR = Integer.MAX_VALUE;
+ private static final int MAX_LIMIT = Integer.MAX_VALUE;
+ private static final double LOG2 = Math.log(2.0);
+
+ private final int maxItems;
+ private final int maxUsersPerItem;
+ private final int maxItemsPerUser;
+
+ /**
+ * Defaults to using no limit ({@link #NO_LIMIT_FACTOR}) for all factors, except
+ * {@code candidatesPerUserFactor} which defaults to {@link #DEFAULT_FACTOR}.
+ *
+ * @see #SamplingCandidateItemsStrategy(int, int, int, int, int)
+ */
+ public SamplingCandidateItemsStrategy(int numUsers, int numItems) {
+ this(DEFAULT_FACTOR, DEFAULT_FACTOR, DEFAULT_FACTOR, numUsers, numItems);
+ }
+
+ /**
+ * @param itemsFactor factor controlling max items considered for a user
+ * @param usersPerItemFactor factor controlling max users considered for each of those items
+ * @param candidatesPerUserFactor factor controlling max candidate items considered from each of those users
+ * @param numUsers number of users currently in the data
+ * @param numItems number of items in the data
+ */
+ public SamplingCandidateItemsStrategy(int itemsFactor,
+ int usersPerItemFactor,
+ int candidatesPerUserFactor,
+ int numUsers,
+ int numItems) {
+ Preconditions.checkArgument(itemsFactor > 0, "itemsFactor must be greater then 0!");
+ Preconditions.checkArgument(usersPerItemFactor > 0, "usersPerItemFactor must be greater then 0!");
+ Preconditions.checkArgument(candidatesPerUserFactor > 0, "candidatesPerUserFactor must be greater then 0!");
+ Preconditions.checkArgument(numUsers > 0, "numUsers must be greater then 0!");
+ Preconditions.checkArgument(numItems > 0, "numItems must be greater then 0!");
+ maxItems = computeMaxFrom(itemsFactor, numItems);
+ maxUsersPerItem = computeMaxFrom(usersPerItemFactor, numUsers);
+ maxItemsPerUser = computeMaxFrom(candidatesPerUserFactor, numItems);
+ log.debug("maxItems {}, maxUsersPerItem {}, maxItemsPerUser {}", maxItems, maxUsersPerItem, maxItemsPerUser);
+ }
+
+ private static int computeMaxFrom(int factor, int numThings) {
+ if (factor == NO_LIMIT_FACTOR) {
+ return MAX_LIMIT;
+ }
+ long max = (long) (factor * (1.0 + Math.log(numThings) / LOG2));
+ return max > MAX_LIMIT ? MAX_LIMIT : (int) max;
+ }
+
+ @Override
+ protected FastIDSet doGetCandidateItems(long[] preferredItemIDs, DataModel dataModel, boolean includeKnownItems)
+ throws TasteException {
+ LongPrimitiveIterator preferredItemIDsIterator = new LongPrimitiveArrayIterator(preferredItemIDs);
+ if (preferredItemIDs.length > maxItems) {
+ double samplingRate = (double) maxItems / preferredItemIDs.length;
+// log.info("preferredItemIDs.length {}, samplingRate {}", preferredItemIDs.length, samplingRate);
+ preferredItemIDsIterator =
+ new SamplingLongPrimitiveIterator(preferredItemIDsIterator, samplingRate);
+ }
+ FastIDSet possibleItemsIDs = new FastIDSet();
+ while (preferredItemIDsIterator.hasNext()) {
+ long itemID = preferredItemIDsIterator.nextLong();
+ PreferenceArray prefs = dataModel.getPreferencesForItem(itemID);
+ int prefsLength = prefs.length();
+ if (prefsLength > maxUsersPerItem) {
+ Iterator<Preference> sampledPrefs =
+ new FixedSizeSamplingIterator<>(maxUsersPerItem, prefs.iterator());
+ while (sampledPrefs.hasNext()) {
+ addSomeOf(possibleItemsIDs, dataModel.getItemIDsFromUser(sampledPrefs.next().getUserID()));
+ }
+ } else {
+ for (int i = 0; i < prefsLength; i++) {
+ addSomeOf(possibleItemsIDs, dataModel.getItemIDsFromUser(prefs.getUserID(i)));
+ }
+ }
+ }
+ if (!includeKnownItems) {
+ possibleItemsIDs.removeAll(preferredItemIDs);
+ }
+ return possibleItemsIDs;
+ }
+
+ private void addSomeOf(FastIDSet possibleItemIDs, FastIDSet itemIDs) {
+ if (itemIDs.size() > maxItemsPerUser) {
+ LongPrimitiveIterator it =
+ new SamplingLongPrimitiveIterator(itemIDs.iterator(), (double) maxItemsPerUser / itemIDs.size());
+ while (it.hasNext()) {
+ possibleItemIDs.add(it.nextLong());
+ }
+ } else {
+ possibleItemIDs.addAll(itemIDs);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java
new file mode 100644
index 0000000..c6d417f
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.common.RandomUtils;
+
+/** Simply encapsulates a user and a similarity value. */
+public final class SimilarUser implements Comparable<SimilarUser> {
+
+ private final long userID;
+ private final double similarity;
+
+ public SimilarUser(long userID, double similarity) {
+ this.userID = userID;
+ this.similarity = similarity;
+ }
+
+ long getUserID() {
+ return userID;
+ }
+
+ double getSimilarity() {
+ return similarity;
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) userID ^ RandomUtils.hashDouble(similarity);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof SimilarUser)) {
+ return false;
+ }
+ SimilarUser other = (SimilarUser) o;
+ return userID == other.getUserID() && similarity == other.getSimilarity();
+ }
+
+ @Override
+ public String toString() {
+ return "SimilarUser[user:" + userID + ", similarity:" + similarity + ']';
+ }
+
+ /** Defines an ordering from most similar to least similar. */
+ @Override
+ public int compareTo(SimilarUser other) {
+ double otherSimilarity = other.getSimilarity();
+ if (similarity > otherSimilarity) {
+ return -1;
+ }
+ if (similarity < otherSimilarity) {
+ return 1;
+ }
+ long otherUserID = other.getUserID();
+ if (userID < otherUserID) {
+ return -1;
+ }
+ if (userID > otherUserID) {
+ return 1;
+ }
+ return 0;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java
new file mode 100644
index 0000000..f7b4385
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java
@@ -0,0 +1,211 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.PriorityQueue;
+import java.util.Queue;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.GenericUserSimilarity;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+
+/**
+ * <p>
+ * A simple class that refactors the "find top N things" logic that is used in several places.
+ * </p>
+ */
+public final class TopItems {
+
+ private static final long[] NO_IDS = new long[0];
+
+ private TopItems() { }
+
+ public static List<RecommendedItem> getTopItems(int howMany,
+ LongPrimitiveIterator possibleItemIDs,
+ IDRescorer rescorer,
+ Estimator<Long> estimator) throws TasteException {
+ Preconditions.checkArgument(possibleItemIDs != null, "possibleItemIDs is null");
+ Preconditions.checkArgument(estimator != null, "estimator is null");
+
+ Queue<RecommendedItem> topItems = new PriorityQueue<>(howMany + 1,
+ Collections.reverseOrder(ByValueRecommendedItemComparator.getInstance()));
+ boolean full = false;
+ double lowestTopValue = Double.NEGATIVE_INFINITY;
+ while (possibleItemIDs.hasNext()) {
+ long itemID = possibleItemIDs.next();
+ if (rescorer == null || !rescorer.isFiltered(itemID)) {
+ double preference;
+ try {
+ preference = estimator.estimate(itemID);
+ } catch (NoSuchItemException nsie) {
+ continue;
+ }
+ double rescoredPref = rescorer == null ? preference : rescorer.rescore(itemID, preference);
+ if (!Double.isNaN(rescoredPref) && (!full || rescoredPref > lowestTopValue)) {
+ topItems.add(new GenericRecommendedItem(itemID, (float) rescoredPref));
+ if (full) {
+ topItems.poll();
+ } else if (topItems.size() > howMany) {
+ full = true;
+ topItems.poll();
+ }
+ lowestTopValue = topItems.peek().getValue();
+ }
+ }
+ }
+ int size = topItems.size();
+ if (size == 0) {
+ return Collections.emptyList();
+ }
+ List<RecommendedItem> result = new ArrayList<>(size);
+ result.addAll(topItems);
+ Collections.sort(result, ByValueRecommendedItemComparator.getInstance());
+ return result;
+ }
+
+ public static long[] getTopUsers(int howMany,
+ LongPrimitiveIterator allUserIDs,
+ IDRescorer rescorer,
+ Estimator<Long> estimator) throws TasteException {
+ Queue<SimilarUser> topUsers = new PriorityQueue<>(howMany + 1, Collections.reverseOrder());
+ boolean full = false;
+ double lowestTopValue = Double.NEGATIVE_INFINITY;
+ while (allUserIDs.hasNext()) {
+ long userID = allUserIDs.next();
+ if (rescorer != null && rescorer.isFiltered(userID)) {
+ continue;
+ }
+ double similarity;
+ try {
+ similarity = estimator.estimate(userID);
+ } catch (NoSuchUserException nsue) {
+ continue;
+ }
+ double rescoredSimilarity = rescorer == null ? similarity : rescorer.rescore(userID, similarity);
+ if (!Double.isNaN(rescoredSimilarity) && (!full || rescoredSimilarity > lowestTopValue)) {
+ topUsers.add(new SimilarUser(userID, rescoredSimilarity));
+ if (full) {
+ topUsers.poll();
+ } else if (topUsers.size() > howMany) {
+ full = true;
+ topUsers.poll();
+ }
+ lowestTopValue = topUsers.peek().getSimilarity();
+ }
+ }
+ int size = topUsers.size();
+ if (size == 0) {
+ return NO_IDS;
+ }
+ List<SimilarUser> sorted = new ArrayList<>(size);
+ sorted.addAll(topUsers);
+ Collections.sort(sorted);
+ long[] result = new long[size];
+ int i = 0;
+ for (SimilarUser similarUser : sorted) {
+ result[i++] = similarUser.getUserID();
+ }
+ return result;
+ }
+
+ /**
+ * <p>
+ * Thanks to tsmorton for suggesting this functionality and writing part of the code.
+ * </p>
+ *
+ * @see GenericItemSimilarity#GenericItemSimilarity(Iterable, int)
+ * @see GenericItemSimilarity#GenericItemSimilarity(org.apache.mahout.cf.taste.similarity.ItemSimilarity,
+ * org.apache.mahout.cf.taste.model.DataModel, int)
+ */
+ public static List<GenericItemSimilarity.ItemItemSimilarity> getTopItemItemSimilarities(
+ int howMany, Iterator<GenericItemSimilarity.ItemItemSimilarity> allSimilarities) {
+
+ Queue<GenericItemSimilarity.ItemItemSimilarity> topSimilarities
+ = new PriorityQueue<>(howMany + 1, Collections.reverseOrder());
+ boolean full = false;
+ double lowestTopValue = Double.NEGATIVE_INFINITY;
+ while (allSimilarities.hasNext()) {
+ GenericItemSimilarity.ItemItemSimilarity similarity = allSimilarities.next();
+ double value = similarity.getValue();
+ if (!Double.isNaN(value) && (!full || value > lowestTopValue)) {
+ topSimilarities.add(similarity);
+ if (full) {
+ topSimilarities.poll();
+ } else if (topSimilarities.size() > howMany) {
+ full = true;
+ topSimilarities.poll();
+ }
+ lowestTopValue = topSimilarities.peek().getValue();
+ }
+ }
+ int size = topSimilarities.size();
+ if (size == 0) {
+ return Collections.emptyList();
+ }
+ List<GenericItemSimilarity.ItemItemSimilarity> result = new ArrayList<>(size);
+ result.addAll(topSimilarities);
+ Collections.sort(result);
+ return result;
+ }
+
+ public static List<GenericUserSimilarity.UserUserSimilarity> getTopUserUserSimilarities(
+ int howMany, Iterator<GenericUserSimilarity.UserUserSimilarity> allSimilarities) {
+
+ Queue<GenericUserSimilarity.UserUserSimilarity> topSimilarities
+ = new PriorityQueue<>(howMany + 1, Collections.reverseOrder());
+ boolean full = false;
+ double lowestTopValue = Double.NEGATIVE_INFINITY;
+ while (allSimilarities.hasNext()) {
+ GenericUserSimilarity.UserUserSimilarity similarity = allSimilarities.next();
+ double value = similarity.getValue();
+ if (!Double.isNaN(value) && (!full || value > lowestTopValue)) {
+ topSimilarities.add(similarity);
+ if (full) {
+ topSimilarities.poll();
+ } else if (topSimilarities.size() > howMany) {
+ full = true;
+ topSimilarities.poll();
+ }
+ lowestTopValue = topSimilarities.peek().getValue();
+ }
+ }
+ int size = topSimilarities.size();
+ if (size == 0) {
+ return Collections.emptyList();
+ }
+ List<GenericUserSimilarity.UserUserSimilarity> result = new ArrayList<>(size);
+ result.addAll(topSimilarities);
+ Collections.sort(result);
+ return result;
+ }
+
+ public interface Estimator<T> {
+ double estimate(T thing) throws TasteException;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
new file mode 100644
index 0000000..0ba5139
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
@@ -0,0 +1,312 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
+import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * factorizes the rating matrix using "Alternating-Least-Squares with Weighted-λ-Regularization" as described in
+ * <a href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf">
+ * "Large-scale Collaborative Filtering for the Netflix Prize"</a>
+ *
+ * also supports the implicit feedback variant of this approach as described in "Collaborative Filtering for Implicit
+ * Feedback Datasets" available at http://research.yahoo.com/pub/2433
+ */
+public class ALSWRFactorizer extends AbstractFactorizer {
+
+ private final DataModel dataModel;
+
+ /** number of features used to compute this factorization */
+ private final int numFeatures;
+ /** parameter to control the regularization */
+ private final double lambda;
+ /** number of iterations */
+ private final int numIterations;
+
+ private final boolean usesImplicitFeedback;
+ /** confidence weighting parameter, only necessary when working with implicit feedback */
+ private final double alpha;
+
+ private final int numTrainingThreads;
+
+ private static final double DEFAULT_ALPHA = 40;
+
+ private static final Logger log = LoggerFactory.getLogger(ALSWRFactorizer.class);
+
+ public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ boolean usesImplicitFeedback, double alpha, int numTrainingThreads) throws TasteException {
+ super(dataModel);
+ this.dataModel = dataModel;
+ this.numFeatures = numFeatures;
+ this.lambda = lambda;
+ this.numIterations = numIterations;
+ this.usesImplicitFeedback = usesImplicitFeedback;
+ this.alpha = alpha;
+ this.numTrainingThreads = numTrainingThreads;
+ }
+
+ public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ boolean usesImplicitFeedback, double alpha) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, usesImplicitFeedback, alpha,
+ Runtime.getRuntime().availableProcessors());
+ }
+
+ public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, false, DEFAULT_ALPHA);
+ }
+
+ static class Features {
+
+ private final DataModel dataModel;
+ private final int numFeatures;
+
+ private final double[][] M;
+ private final double[][] U;
+
+ Features(ALSWRFactorizer factorizer) throws TasteException {
+ dataModel = factorizer.dataModel;
+ numFeatures = factorizer.numFeatures;
+ Random random = RandomUtils.getRandom();
+ M = new double[dataModel.getNumItems()][numFeatures];
+ LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
+ while (itemIDsIterator.hasNext()) {
+ long itemID = itemIDsIterator.nextLong();
+ int itemIDIndex = factorizer.itemIndex(itemID);
+ M[itemIDIndex][0] = averateRating(itemID);
+ for (int feature = 1; feature < numFeatures; feature++) {
+ M[itemIDIndex][feature] = random.nextDouble() * 0.1;
+ }
+ }
+ U = new double[dataModel.getNumUsers()][numFeatures];
+ }
+
+ double[][] getM() {
+ return M;
+ }
+
+ double[][] getU() {
+ return U;
+ }
+
+ Vector getUserFeatureColumn(int index) {
+ return new DenseVector(U[index]);
+ }
+
+ Vector getItemFeatureColumn(int index) {
+ return new DenseVector(M[index]);
+ }
+
+ void setFeatureColumnInU(int idIndex, Vector vector) {
+ setFeatureColumn(U, idIndex, vector);
+ }
+
+ void setFeatureColumnInM(int idIndex, Vector vector) {
+ setFeatureColumn(M, idIndex, vector);
+ }
+
+ protected void setFeatureColumn(double[][] matrix, int idIndex, Vector vector) {
+ for (int feature = 0; feature < numFeatures; feature++) {
+ matrix[idIndex][feature] = vector.get(feature);
+ }
+ }
+
+ protected double averateRating(long itemID) throws TasteException {
+ PreferenceArray prefs = dataModel.getPreferencesForItem(itemID);
+ RunningAverage avg = new FullRunningAverage();
+ for (Preference pref : prefs) {
+ avg.addDatum(pref.getValue());
+ }
+ return avg.getAverage();
+ }
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ log.info("starting to compute the factorization...");
+ final Features features = new Features(this);
+
+ /* feature maps necessary for solving for implicit feedback */
+ OpenIntObjectHashMap<Vector> userY = null;
+ OpenIntObjectHashMap<Vector> itemY = null;
+
+ if (usesImplicitFeedback) {
+ userY = userFeaturesMapping(dataModel.getUserIDs(), dataModel.getNumUsers(), features.getU());
+ itemY = itemFeaturesMapping(dataModel.getItemIDs(), dataModel.getNumItems(), features.getM());
+ }
+
+ for (int iteration = 0; iteration < numIterations; iteration++) {
+ log.info("iteration {}", iteration);
+
+ /* fix M - compute U */
+ ExecutorService queue = createQueue();
+ LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
+ try {
+
+ final ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackSolver = usesImplicitFeedback
+ ? new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, itemY, numTrainingThreads)
+ : null;
+
+ while (userIDsIterator.hasNext()) {
+ final long userID = userIDsIterator.nextLong();
+ final LongPrimitiveIterator itemIDsFromUser = dataModel.getItemIDsFromUser(userID).iterator();
+ final PreferenceArray userPrefs = dataModel.getPreferencesFromUser(userID);
+ queue.execute(new Runnable() {
+ @Override
+ public void run() {
+ List<Vector> featureVectors = new ArrayList<>();
+ while (itemIDsFromUser.hasNext()) {
+ long itemID = itemIDsFromUser.nextLong();
+ featureVectors.add(features.getItemFeatureColumn(itemIndex(itemID)));
+ }
+
+ Vector userFeatures = usesImplicitFeedback
+ ? implicitFeedbackSolver.solve(sparseUserRatingVector(userPrefs))
+ : AlternatingLeastSquaresSolver.solve(featureVectors, ratingVector(userPrefs), lambda, numFeatures);
+
+ features.setFeatureColumnInU(userIndex(userID), userFeatures);
+ }
+ });
+ }
+ } finally {
+ queue.shutdown();
+ try {
+ queue.awaitTermination(dataModel.getNumUsers(), TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ log.warn("Error when computing user features", e);
+ }
+ }
+
+ /* fix U - compute M */
+ queue = createQueue();
+ LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
+ try {
+
+ final ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackSolver = usesImplicitFeedback
+ ? new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, userY, numTrainingThreads)
+ : null;
+
+ while (itemIDsIterator.hasNext()) {
+ final long itemID = itemIDsIterator.nextLong();
+ final PreferenceArray itemPrefs = dataModel.getPreferencesForItem(itemID);
+ queue.execute(new Runnable() {
+ @Override
+ public void run() {
+ List<Vector> featureVectors = new ArrayList<>();
+ for (Preference pref : itemPrefs) {
+ long userID = pref.getUserID();
+ featureVectors.add(features.getUserFeatureColumn(userIndex(userID)));
+ }
+
+ Vector itemFeatures = usesImplicitFeedback
+ ? implicitFeedbackSolver.solve(sparseItemRatingVector(itemPrefs))
+ : AlternatingLeastSquaresSolver.solve(featureVectors, ratingVector(itemPrefs), lambda, numFeatures);
+
+ features.setFeatureColumnInM(itemIndex(itemID), itemFeatures);
+ }
+ });
+ }
+ } finally {
+ queue.shutdown();
+ try {
+ queue.awaitTermination(dataModel.getNumItems(), TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ log.warn("Error when computing item features", e);
+ }
+ }
+ }
+
+ log.info("finished computation of the factorization...");
+ return createFactorization(features.getU(), features.getM());
+ }
+
+ protected ExecutorService createQueue() {
+ return Executors.newFixedThreadPool(numTrainingThreads);
+ }
+
+ protected static Vector ratingVector(PreferenceArray prefs) {
+ double[] ratings = new double[prefs.length()];
+ for (int n = 0; n < prefs.length(); n++) {
+ ratings[n] = prefs.get(n).getValue();
+ }
+ return new DenseVector(ratings, true);
+ }
+
+ //TODO find a way to get rid of the object overhead here
+ protected OpenIntObjectHashMap<Vector> itemFeaturesMapping(LongPrimitiveIterator itemIDs, int numItems,
+ double[][] featureMatrix) {
+ OpenIntObjectHashMap<Vector> mapping = new OpenIntObjectHashMap<>(numItems);
+ while (itemIDs.hasNext()) {
+ long itemID = itemIDs.next();
+ int itemIndex = itemIndex(itemID);
+ mapping.put(itemIndex, new DenseVector(featureMatrix[itemIndex(itemID)], true));
+ }
+
+ return mapping;
+ }
+
+ protected OpenIntObjectHashMap<Vector> userFeaturesMapping(LongPrimitiveIterator userIDs, int numUsers,
+ double[][] featureMatrix) {
+ OpenIntObjectHashMap<Vector> mapping = new OpenIntObjectHashMap<>(numUsers);
+
+ while (userIDs.hasNext()) {
+ long userID = userIDs.next();
+ int userIndex = userIndex(userID);
+ mapping.put(userIndex, new DenseVector(featureMatrix[userIndex(userID)], true));
+ }
+
+ return mapping;
+ }
+
+ protected Vector sparseItemRatingVector(PreferenceArray prefs) {
+ SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length());
+ for (Preference preference : prefs) {
+ ratings.set(userIndex(preference.getUserID()), preference.getValue());
+ }
+ return ratings;
+ }
+
+ protected Vector sparseUserRatingVector(PreferenceArray prefs) {
+ SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length());
+ for (Preference preference : prefs) {
+ ratings.set(itemIndex(preference.getItemID()), preference.getValue());
+ }
+ return ratings;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
new file mode 100644
index 0000000..0a39a1d
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
@@ -0,0 +1,94 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+/**
+ * base class for {@link Factorizer}s, provides ID to index mapping
+ */
+public abstract class AbstractFactorizer implements Factorizer {
+
+ private final DataModel dataModel;
+ private FastByIDMap<Integer> userIDMapping;
+ private FastByIDMap<Integer> itemIDMapping;
+ private final RefreshHelper refreshHelper;
+
+ protected AbstractFactorizer(DataModel dataModel) throws TasteException {
+ this.dataModel = dataModel;
+ buildMappings();
+ refreshHelper = new RefreshHelper(new Callable<Object>() {
+ @Override
+ public Object call() throws TasteException {
+ buildMappings();
+ return null;
+ }
+ });
+ refreshHelper.addDependency(dataModel);
+ }
+
+ private void buildMappings() throws TasteException {
+ userIDMapping = createIDMapping(dataModel.getNumUsers(), dataModel.getUserIDs());
+ itemIDMapping = createIDMapping(dataModel.getNumItems(), dataModel.getItemIDs());
+ }
+
+ protected Factorization createFactorization(double[][] userFeatures, double[][] itemFeatures) {
+ return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
+ }
+
+ protected Integer userIndex(long userID) {
+ Integer userIndex = userIDMapping.get(userID);
+ if (userIndex == null) {
+ userIndex = userIDMapping.size();
+ userIDMapping.put(userID, userIndex);
+ }
+ return userIndex;
+ }
+
+ protected Integer itemIndex(long itemID) {
+ Integer itemIndex = itemIDMapping.get(itemID);
+ if (itemIndex == null) {
+ itemIndex = itemIDMapping.size();
+ itemIDMapping.put(itemID, itemIndex);
+ }
+ return itemIndex;
+ }
+
+ private static FastByIDMap<Integer> createIDMapping(int size, LongPrimitiveIterator idIterator) {
+ FastByIDMap<Integer> mapping = new FastByIDMap<>(size);
+ int index = 0;
+ while (idIterator.hasNext()) {
+ mapping.put(idIterator.nextLong(), index++);
+ }
+ return mapping;
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
new file mode 100644
index 0000000..f169a60
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
@@ -0,0 +1,137 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.util.Arrays;
+import java.util.Map;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+
+/**
+ * a factorization of the rating matrix
+ */
+public class Factorization {
+
+ /** used to find the rows in the user features matrix by userID */
+ private final FastByIDMap<Integer> userIDMapping;
+ /** used to find the rows in the item features matrix by itemID */
+ private final FastByIDMap<Integer> itemIDMapping;
+
+ /** user features matrix */
+ private final double[][] userFeatures;
+ /** item features matrix */
+ private final double[][] itemFeatures;
+
+ public Factorization(FastByIDMap<Integer> userIDMapping, FastByIDMap<Integer> itemIDMapping, double[][] userFeatures,
+ double[][] itemFeatures) {
+ this.userIDMapping = Preconditions.checkNotNull(userIDMapping);
+ this.itemIDMapping = Preconditions.checkNotNull(itemIDMapping);
+ this.userFeatures = userFeatures;
+ this.itemFeatures = itemFeatures;
+ }
+
+ public double[][] allUserFeatures() {
+ return userFeatures;
+ }
+
+ public double[] getUserFeatures(long userID) throws NoSuchUserException {
+ Integer index = userIDMapping.get(userID);
+ if (index == null) {
+ throw new NoSuchUserException(userID);
+ }
+ return userFeatures[index];
+ }
+
+ public double[][] allItemFeatures() {
+ return itemFeatures;
+ }
+
+ public double[] getItemFeatures(long itemID) throws NoSuchItemException {
+ Integer index = itemIDMapping.get(itemID);
+ if (index == null) {
+ throw new NoSuchItemException(itemID);
+ }
+ return itemFeatures[index];
+ }
+
+ public int userIndex(long userID) throws NoSuchUserException {
+ Integer index = userIDMapping.get(userID);
+ if (index == null) {
+ throw new NoSuchUserException(userID);
+ }
+ return index;
+ }
+
+ public Iterable<Map.Entry<Long,Integer>> getUserIDMappings() {
+ return userIDMapping.entrySet();
+ }
+
+ public LongPrimitiveIterator getUserIDMappingKeys() {
+ return userIDMapping.keySetIterator();
+ }
+
+ public int itemIndex(long itemID) throws NoSuchItemException {
+ Integer index = itemIDMapping.get(itemID);
+ if (index == null) {
+ throw new NoSuchItemException(itemID);
+ }
+ return index;
+ }
+
+ public Iterable<Map.Entry<Long,Integer>> getItemIDMappings() {
+ return itemIDMapping.entrySet();
+ }
+
+ public LongPrimitiveIterator getItemIDMappingKeys() {
+ return itemIDMapping.keySetIterator();
+ }
+
+ public int numFeatures() {
+ return userFeatures.length > 0 ? userFeatures[0].length : 0;
+ }
+
+ public int numUsers() {
+ return userIDMapping.size();
+ }
+
+ public int numItems() {
+ return itemIDMapping.size();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof Factorization) {
+ Factorization other = (Factorization) o;
+ return userIDMapping.equals(other.userIDMapping) && itemIDMapping.equals(other.itemIDMapping)
+ && Arrays.deepEquals(userFeatures, other.userFeatures) && Arrays.deepEquals(itemFeatures, other.itemFeatures);
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ int hashCode = 31 * userIDMapping.hashCode() + itemIDMapping.hashCode();
+ hashCode = 31 * hashCode + Arrays.deepHashCode(userFeatures);
+ hashCode = 31 * hashCode + Arrays.deepHashCode(itemFeatures);
+ return hashCode;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
new file mode 100644
index 0000000..2cabe73
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
@@ -0,0 +1,30 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * Implementation must be able to create a factorization of a rating matrix
+ */
+public interface Factorizer extends Refreshable {
+
+ Factorization factorize() throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java
new file mode 100644
index 0000000..08c038a
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Provides a file-based persistent store. */
+public class FilePersistenceStrategy implements PersistenceStrategy {
+
+ private final File file;
+
+ private static final Logger log = LoggerFactory.getLogger(FilePersistenceStrategy.class);
+
+ /**
+ * @param file the file to use for storage. If the file does not exist it will be created when required.
+ */
+ public FilePersistenceStrategy(File file) {
+ this.file = Preconditions.checkNotNull(file);
+ }
+
+ @Override
+ public Factorization load() throws IOException {
+ if (!file.exists()) {
+ log.info("{} does not yet exist, no factorization found", file.getAbsolutePath());
+ return null;
+ }
+ try (DataInputStream in = new DataInputStream(new BufferedInputStream(new FileInputStream(file)))){
+ log.info("Reading factorization from {}...", file.getAbsolutePath());
+ return readBinary(in);
+ }
+ }
+
+ @Override
+ public void maybePersist(Factorization factorization) throws IOException {
+ try (DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)))){
+ log.info("Writing factorization to {}...", file.getAbsolutePath());
+ writeBinary(factorization, out);
+ }
+ }
+
+ protected static void writeBinary(Factorization factorization, DataOutput out) throws IOException {
+ out.writeInt(factorization.numFeatures());
+ out.writeInt(factorization.numUsers());
+ out.writeInt(factorization.numItems());
+
+ for (Map.Entry<Long,Integer> mappingEntry : factorization.getUserIDMappings()) {
+ long userID = mappingEntry.getKey();
+ out.writeInt(mappingEntry.getValue());
+ out.writeLong(userID);
+ try {
+ double[] userFeatures = factorization.getUserFeatures(userID);
+ for (int feature = 0; feature < factorization.numFeatures(); feature++) {
+ out.writeDouble(userFeatures[feature]);
+ }
+ } catch (NoSuchUserException e) {
+ throw new IOException("Unable to persist factorization", e);
+ }
+ }
+
+ for (Map.Entry<Long,Integer> entry : factorization.getItemIDMappings()) {
+ long itemID = entry.getKey();
+ out.writeInt(entry.getValue());
+ out.writeLong(itemID);
+ try {
+ double[] itemFeatures = factorization.getItemFeatures(itemID);
+ for (int feature = 0; feature < factorization.numFeatures(); feature++) {
+ out.writeDouble(itemFeatures[feature]);
+ }
+ } catch (NoSuchItemException e) {
+ throw new IOException("Unable to persist factorization", e);
+ }
+ }
+ }
+
+ public static Factorization readBinary(DataInput in) throws IOException {
+ int numFeatures = in.readInt();
+ int numUsers = in.readInt();
+ int numItems = in.readInt();
+
+ FastByIDMap<Integer> userIDMapping = new FastByIDMap<>(numUsers);
+ double[][] userFeatures = new double[numUsers][numFeatures];
+
+ for (int n = 0; n < numUsers; n++) {
+ int userIndex = in.readInt();
+ long userID = in.readLong();
+ userIDMapping.put(userID, userIndex);
+ for (int feature = 0; feature < numFeatures; feature++) {
+ userFeatures[userIndex][feature] = in.readDouble();
+ }
+ }
+
+ FastByIDMap<Integer> itemIDMapping = new FastByIDMap<>(numItems);
+ double[][] itemFeatures = new double[numItems][numFeatures];
+
+ for (int n = 0; n < numItems; n++) {
+ int itemIndex = in.readInt();
+ long itemID = in.readLong();
+ itemIDMapping.put(itemID, itemIndex);
+ for (int feature = 0; feature < numFeatures; feature++) {
+ itemFeatures[itemIndex][feature] = in.readDouble();
+ }
+ }
+
+ return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java
new file mode 100644
index 0000000..0d1aab0
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.io.IOException;
+
+/**
+ * A {@link PersistenceStrategy} which does nothing.
+ */
+public class NoPersistenceStrategy implements PersistenceStrategy {
+
+ @Override
+ public Factorization load() throws IOException {
+ return null;
+ }
+
+ @Override
+ public void maybePersist(Factorization factorization) throws IOException {
+ // do nothing.
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java
new file mode 100644
index 0000000..8a6a702
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java
@@ -0,0 +1,340 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Minimalistic implementation of Parallel SGD factorizer based on
+ * <a href="http://www.sze.hu/~gtakacs/download/jmlr_2009.pdf">
+ * "Scalable Collaborative Filtering Approaches for Large Recommender Systems"</a>
+ * and
+ * <a href="hwww.cs.wisc.edu/~brecht/papers/hogwildTR.pdf">
+ * "Hogwild!: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent"</a> */
+public class ParallelSGDFactorizer extends AbstractFactorizer {
+
+ private final DataModel dataModel;
+ /** Parameter used to prevent overfitting. */
+ private final double lambda;
+ /** Number of features used to compute this factorization */
+ private final int rank;
+ /** Number of iterations */
+ private final int numEpochs;
+
+ private int numThreads;
+
+ // these next two control decayFactor^steps exponential type of annealing learning rate and decay factor
+ private double mu0 = 0.01;
+ private double decayFactor = 1;
+ // these next two control 1/steps^forget type annealing
+ private int stepOffset = 0;
+ // -1 equals even weighting of all examples, 0 means only use exponential annealing
+ private double forgettingExponent = 0;
+
+ // The following two should be inversely proportional :)
+ private double biasMuRatio = 0.5;
+ private double biasLambdaRatio = 0.1;
+
+ /** TODO: this is not safe as += is not atomic on many processors, can be replaced with AtomicDoubleArray
+ * but it works just fine right now */
+ /** user features */
+ protected volatile double[][] userVectors;
+ /** item features */
+ protected volatile double[][] itemVectors;
+
+ private final PreferenceShuffler shuffler;
+
+ private int epoch = 1;
+ /** place in user vector where the bias is stored */
+ private static final int USER_BIAS_INDEX = 1;
+ /** place in item vector where the bias is stored */
+ private static final int ITEM_BIAS_INDEX = 2;
+ private static final int FEATURE_OFFSET = 3;
+ /** Standard deviation for random initialization of features */
+ private static final double NOISE = 0.02;
+
+ private static final Logger logger = LoggerFactory.getLogger(ParallelSGDFactorizer.class);
+
+ protected static class PreferenceShuffler {
+
+ private Preference[] preferences;
+ private Preference[] unstagedPreferences;
+
+ protected final RandomWrapper random = RandomUtils.getRandom();
+
+ public PreferenceShuffler(DataModel dataModel) throws TasteException {
+ cachePreferences(dataModel);
+ shuffle();
+ stage();
+ }
+
+ private int countPreferences(DataModel dataModel) throws TasteException {
+ int numPreferences = 0;
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong());
+ numPreferences += preferencesFromUser.length();
+ }
+ return numPreferences;
+ }
+
+ private void cachePreferences(DataModel dataModel) throws TasteException {
+ int numPreferences = countPreferences(dataModel);
+ preferences = new Preference[numPreferences];
+
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ int index = 0;
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID);
+ for (Preference preference : preferencesFromUser) {
+ preferences[index++] = preference;
+ }
+ }
+ }
+
+ public final void shuffle() {
+ unstagedPreferences = preferences.clone();
+ /* Durstenfeld shuffle */
+ for (int i = unstagedPreferences.length - 1; i > 0; i--) {
+ int rand = random.nextInt(i + 1);
+ swapCachedPreferences(i, rand);
+ }
+ }
+
+ //merge this part into shuffle() will make compiler-optimizer do some real absurd stuff, test on OpenJDK7
+ private void swapCachedPreferences(int x, int y) {
+ Preference p = unstagedPreferences[x];
+
+ unstagedPreferences[x] = unstagedPreferences[y];
+ unstagedPreferences[y] = p;
+ }
+
+ public final void stage() {
+ preferences = unstagedPreferences;
+ }
+
+ public Preference get(int i) {
+ return preferences[i];
+ }
+
+ public int size() {
+ return preferences.length;
+ }
+
+ }
+
+ public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numEpochs)
+ throws TasteException {
+ super(dataModel);
+ this.dataModel = dataModel;
+ this.rank = numFeatures + FEATURE_OFFSET;
+ this.lambda = lambda;
+ this.numEpochs = numEpochs;
+
+ shuffler = new PreferenceShuffler(dataModel);
+
+ //max thread num set to n^0.25 as suggested by hogwild! paper
+ numThreads = Math.min(Runtime.getRuntime().availableProcessors(), (int) Math.pow((double) shuffler.size(), 0.25));
+ }
+
+ public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ double mu0, double decayFactor, int stepOffset, double forgettingExponent) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations);
+
+ this.mu0 = mu0;
+ this.decayFactor = decayFactor;
+ this.stepOffset = stepOffset;
+ this.forgettingExponent = forgettingExponent;
+ }
+
+ public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ double mu0, double decayFactor, int stepOffset, double forgettingExponent, int numThreads) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent);
+
+ this.numThreads = numThreads;
+ }
+
+ public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ double mu0, double decayFactor, int stepOffset, double forgettingExponent,
+ double biasMuRatio, double biasLambdaRatio) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent);
+
+ this.biasMuRatio = biasMuRatio;
+ this.biasLambdaRatio = biasLambdaRatio;
+ }
+
+ public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ double mu0, double decayFactor, int stepOffset, double forgettingExponent,
+ double biasMuRatio, double biasLambdaRatio, int numThreads) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent, biasMuRatio,
+ biasLambdaRatio);
+
+ this.numThreads = numThreads;
+ }
+
+ protected void initialize() throws TasteException {
+ RandomWrapper random = RandomUtils.getRandom();
+ userVectors = new double[dataModel.getNumUsers()][rank];
+ itemVectors = new double[dataModel.getNumItems()][rank];
+
+ double globalAverage = getAveragePreference();
+ for (int userIndex = 0; userIndex < userVectors.length; userIndex++) {
+ userVectors[userIndex][0] = globalAverage;
+ userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias
+ userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item feature contains item bias
+ for (int feature = FEATURE_OFFSET; feature < rank; feature++) {
+ userVectors[userIndex][feature] = random.nextGaussian() * NOISE;
+ }
+ }
+ for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) {
+ itemVectors[itemIndex][0] = 1; // corresponding user feature contains global average
+ itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user feature contains user bias
+ itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias
+ for (int feature = FEATURE_OFFSET; feature < rank; feature++) {
+ itemVectors[itemIndex][feature] = random.nextGaussian() * NOISE;
+ }
+ }
+ }
+
+ //TODO: needs optimization
+ private double getMu(int i) {
+ return mu0 * Math.pow(decayFactor, i - 1) * Math.pow(i + stepOffset, forgettingExponent);
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ initialize();
+
+ if (logger.isInfoEnabled()) {
+ logger.info("starting to compute the factorization...");
+ }
+
+ for (epoch = 1; epoch <= numEpochs; epoch++) {
+ shuffler.stage();
+
+ final double mu = getMu(epoch);
+ int subSize = shuffler.size() / numThreads + 1;
+
+ ExecutorService executor=Executors.newFixedThreadPool(numThreads);
+
+ try {
+ for (int t = 0; t < numThreads; t++) {
+ final int iStart = t * subSize;
+ final int iEnd = Math.min((t + 1) * subSize, shuffler.size());
+
+ executor.execute(new Runnable() {
+ @Override
+ public void run() {
+ for (int i = iStart; i < iEnd; i++) {
+ update(shuffler.get(i), mu);
+ }
+ }
+ });
+ }
+ } finally {
+ executor.shutdown();
+ shuffler.shuffle();
+
+ try {
+ boolean terminated = executor.awaitTermination(numEpochs * shuffler.size(), TimeUnit.MICROSECONDS);
+ if (!terminated) {
+ logger.error("subtasks takes forever, return anyway");
+ }
+ } catch (InterruptedException e) {
+ throw new TasteException("waiting fof termination interrupted", e);
+ }
+ }
+
+ }
+
+ return createFactorization(userVectors, itemVectors);
+ }
+
+ double getAveragePreference() throws TasteException {
+ RunningAverage average = new FullRunningAverage();
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+ average.addDatum(pref.getValue());
+ }
+ }
+ return average.getAverage();
+ }
+
+ /** TODO: this is the vanilla sgd by Tacaks 2009, I speculate that using scaling technique proposed in:
+ * Towards Optimal One Pass Large Scale Learning with Averaged Stochastic Gradient Descent section 5, page 6
+ * can be beneficial in term s of both speed and accuracy.
+ *
+ * Tacaks' method doesn't calculate gradient of regularization correctly, which has non-zero elements everywhere of
+ * the matrix. While Tacaks' method can only updates a single row/column, if one user has a lot of recommendation,
+ * her vector will be more affected by regularization using an isolated scaling factor for both user vectors and
+ * item vectors can remove this issue without inducing more update cost it even reduces it a bit by only performing
+ * one addition and one multiplication.
+ *
+ * BAD SIDE1: the scaling factor decreases fast, it has to be scaled up from time to time before dropped to zero or
+ * caused roundoff error
+ * BAD SIDE2: no body experiment on it before, and people generally use very small lambda
+ * so it's impact on accuracy may still be unknown.
+ * BAD SIDE3: don't know how to make it work for L1-regularization or
+ * "pseudorank?" (sum of singular values)-regularization */
+ protected void update(Preference preference, double mu) {
+ int userIndex = userIndex(preference.getUserID());
+ int itemIndex = itemIndex(preference.getItemID());
+
+ double[] userVector = userVectors[userIndex];
+ double[] itemVector = itemVectors[itemIndex];
+
+ double prediction = dot(userVector, itemVector);
+ double err = preference.getValue() - prediction;
+
+ // adjust features
+ for (int k = FEATURE_OFFSET; k < rank; k++) {
+ double userFeature = userVector[k];
+ double itemFeature = itemVector[k];
+
+ userVector[k] += mu * (err * itemFeature - lambda * userFeature);
+ itemVector[k] += mu * (err * userFeature - lambda * itemFeature);
+ }
+
+ // adjust user and item bias
+ userVector[USER_BIAS_INDEX] += biasMuRatio * mu * (err - biasLambdaRatio * lambda * userVector[USER_BIAS_INDEX]);
+ itemVector[ITEM_BIAS_INDEX] += biasMuRatio * mu * (err - biasLambdaRatio * lambda * itemVector[ITEM_BIAS_INDEX]);
+ }
+
+ private double dot(double[] userVector, double[] itemVector) {
+ double sum = 0;
+ for (int k = 0; k < rank; k++) {
+ sum += userVector[k] * itemVector[k];
+ }
+ return sum;
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java
new file mode 100644
index 0000000..abf3eca
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.io.IOException;
+
+/**
+ * Provides storage for {@link Factorization}s
+ */
+public interface PersistenceStrategy {
+
+ /**
+ * Load a factorization from a persistent store.
+ *
+ * @return a Factorization or null if the persistent store is empty.
+ *
+ * @throws IOException
+ */
+ Factorization load() throws IOException;
+
+ /**
+ * Write a factorization to a persistent store unless it already
+ * contains an identical factorization.
+ *
+ * @param factorization
+ *
+ * @throws IOException
+ */
+ void maybePersist(Factorization factorization) throws IOException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
new file mode 100644
index 0000000..2c9f0ae
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
@@ -0,0 +1,221 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+
+/** Matrix factorization with user and item biases for rating prediction, trained with plain vanilla SGD */
+public class RatingSGDFactorizer extends AbstractFactorizer {
+
+ protected static final int FEATURE_OFFSET = 3;
+
+ /** Multiplicative decay factor for learning_rate */
+ protected final double learningRateDecay;
+ /** Learning rate (step size) */
+ protected final double learningRate;
+ /** Parameter used to prevent overfitting. */
+ protected final double preventOverfitting;
+ /** Number of features used to compute this factorization */
+ protected final int numFeatures;
+ /** Number of iterations */
+ private final int numIterations;
+ /** Standard deviation for random initialization of features */
+ protected final double randomNoise;
+ /** User features */
+ protected double[][] userVectors;
+ /** Item features */
+ protected double[][] itemVectors;
+ protected final DataModel dataModel;
+ private long[] cachedUserIDs;
+ private long[] cachedItemIDs;
+
+ protected double biasLearningRate = 0.5;
+ protected double biasReg = 0.1;
+
+ /** place in user vector where the bias is stored */
+ protected static final int USER_BIAS_INDEX = 1;
+ /** place in item vector where the bias is stored */
+ protected static final int ITEM_BIAS_INDEX = 2;
+
+ public RatingSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException {
+ this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0);
+ }
+
+ public RatingSGDFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting,
+ double randomNoise, int numIterations, double learningRateDecay) throws TasteException {
+ super(dataModel);
+ this.dataModel = dataModel;
+ this.numFeatures = numFeatures + FEATURE_OFFSET;
+ this.numIterations = numIterations;
+
+ this.learningRate = learningRate;
+ this.learningRateDecay = learningRateDecay;
+ this.preventOverfitting = preventOverfitting;
+ this.randomNoise = randomNoise;
+ }
+
+ protected void prepareTraining() throws TasteException {
+ RandomWrapper random = RandomUtils.getRandom();
+ userVectors = new double[dataModel.getNumUsers()][numFeatures];
+ itemVectors = new double[dataModel.getNumItems()][numFeatures];
+
+ double globalAverage = getAveragePreference();
+ for (int userIndex = 0; userIndex < userVectors.length; userIndex++) {
+ userVectors[userIndex][0] = globalAverage;
+ userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias
+ userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item feature contains item bias
+ for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
+ userVectors[userIndex][feature] = random.nextGaussian() * randomNoise;
+ }
+ }
+ for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) {
+ itemVectors[itemIndex][0] = 1; // corresponding user feature contains global average
+ itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user feature contains user bias
+ itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias
+ for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
+ itemVectors[itemIndex][feature] = random.nextGaussian() * randomNoise;
+ }
+ }
+
+ cachePreferences();
+ shufflePreferences();
+ }
+
+ private int countPreferences() throws TasteException {
+ int numPreferences = 0;
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong());
+ numPreferences += preferencesFromUser.length();
+ }
+ return numPreferences;
+ }
+
+ private void cachePreferences() throws TasteException {
+ int numPreferences = countPreferences();
+ cachedUserIDs = new long[numPreferences];
+ cachedItemIDs = new long[numPreferences];
+
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ int index = 0;
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID);
+ for (Preference preference : preferencesFromUser) {
+ cachedUserIDs[index] = userID;
+ cachedItemIDs[index] = preference.getItemID();
+ index++;
+ }
+ }
+ }
+
+ protected void shufflePreferences() {
+ RandomWrapper random = RandomUtils.getRandom();
+ /* Durstenfeld shuffle */
+ for (int currentPos = cachedUserIDs.length - 1; currentPos > 0; currentPos--) {
+ int swapPos = random.nextInt(currentPos + 1);
+ swapCachedPreferences(currentPos, swapPos);
+ }
+ }
+
+ private void swapCachedPreferences(int posA, int posB) {
+ long tmpUserIndex = cachedUserIDs[posA];
+ long tmpItemIndex = cachedItemIDs[posA];
+
+ cachedUserIDs[posA] = cachedUserIDs[posB];
+ cachedItemIDs[posA] = cachedItemIDs[posB];
+
+ cachedUserIDs[posB] = tmpUserIndex;
+ cachedItemIDs[posB] = tmpItemIndex;
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ prepareTraining();
+ double currentLearningRate = learningRate;
+
+
+ for (int it = 0; it < numIterations; it++) {
+ for (int index = 0; index < cachedUserIDs.length; index++) {
+ long userId = cachedUserIDs[index];
+ long itemId = cachedItemIDs[index];
+ float rating = dataModel.getPreferenceValue(userId, itemId);
+ updateParameters(userId, itemId, rating, currentLearningRate);
+ }
+ currentLearningRate *= learningRateDecay;
+ }
+ return createFactorization(userVectors, itemVectors);
+ }
+
+ double getAveragePreference() throws TasteException {
+ RunningAverage average = new FullRunningAverage();
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+ average.addDatum(pref.getValue());
+ }
+ }
+ return average.getAverage();
+ }
+
+ protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) {
+ int userIndex = userIndex(userID);
+ int itemIndex = itemIndex(itemID);
+
+ double[] userVector = userVectors[userIndex];
+ double[] itemVector = itemVectors[itemIndex];
+ double prediction = predictRating(userIndex, itemIndex);
+ double err = rating - prediction;
+
+ // adjust user bias
+ userVector[USER_BIAS_INDEX] +=
+ biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * userVector[USER_BIAS_INDEX]);
+
+ // adjust item bias
+ itemVector[ITEM_BIAS_INDEX] +=
+ biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * itemVector[ITEM_BIAS_INDEX]);
+
+ // adjust features
+ for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
+ double userFeature = userVector[feature];
+ double itemFeature = itemVector[feature];
+
+ double deltaUserFeature = err * itemFeature - preventOverfitting * userFeature;
+ userVector[feature] += currentLearningRate * deltaUserFeature;
+
+ double deltaItemFeature = err * userFeature - preventOverfitting * itemFeature;
+ itemVector[feature] += currentLearningRate * deltaItemFeature;
+ }
+ }
+
+ private double predictRating(int userID, int itemID) {
+ double sum = 0;
+ for (int feature = 0; feature < numFeatures; feature++) {
+ sum += userVectors[userID][feature] * itemVectors[itemID][feature];
+ }
+ return sum;
+ }
+}