You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/28 14:54:54 UTC
[26/51] [partial] mahout git commit: NO-JIRA Clean up MR refactor
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/eval/RelevantItemsDataSplitter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/eval/RelevantItemsDataSplitter.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/eval/RelevantItemsDataSplitter.java
new file mode 100644
index 0000000..da318d5
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/eval/RelevantItemsDataSplitter.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.eval;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+
+/**
+ * Implementations of this interface determine the items that are considered relevant,
+ * and splits data into a training and test subset, for purposes of precision/recall
+ * tests as implemented by implementations of {@link RecommenderIRStatsEvaluator}.
+ */
+public interface RelevantItemsDataSplitter {
+
+ /**
+ * During testing, relevant items are removed from a particular users' preferences,
+ * and a model is build using this user's other preferences and all other users.
+ *
+ * @param at Maximum number of items to be removed
+ * @param relevanceThreshold Minimum strength of preference for an item to be considered
+ * relevant
+ * @return IDs of relevant items
+ */
+ FastIDSet getRelevantItemsIDs(long userID,
+ int at,
+ double relevanceThreshold,
+ DataModel dataModel) throws TasteException;
+
+ /**
+ * Adds a single user and all their preferences to the training model.
+ *
+ * @param userID ID of user whose preferences we are trying to predict
+ * @param relevantItemIDs IDs of items considered relevant to that user
+ * @param trainingUsers the database of training preferences to which we will
+ * append the ones for otherUserID.
+ * @param otherUserID for whom we are adding preferences to the training model
+ */
+ void processOtherUser(long userID,
+ FastIDSet relevantItemIDs,
+ FastByIDMap<PreferenceArray> trainingUsers,
+ long otherUserID,
+ DataModel dataModel) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityEntityWritable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityEntityWritable.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityEntityWritable.java
new file mode 100644
index 0000000..e70a675
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityEntityWritable.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import com.google.common.primitives.Longs;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.mahout.math.Varint;
+
+/** A {@link WritableComparable} encapsulating two items. */
+public final class EntityEntityWritable implements WritableComparable<EntityEntityWritable>, Cloneable {
+
+ private long aID;
+ private long bID;
+
+ public EntityEntityWritable() {
+ // do nothing
+ }
+
+ public EntityEntityWritable(long aID, long bID) {
+ this.aID = aID;
+ this.bID = bID;
+ }
+
+ long getAID() {
+ return aID;
+ }
+
+ long getBID() {
+ return bID;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ Varint.writeSignedVarLong(aID, out);
+ Varint.writeSignedVarLong(bID, out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ aID = Varint.readSignedVarLong(in);
+ bID = Varint.readSignedVarLong(in);
+ }
+
+ @Override
+ public int compareTo(EntityEntityWritable that) {
+ int aCompare = compare(aID, that.getAID());
+ return aCompare == 0 ? compare(bID, that.getBID()) : aCompare;
+ }
+
+ private static int compare(long a, long b) {
+ return a < b ? -1 : a > b ? 1 : 0;
+ }
+
+ @Override
+ public int hashCode() {
+ return Longs.hashCode(aID) + 31 * Longs.hashCode(bID);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof EntityEntityWritable) {
+ EntityEntityWritable that = (EntityEntityWritable) o;
+ return aID == that.getAID() && bID == that.getBID();
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return aID + "\t" + bID;
+ }
+
+ @Override
+ public EntityEntityWritable clone() {
+ return new EntityEntityWritable(aID, bID);
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityPrefWritable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityPrefWritable.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityPrefWritable.java
new file mode 100644
index 0000000..2aab63c
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityPrefWritable.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.VarLongWritable;
+
+/** A {@link org.apache.hadoop.io.Writable} encapsulating an item ID and a preference value. */
+public final class EntityPrefWritable extends VarLongWritable implements Cloneable {
+
+ private float prefValue;
+
+ public EntityPrefWritable() {
+ // do nothing
+ }
+
+ public EntityPrefWritable(long itemID, float prefValue) {
+ super(itemID);
+ this.prefValue = prefValue;
+ }
+
+ public EntityPrefWritable(EntityPrefWritable other) {
+ this(other.get(), other.getPrefValue());
+ }
+
+ public long getID() {
+ return get();
+ }
+
+ public float getPrefValue() {
+ return prefValue;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ super.write(out);
+ out.writeFloat(prefValue);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ super.readFields(in);
+ prefValue = in.readFloat();
+ }
+
+ @Override
+ public int hashCode() {
+ return super.hashCode() ^ RandomUtils.hashFloat(prefValue);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof EntityPrefWritable)) {
+ return false;
+ }
+ EntityPrefWritable other = (EntityPrefWritable) o;
+ return get() == other.get() && prefValue == other.getPrefValue();
+ }
+
+ @Override
+ public String toString() {
+ return get() + "\t" + prefValue;
+ }
+
+ @Override
+ public EntityPrefWritable clone() {
+ return new EntityPrefWritable(get(), prefValue);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/MutableRecommendedItem.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/MutableRecommendedItem.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/MutableRecommendedItem.java
new file mode 100644
index 0000000..3de272d
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/MutableRecommendedItem.java
@@ -0,0 +1,81 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.common.RandomUtils;
+
+/**
+ * Mutable variant of {@link RecommendedItem}
+ */
+public class MutableRecommendedItem implements RecommendedItem {
+
+ private long itemID;
+ private float value;
+
+ public MutableRecommendedItem() {}
+
+ public MutableRecommendedItem(long itemID, float value) {
+ this.itemID = itemID;
+ this.value = value;
+ }
+
+ @Override
+ public long getItemID() {
+ return itemID;
+ }
+
+ @Override
+ public float getValue() {
+ return value;
+ }
+
+ public void setItemID(long itemID) {
+ this.itemID = itemID;
+ }
+
+ public void set(long itemID, float value) {
+ this.itemID = itemID;
+ this.value = value;
+ }
+
+ public void capToMaxValue(float maxValue) {
+ if (value > maxValue) {
+ value = maxValue;
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "MutableRecommendedItem[item:" + itemID + ", value:" + value + ']';
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) itemID ^ RandomUtils.hashFloat(value);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof MutableRecommendedItem)) {
+ return false;
+ }
+ RecommendedItem other = (RecommendedItem) o;
+ return itemID == other.getItemID() && value == other.getValue();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/RecommendedItemsWritable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/RecommendedItemsWritable.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/RecommendedItemsWritable.java
new file mode 100644
index 0000000..bc832aa
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/RecommendedItemsWritable.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.cf.taste.impl.recommender.GenericRecommendedItem;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.math.Varint;
+
+/**
+ * A {@link Writable} which encapsulates a list of {@link RecommendedItem}s. This is the mapper (and reducer)
+ * output, and represents items recommended to a user. The first item is the one whose estimated preference is
+ * highest.
+ */
+public final class RecommendedItemsWritable implements Writable {
+
+ private List<RecommendedItem> recommended;
+
+ public RecommendedItemsWritable() {
+ // do nothing
+ }
+
+ public RecommendedItemsWritable(List<RecommendedItem> recommended) {
+ this.recommended = recommended;
+ }
+
+ public List<RecommendedItem> getRecommendedItems() {
+ return recommended;
+ }
+
+ public void set(List<RecommendedItem> recommended) {
+ this.recommended = recommended;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(recommended.size());
+ for (RecommendedItem item : recommended) {
+ Varint.writeSignedVarLong(item.getItemID(), out);
+ out.writeFloat(item.getValue());
+ }
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int size = in.readInt();
+ recommended = new ArrayList<>(size);
+ for (int i = 0; i < size; i++) {
+ long itemID = Varint.readSignedVarLong(in);
+ float value = in.readFloat();
+ RecommendedItem recommendedItem = new GenericRecommendedItem(itemID, value);
+ recommended.add(recommendedItem);
+ }
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder result = new StringBuilder(200);
+ result.append('[');
+ boolean first = true;
+ for (RecommendedItem item : recommended) {
+ if (first) {
+ first = false;
+ } else {
+ result.append(',');
+ }
+ result.append(String.valueOf(item.getItemID()));
+ result.append(':');
+ result.append(String.valueOf(item.getValue()));
+ }
+ result.append(']');
+ return result.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtils.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtils.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtils.java
new file mode 100644
index 0000000..e3fab29
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtils.java
@@ -0,0 +1,84 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import com.google.common.primitives.Longs;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+
+import java.util.regex.Pattern;
+
+/**
+ * Some helper methods for the hadoop-related stuff in org.apache.mahout.cf.taste
+ */
+public final class TasteHadoopUtils {
+
+ public static final int USER_ID_POS = 0;
+ public static final int ITEM_ID_POS = 1;
+
+ /** Standard delimiter of textual preference data */
+ private static final Pattern PREFERENCE_TOKEN_DELIMITER = Pattern.compile("[\t,]");
+
+ private TasteHadoopUtils() {}
+
+ /**
+ * Splits a preference data line into string tokens
+ */
+ public static String[] splitPrefTokens(CharSequence line) {
+ return PREFERENCE_TOKEN_DELIMITER.split(line);
+ }
+
+ /**
+ * Maps a long to an int with range of 0 to Integer.MAX_VALUE-1
+ */
+ public static int idToIndex(long id) {
+ return 0x7FFFFFFF & Longs.hashCode(id) % 0x7FFFFFFE;
+ }
+
+ public static int readID(String token, boolean usesLongIDs) {
+ return usesLongIDs ? idToIndex(Long.parseLong(token)) : Integer.parseInt(token);
+ }
+
+ /**
+ * Reads a binary mapping file
+ */
+ public static OpenIntLongHashMap readIDIndexMap(String idIndexPathStr, Configuration conf) {
+ OpenIntLongHashMap indexIDMap = new OpenIntLongHashMap();
+ Path itemIDIndexPath = new Path(idIndexPathStr);
+ for (Pair<VarIntWritable,VarLongWritable> record
+ : new SequenceFileDirIterable<VarIntWritable,VarLongWritable>(itemIDIndexPath,
+ PathType.LIST,
+ PathFilters.partFilter(),
+ null,
+ true,
+ conf)) {
+ indexIDMap.put(record.getFirst().get(), record.getSecond().get());
+ }
+ return indexIDMap;
+ }
+
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToEntityPrefsMapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToEntityPrefsMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToEntityPrefsMapper.java
new file mode 100644
index 0000000..fdb552e
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToEntityPrefsMapper.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.cf.taste.hadoop.item.RecommenderJob;
+import org.apache.mahout.math.VarLongWritable;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+public abstract class ToEntityPrefsMapper extends
+ Mapper<LongWritable,Text, VarLongWritable,VarLongWritable> {
+
+ public static final String TRANSPOSE_USER_ITEM = ToEntityPrefsMapper.class + "transposeUserItem";
+ public static final String RATING_SHIFT = ToEntityPrefsMapper.class + "shiftRatings";
+
+ private static final Pattern DELIMITER = Pattern.compile("[\t,]");
+
+ private boolean booleanData;
+ private boolean transpose;
+ private final boolean itemKey;
+ private float ratingShift;
+
+ ToEntityPrefsMapper(boolean itemKey) {
+ this.itemKey = itemKey;
+ }
+
+ @Override
+ protected void setup(Context context) {
+ Configuration jobConf = context.getConfiguration();
+ booleanData = jobConf.getBoolean(RecommenderJob.BOOLEAN_DATA, false);
+ transpose = jobConf.getBoolean(TRANSPOSE_USER_ITEM, false);
+ ratingShift = Float.parseFloat(jobConf.get(RATING_SHIFT, "0.0"));
+ }
+
+ @Override
+ public void map(LongWritable key,
+ Text value,
+ Context context) throws IOException, InterruptedException {
+ String[] tokens = DELIMITER.split(value.toString());
+ long userID = Long.parseLong(tokens[0]);
+ long itemID = Long.parseLong(tokens[1]);
+ if (itemKey ^ transpose) {
+ // If using items as keys, and not transposing items and users, then users are items!
+ // Or if not using items as keys (users are, as usual), but transposing items and users,
+ // then users are items! Confused?
+ long temp = userID;
+ userID = itemID;
+ itemID = temp;
+ }
+ if (booleanData) {
+ context.write(new VarLongWritable(userID), new VarLongWritable(itemID));
+ } else {
+ float prefValue = tokens.length > 2 ? Float.parseFloat(tokens[2]) + ratingShift : 1.0f;
+ context.write(new VarLongWritable(userID), new EntityPrefWritable(itemID, prefValue));
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToItemPrefsMapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToItemPrefsMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToItemPrefsMapper.java
new file mode 100644
index 0000000..f5f9574
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToItemPrefsMapper.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+/**
+ * <h1>Input</h1>
+ *
+ * <p>
+ * Intended for use with {@link org.apache.hadoop.mapreduce.lib.input.TextInputFormat};
+ * accepts line number / line pairs as
+ * {@link org.apache.hadoop.io.LongWritable}/{@link org.apache.hadoop.io.Text} pairs.
+ * </p>
+ *
+ * <p>
+ * Each line is assumed to be of the form {@code userID,itemID,preference}, or {@code userID,itemID}.
+ * </p>
+ *
+ * <h1>Output</h1>
+ *
+ * <p>
+ * Outputs the user ID as a {@link org.apache.mahout.math.VarLongWritable} mapped to the item ID and preference as a
+ * {@link EntityPrefWritable}.
+ * </p>
+ */
+public final class ToItemPrefsMapper extends ToEntityPrefsMapper {
+
+ public ToItemPrefsMapper() {
+ super(false);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueue.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueue.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueue.java
new file mode 100644
index 0000000..8f563b0
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueue.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.lucene.util.PriorityQueue;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+
+public class TopItemsQueue extends PriorityQueue<MutableRecommendedItem> {
+
+ private static final long SENTINEL_ID = Long.MIN_VALUE;
+
+ private final int maxSize;
+
+ public TopItemsQueue(int maxSize) {
+ super(maxSize);
+ this.maxSize = maxSize;
+ }
+
+ public List<RecommendedItem> getTopItems() {
+ List<RecommendedItem> recommendedItems = new ArrayList<>(maxSize);
+ while (size() > 0) {
+ MutableRecommendedItem topItem = pop();
+ // filter out "sentinel" objects necessary for maintaining an efficient priority queue
+ if (topItem.getItemID() != SENTINEL_ID) {
+ recommendedItems.add(topItem);
+ }
+ }
+ Collections.reverse(recommendedItems);
+ return recommendedItems;
+ }
+
+ @Override
+ protected boolean lessThan(MutableRecommendedItem one, MutableRecommendedItem two) {
+ return one.getValue() < two.getValue();
+ }
+
+ @Override
+ protected MutableRecommendedItem getSentinelObject() {
+ return new MutableRecommendedItem(SENTINEL_ID, Float.MIN_VALUE);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
new file mode 100644
index 0000000..4bb95ae
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
@@ -0,0 +1,100 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalFileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+
+final class ALS {
+
+ private ALS() {}
+
+ static Vector readFirstRow(Path dir, Configuration conf) throws IOException {
+ Iterator<VectorWritable> iterator = new SequenceFileDirValueIterator<>(dir, PathType.LIST,
+ PathFilters.partFilter(), null, true, conf);
+ return iterator.hasNext() ? iterator.next().get() : null;
+ }
+
+ public static OpenIntObjectHashMap<Vector> readMatrixByRowsFromDistributedCache(int numEntities,
+ Configuration conf) throws IOException {
+
+ IntWritable rowIndex = new IntWritable();
+ VectorWritable row = new VectorWritable();
+
+
+ OpenIntObjectHashMap<Vector> featureMatrix = numEntities > 0
+ ? new OpenIntObjectHashMap<Vector>(numEntities) : new OpenIntObjectHashMap<Vector>();
+
+ Path[] cachedFiles = HadoopUtil.getCachedFiles(conf);
+ LocalFileSystem localFs = FileSystem.getLocal(conf);
+
+ for (Path cachedFile : cachedFiles) {
+ try (SequenceFile.Reader reader = new SequenceFile.Reader(localFs.getConf(), SequenceFile.Reader.file(cachedFile))) {
+ while (reader.next(rowIndex, row)) {
+ featureMatrix.put(rowIndex.get(), row.get());
+ }
+ }
+ }
+
+ Preconditions.checkState(!featureMatrix.isEmpty(), "Feature matrix is empty");
+ return featureMatrix;
+ }
+
+ public static OpenIntObjectHashMap<Vector> readMatrixByRows(Path dir, Configuration conf) {
+ OpenIntObjectHashMap<Vector> matrix = new OpenIntObjectHashMap<>();
+ for (Pair<IntWritable,VectorWritable> pair
+ : new SequenceFileDirIterable<IntWritable,VectorWritable>(dir, PathType.LIST, PathFilters.partFilter(), conf)) {
+ int rowIndex = pair.getFirst().get();
+ Vector row = pair.getSecond().get();
+ matrix.put(rowIndex, row);
+ }
+ return matrix;
+ }
+
+ public static Vector solveExplicit(VectorWritable ratingsWritable, OpenIntObjectHashMap<Vector> uOrM,
+ double lambda, int numFeatures) {
+ Vector ratings = ratingsWritable.get();
+
+ List<Vector> featureVectors = new ArrayList<>(ratings.getNumNondefaultElements());
+ for (Vector.Element e : ratings.nonZeroes()) {
+ int index = e.index();
+ featureVectors.add(uOrM.get(index));
+ }
+
+ return AlternatingLeastSquaresSolver.solve(featureVectors, ratings, lambda, numFeatures);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java
new file mode 100644
index 0000000..b061a63
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java
@@ -0,0 +1,158 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.RandomUtils;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * <p>Split a recommendation dataset into a training and a test set</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--input (path): Directory containing one or more text files with the dataset</li>
+ * <li>--output (path): path where output should go</li>
+ * <li>--trainingPercentage (double): percentage of the data to use as training set (optional, default 0.9)</li>
+ * <li>--probePercentage (double): percentage of the data to use as probe set (optional, default 0.1)</li>
+ * </ol>
+ */
+public class DatasetSplitter extends AbstractJob {
+
+ private static final String TRAINING_PERCENTAGE = DatasetSplitter.class.getName() + ".trainingPercentage";
+ private static final String PROBE_PERCENTAGE = DatasetSplitter.class.getName() + ".probePercentage";
+ private static final String PART_TO_USE = DatasetSplitter.class.getName() + ".partToUse";
+
+ private static final Text INTO_TRAINING_SET = new Text("T");
+ private static final Text INTO_PROBE_SET = new Text("P");
+
+ private static final double DEFAULT_TRAINING_PERCENTAGE = 0.9;
+ private static final double DEFAULT_PROBE_PERCENTAGE = 0.1;
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new DatasetSplitter(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("trainingPercentage", "t", "percentage of the data to use as training set (default: "
+ + DEFAULT_TRAINING_PERCENTAGE + ')', String.valueOf(DEFAULT_TRAINING_PERCENTAGE));
+ addOption("probePercentage", "p", "percentage of the data to use as probe set (default: "
+ + DEFAULT_PROBE_PERCENTAGE + ')', String.valueOf(DEFAULT_PROBE_PERCENTAGE));
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ double trainingPercentage = Double.parseDouble(getOption("trainingPercentage"));
+ double probePercentage = Double.parseDouble(getOption("probePercentage"));
+ String tempDir = getOption("tempDir");
+
+ Path markedPrefs = new Path(tempDir, "markedPreferences");
+ Path trainingSetPath = new Path(getOutputPath(), "trainingSet");
+ Path probeSetPath = new Path(getOutputPath(), "probeSet");
+
+ Job markPreferences = prepareJob(getInputPath(), markedPrefs, TextInputFormat.class, MarkPreferencesMapper.class,
+ Text.class, Text.class, SequenceFileOutputFormat.class);
+ markPreferences.getConfiguration().set(TRAINING_PERCENTAGE, String.valueOf(trainingPercentage));
+ markPreferences.getConfiguration().set(PROBE_PERCENTAGE, String.valueOf(probePercentage));
+ boolean succeeded = markPreferences.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ Job createTrainingSet = prepareJob(markedPrefs, trainingSetPath, SequenceFileInputFormat.class,
+ WritePrefsMapper.class, NullWritable.class, Text.class, TextOutputFormat.class);
+ createTrainingSet.getConfiguration().set(PART_TO_USE, INTO_TRAINING_SET.toString());
+ succeeded = createTrainingSet.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ Job createProbeSet = prepareJob(markedPrefs, probeSetPath, SequenceFileInputFormat.class,
+ WritePrefsMapper.class, NullWritable.class, Text.class, TextOutputFormat.class);
+ createProbeSet.getConfiguration().set(PART_TO_USE, INTO_PROBE_SET.toString());
+ succeeded = createProbeSet.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ return 0;
+ }
+
+ static class MarkPreferencesMapper extends Mapper<LongWritable,Text,Text,Text> {
+
+ private Random random;
+ private double trainingBound;
+ private double probeBound;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ random = RandomUtils.getRandom();
+ trainingBound = Double.parseDouble(ctx.getConfiguration().get(TRAINING_PERCENTAGE));
+ probeBound = trainingBound + Double.parseDouble(ctx.getConfiguration().get(PROBE_PERCENTAGE));
+ }
+
+ @Override
+ protected void map(LongWritable key, Text text, Context ctx) throws IOException, InterruptedException {
+ double randomValue = random.nextDouble();
+ if (randomValue <= trainingBound) {
+ ctx.write(INTO_TRAINING_SET, text);
+ } else if (randomValue <= probeBound) {
+ ctx.write(INTO_PROBE_SET, text);
+ }
+ }
+ }
+
+ static class WritePrefsMapper extends Mapper<Text,Text,NullWritable,Text> {
+
+ private String partToUse;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ partToUse = ctx.getConfiguration().get(PART_TO_USE);
+ }
+
+ @Override
+ protected void map(Text key, Text text, Context ctx) throws IOException, InterruptedException {
+ if (partToUse.equals(key.toString())) {
+ ctx.write(NullWritable.get(), text);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java
new file mode 100644
index 0000000..4e6aaf5
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java
@@ -0,0 +1,166 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.io.Charsets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+
+/**
+ * <p>Measures the root-mean-squared error of a rating matrix factorization against a test set.</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--output (path): path where output should go</li>
+ * <li>--pairs (path): path containing the test ratings, each line must be userID,itemID,rating</li>
+ * <li>--userFeatures (path): path to the user feature matrix</li>
+ * <li>--itemFeatures (path): path to the item feature matrix</li>
+ * </ol>
+ */
+public class FactorizationEvaluator extends AbstractJob {
+
+ private static final String USER_FEATURES_PATH = RecommenderJob.class.getName() + ".userFeatures";
+ private static final String ITEM_FEATURES_PATH = RecommenderJob.class.getName() + ".itemFeatures";
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new FactorizationEvaluator(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOption("userFeatures", null, "path to the user feature matrix", true);
+ addOption("itemFeatures", null, "path to the item feature matrix", true);
+ addOption("usesLongIDs", null, "input contains long IDs that need to be translated");
+ addOutputOption();
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Path errors = getTempPath("errors");
+
+ Job predictRatings = prepareJob(getInputPath(), errors, TextInputFormat.class, PredictRatingsMapper.class,
+ DoubleWritable.class, NullWritable.class, SequenceFileOutputFormat.class);
+
+ Configuration conf = predictRatings.getConfiguration();
+ conf.set(USER_FEATURES_PATH, getOption("userFeatures"));
+ conf.set(ITEM_FEATURES_PATH, getOption("itemFeatures"));
+
+ boolean usesLongIDs = Boolean.parseBoolean(getOption("usesLongIDs"));
+ if (usesLongIDs) {
+ conf.set(ParallelALSFactorizationJob.USES_LONG_IDS, String.valueOf(true));
+ }
+
+
+ boolean succeeded = predictRatings.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ FileSystem fs = FileSystem.get(getOutputPath().toUri(), getConf());
+ FSDataOutputStream outputStream = fs.create(getOutputPath("rmse.txt"));
+ try (BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(outputStream, Charsets.UTF_8))){
+ double rmse = computeRmse(errors);
+ writer.write(String.valueOf(rmse));
+ }
+ return 0;
+ }
+
+ private double computeRmse(Path errors) {
+ RunningAverage average = new FullRunningAverage();
+ for (Pair<DoubleWritable,NullWritable> entry
+ : new SequenceFileDirIterable<DoubleWritable, NullWritable>(errors, PathType.LIST, PathFilters.logsCRCFilter(),
+ getConf())) {
+ DoubleWritable error = entry.getFirst();
+ average.addDatum(error.get() * error.get());
+ }
+
+ return Math.sqrt(average.getAverage());
+ }
+
+ public static class PredictRatingsMapper extends Mapper<LongWritable,Text,DoubleWritable,NullWritable> {
+
+ private OpenIntObjectHashMap<Vector> U;
+ private OpenIntObjectHashMap<Vector> M;
+
+ private boolean usesLongIDs;
+
+ private final DoubleWritable error = new DoubleWritable();
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ Configuration conf = ctx.getConfiguration();
+
+ Path pathToU = new Path(conf.get(USER_FEATURES_PATH));
+ Path pathToM = new Path(conf.get(ITEM_FEATURES_PATH));
+
+ U = ALS.readMatrixByRows(pathToU, conf);
+ M = ALS.readMatrixByRows(pathToM, conf);
+
+ usesLongIDs = conf.getBoolean(ParallelALSFactorizationJob.USES_LONG_IDS, false);
+ }
+
+ @Override
+ protected void map(LongWritable key, Text value, Context ctx) throws IOException, InterruptedException {
+
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString());
+
+ int userID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.USER_ID_POS], usesLongIDs);
+ int itemID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.ITEM_ID_POS], usesLongIDs);
+ double rating = Double.parseDouble(tokens[2]);
+
+ if (U.containsKey(userID) && M.containsKey(itemID)) {
+ double estimate = U.get(userID).dot(M.get(itemID));
+ error.set(rating - estimate);
+ ctx.write(error, NullWritable.get());
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java
new file mode 100644
index 0000000..d93e3a4
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper;
+import org.apache.hadoop.util.ReflectionUtils;
+
+import java.io.IOException;
+
+/**
+ * Multithreaded Mapper for {@link SharingMapper}s. Will call setupSharedInstance() once in the controlling thread
+ * before executing the mappers using a thread pool.
+ *
+ * @param <K1>
+ * @param <V1>
+ * @param <K2>
+ * @param <V2>
+ */
+public class MultithreadedSharingMapper<K1, V1, K2, V2> extends MultithreadedMapper<K1, V1, K2, V2> {
+
+ @Override
+ public void run(Context ctx) throws IOException, InterruptedException {
+ Class<Mapper<K1, V1, K2, V2>> mapperClass =
+ MultithreadedSharingMapper.getMapperClass((JobContext) ctx);
+ Preconditions.checkNotNull(mapperClass, "Could not find Multithreaded Mapper class.");
+
+ Configuration conf = ctx.getConfiguration();
+ // instantiate the mapper
+ Mapper<K1, V1, K2, V2> mapper1 = ReflectionUtils.newInstance(mapperClass, conf);
+ SharingMapper<K1, V1, K2, V2, ?> mapper = null;
+ if (mapper1 instanceof SharingMapper) {
+ mapper = (SharingMapper<K1, V1, K2, V2, ?>) mapper1;
+ }
+ Preconditions.checkNotNull(mapper, "Could not instantiate SharingMapper. Class was: %s",
+ mapper1.getClass().getName());
+
+ // single threaded call to setup the sharing mapper
+ mapper.setupSharedInstance(ctx);
+
+ // multithreaded execution
+ super.run(ctx);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
new file mode 100644
index 0000000..2ce9b61
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
@@ -0,0 +1,414 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.mapreduce.MergeVectorsCombiner;
+import org.apache.mahout.common.mapreduce.MergeVectorsReducer;
+import org.apache.mahout.common.mapreduce.TransposeMapper;
+import org.apache.mahout.common.mapreduce.VectorSumCombiner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.Vectors;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>MapReduce implementation of the two factorization algorithms described in
+ *
+ * <p>"Large-scale Parallel Collaborative Filtering for the Netflix Prize" available at
+ * http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf.</p>
+ *
+ * "<p>Collaborative Filtering for Implicit Feedback Datasets" available at
+ * http://research.yahoo.com/pub/2433</p>
+ *
+ * </p>
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--input (path): Directory containing one or more text files with the dataset</li>
+ * <li>--output (path): path where output should go</li>
+ * <li>--lambda (double): regularization parameter to avoid overfitting</li>
+ * <li>--userFeatures (path): path to the user feature matrix</li>
+ * <li>--itemFeatures (path): path to the item feature matrix</li>
+ * <li>--numThreadsPerSolver (int): threads to use per solver mapper, (default: 1)</li>
+ * </ol>
+ */
+public class ParallelALSFactorizationJob extends AbstractJob {
+
+ private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJob.class);
+
+ static final String NUM_FEATURES = ParallelALSFactorizationJob.class.getName() + ".numFeatures";
+ static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + ".lambda";
+ static final String ALPHA = ParallelALSFactorizationJob.class.getName() + ".alpha";
+ static final String NUM_ENTITIES = ParallelALSFactorizationJob.class.getName() + ".numEntities";
+
+ static final String USES_LONG_IDS = ParallelALSFactorizationJob.class.getName() + ".usesLongIDs";
+ static final String TOKEN_POS = ParallelALSFactorizationJob.class.getName() + ".tokenPos";
+
+ private boolean implicitFeedback;
+ private int numIterations;
+ private int numFeatures;
+ private double lambda;
+ private double alpha;
+ private int numThreadsPerSolver;
+
+ enum Stats { NUM_USERS }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new ParallelALSFactorizationJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("lambda", null, "regularization parameter", true);
+ addOption("implicitFeedback", null, "data consists of implicit feedback?", String.valueOf(false));
+ addOption("alpha", null, "confidence parameter (only used on implicit feedback)", String.valueOf(40));
+ addOption("numFeatures", null, "dimension of the feature space", true);
+ addOption("numIterations", null, "number of iterations", true);
+ addOption("numThreadsPerSolver", null, "threads per solver mapper", String.valueOf(1));
+ addOption("usesLongIDs", null, "input contains long IDs that need to be translated");
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ numFeatures = Integer.parseInt(getOption("numFeatures"));
+ numIterations = Integer.parseInt(getOption("numIterations"));
+ lambda = Double.parseDouble(getOption("lambda"));
+ alpha = Double.parseDouble(getOption("alpha"));
+ implicitFeedback = Boolean.parseBoolean(getOption("implicitFeedback"));
+
+ numThreadsPerSolver = Integer.parseInt(getOption("numThreadsPerSolver"));
+ boolean usesLongIDs = Boolean.parseBoolean(getOption("usesLongIDs", String.valueOf(false)));
+
+ /*
+ * compute the factorization A = U M'
+ *
+ * where A (users x items) is the matrix of known ratings
+ * U (users x features) is the representation of users in the feature space
+ * M (items x features) is the representation of items in the feature space
+ */
+
+ if (usesLongIDs) {
+ Job mapUsers = prepareJob(getInputPath(), getOutputPath("userIDIndex"), TextInputFormat.class,
+ MapLongIDsMapper.class, VarIntWritable.class, VarLongWritable.class, IDMapReducer.class,
+ VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class);
+ mapUsers.getConfiguration().set(TOKEN_POS, String.valueOf(TasteHadoopUtils.USER_ID_POS));
+ mapUsers.waitForCompletion(true);
+
+ Job mapItems = prepareJob(getInputPath(), getOutputPath("itemIDIndex"), TextInputFormat.class,
+ MapLongIDsMapper.class, VarIntWritable.class, VarLongWritable.class, IDMapReducer.class,
+ VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class);
+ mapItems.getConfiguration().set(TOKEN_POS, String.valueOf(TasteHadoopUtils.ITEM_ID_POS));
+ mapItems.waitForCompletion(true);
+ }
+
+ /* create A' */
+ Job itemRatings = prepareJob(getInputPath(), pathToItemRatings(),
+ TextInputFormat.class, ItemRatingVectorsMapper.class, IntWritable.class,
+ VectorWritable.class, VectorSumReducer.class, IntWritable.class,
+ VectorWritable.class, SequenceFileOutputFormat.class);
+ itemRatings.setCombinerClass(VectorSumCombiner.class);
+ itemRatings.getConfiguration().set(USES_LONG_IDS, String.valueOf(usesLongIDs));
+ boolean succeeded = itemRatings.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ /* create A */
+ Job userRatings = prepareJob(pathToItemRatings(), pathToUserRatings(),
+ TransposeMapper.class, IntWritable.class, VectorWritable.class, MergeUserVectorsReducer.class,
+ IntWritable.class, VectorWritable.class);
+ userRatings.setCombinerClass(MergeVectorsCombiner.class);
+ succeeded = userRatings.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ //TODO this could be fiddled into one of the upper jobs
+ Job averageItemRatings = prepareJob(pathToItemRatings(), getTempPath("averageRatings"),
+ AverageRatingMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class,
+ IntWritable.class, VectorWritable.class);
+ averageItemRatings.setCombinerClass(MergeVectorsCombiner.class);
+ succeeded = averageItemRatings.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ Vector averageRatings = ALS.readFirstRow(getTempPath("averageRatings"), getConf());
+
+ int numItems = averageRatings.getNumNondefaultElements();
+ int numUsers = (int) userRatings.getCounters().findCounter(Stats.NUM_USERS).getValue();
+
+ log.info("Found {} users and {} items", numUsers, numItems);
+
+ /* create an initial M */
+ initializeM(averageRatings);
+
+ for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) {
+ /* broadcast M, read A row-wise, recompute U row-wise */
+ log.info("Recomputing U (iteration {}/{})", currentIteration, numIterations);
+ runSolver(pathToUserRatings(), pathToU(currentIteration), pathToM(currentIteration - 1), currentIteration, "U",
+ numItems);
+ /* broadcast U, read A' row-wise, recompute M row-wise */
+ log.info("Recomputing M (iteration {}/{})", currentIteration, numIterations);
+ runSolver(pathToItemRatings(), pathToM(currentIteration), pathToU(currentIteration), currentIteration, "M",
+ numUsers);
+ }
+
+ return 0;
+ }
+
+ private void initializeM(Vector averageRatings) throws IOException {
+ Random random = RandomUtils.getRandom();
+
+ FileSystem fs = FileSystem.get(pathToM(-1).toUri(), getConf());
+ try (SequenceFile.Writer writer =
+ new SequenceFile.Writer(fs, getConf(), new Path(pathToM(-1), "part-m-00000"),
+ IntWritable.class, VectorWritable.class)) {
+ IntWritable index = new IntWritable();
+ VectorWritable featureVector = new VectorWritable();
+
+ for (Vector.Element e : averageRatings.nonZeroes()) {
+ Vector row = new DenseVector(numFeatures);
+ row.setQuick(0, e.get());
+ for (int m = 1; m < numFeatures; m++) {
+ row.setQuick(m, random.nextDouble());
+ }
+ index.set(e.index());
+ featureVector.set(row);
+ writer.append(index, featureVector);
+ }
+ }
+ }
+
+ static class VectorSumReducer
+ extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
+
+ private final VectorWritable result = new VectorWritable();
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context ctx)
+ throws IOException, InterruptedException {
+ Vector sum = Vectors.sum(values.iterator());
+ result.set(new SequentialAccessSparseVector(sum));
+ ctx.write(key, result);
+ }
+ }
+
+ static class MergeUserVectorsReducer extends
+ Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable> {
+
+ private final VectorWritable result = new VectorWritable();
+
+ @Override
+ public void reduce(WritableComparable<?> key, Iterable<VectorWritable> vectors, Context ctx)
+ throws IOException, InterruptedException {
+ Vector merged = VectorWritable.merge(vectors.iterator()).get();
+ result.set(new SequentialAccessSparseVector(merged));
+ ctx.write(key, result);
+ ctx.getCounter(Stats.NUM_USERS).increment(1);
+ }
+ }
+
+ static class ItemRatingVectorsMapper extends Mapper<LongWritable,Text,IntWritable,VectorWritable> {
+
+ private final IntWritable itemIDWritable = new IntWritable();
+ private final VectorWritable ratingsWritable = new VectorWritable(true);
+ private final Vector ratings = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
+
+ private boolean usesLongIDs;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ usesLongIDs = ctx.getConfiguration().getBoolean(USES_LONG_IDS, false);
+ }
+
+ @Override
+ protected void map(LongWritable offset, Text line, Context ctx) throws IOException, InterruptedException {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
+ int userID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.USER_ID_POS], usesLongIDs);
+ int itemID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.ITEM_ID_POS], usesLongIDs);
+ float rating = Float.parseFloat(tokens[2]);
+
+ ratings.setQuick(userID, rating);
+
+ itemIDWritable.set(itemID);
+ ratingsWritable.set(ratings);
+
+ ctx.write(itemIDWritable, ratingsWritable);
+
+ // prepare instance for reuse
+ ratings.setQuick(userID, 0.0d);
+ }
+ }
+
+ private void runSolver(Path ratings, Path output, Path pathToUorM, int currentIteration, String matrixName,
+ int numEntities) throws ClassNotFoundException, IOException, InterruptedException {
+
+ // necessary for local execution in the same JVM only
+ SharingMapper.reset();
+
+ Class<? extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable>> solverMapperClassInternal;
+ String name;
+
+ if (implicitFeedback) {
+ solverMapperClassInternal = SolveImplicitFeedbackMapper.class;
+ name = "Recompute " + matrixName + ", iteration (" + currentIteration + '/' + numIterations + "), "
+ + '(' + numThreadsPerSolver + " threads, " + numFeatures + " features, implicit feedback)";
+ } else {
+ solverMapperClassInternal = SolveExplicitFeedbackMapper.class;
+ name = "Recompute " + matrixName + ", iteration (" + currentIteration + '/' + numIterations + "), "
+ + '(' + numThreadsPerSolver + " threads, " + numFeatures + " features, explicit feedback)";
+ }
+
+ Job solverForUorI = prepareJob(ratings, output, SequenceFileInputFormat.class, MultithreadedSharingMapper.class,
+ IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, name);
+ Configuration solverConf = solverForUorI.getConfiguration();
+ solverConf.set(LAMBDA, String.valueOf(lambda));
+ solverConf.set(ALPHA, String.valueOf(alpha));
+ solverConf.setInt(NUM_FEATURES, numFeatures);
+ solverConf.set(NUM_ENTITIES, String.valueOf(numEntities));
+
+ FileSystem fs = FileSystem.get(pathToUorM.toUri(), solverConf);
+ FileStatus[] parts = fs.listStatus(pathToUorM, PathFilters.partFilter());
+ for (FileStatus part : parts) {
+ if (log.isDebugEnabled()) {
+ log.debug("Adding {} to distributed cache", part.getPath().toString());
+ }
+ DistributedCache.addCacheFile(part.getPath().toUri(), solverConf);
+ }
+
+ MultithreadedMapper.setMapperClass(solverForUorI, solverMapperClassInternal);
+ MultithreadedMapper.setNumberOfThreads(solverForUorI, numThreadsPerSolver);
+
+ boolean succeeded = solverForUorI.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+ static class AverageRatingMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private final IntWritable firstIndex = new IntWritable(0);
+ private final Vector featureVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
+ private final VectorWritable featureVectorWritable = new VectorWritable();
+
+ @Override
+ protected void map(IntWritable r, VectorWritable v, Context ctx) throws IOException, InterruptedException {
+ RunningAverage avg = new FullRunningAverage();
+ for (Vector.Element e : v.get().nonZeroes()) {
+ avg.addDatum(e.get());
+ }
+
+ featureVector.setQuick(r.get(), avg.getAverage());
+ featureVectorWritable.set(featureVector);
+ ctx.write(firstIndex, featureVectorWritable);
+
+ // prepare instance for reuse
+ featureVector.setQuick(r.get(), 0.0d);
+ }
+ }
+
+ static class MapLongIDsMapper extends Mapper<LongWritable,Text,VarIntWritable,VarLongWritable> {
+
+ private int tokenPos;
+ private final VarIntWritable index = new VarIntWritable();
+ private final VarLongWritable idWritable = new VarLongWritable();
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ tokenPos = ctx.getConfiguration().getInt(TOKEN_POS, -1);
+ Preconditions.checkState(tokenPos >= 0);
+ }
+
+ @Override
+ protected void map(LongWritable key, Text line, Context ctx) throws IOException, InterruptedException {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
+
+ long id = Long.parseLong(tokens[tokenPos]);
+
+ index.set(TasteHadoopUtils.idToIndex(id));
+ idWritable.set(id);
+ ctx.write(index, idWritable);
+ }
+ }
+
+ static class IDMapReducer extends Reducer<VarIntWritable,VarLongWritable,VarIntWritable,VarLongWritable> {
+ @Override
+ protected void reduce(VarIntWritable index, Iterable<VarLongWritable> ids, Context ctx)
+ throws IOException, InterruptedException {
+ ctx.write(index, ids.iterator().next());
+ }
+ }
+
+ private Path pathToM(int iteration) {
+ return iteration == numIterations - 1 ? getOutputPath("M") : getTempPath("M-" + iteration);
+ }
+
+ private Path pathToU(int iteration) {
+ return iteration == numIterations - 1 ? getOutputPath("U") : getTempPath("U-" + iteration);
+ }
+
+ private Path pathToItemRatings() {
+ return getTempPath("itemRatings");
+ }
+
+ private Path pathToUserRatings() {
+ return getOutputPath("userRatings");
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java
new file mode 100644
index 0000000..6e7ea81
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java
@@ -0,0 +1,145 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.mahout.cf.taste.hadoop.MutableRecommendedItem;
+import org.apache.mahout.cf.taste.hadoop.RecommendedItemsWritable;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.hadoop.TopItemsQueue;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.IntObjectProcedure;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+import org.apache.mahout.math.set.OpenIntHashSet;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * a multithreaded mapper that loads the feature matrices U and M into memory. Afterwards it computes recommendations
+ * from these. Can be executed by a {@link MultithreadedSharingMapper}.
+ */
+public class PredictionMapper extends SharingMapper<IntWritable,VectorWritable,LongWritable,RecommendedItemsWritable,
+ Pair<OpenIntObjectHashMap<Vector>,OpenIntObjectHashMap<Vector>>> {
+
+ private int recommendationsPerUser;
+ private float maxRating;
+
+ private boolean usesLongIDs;
+ private OpenIntLongHashMap userIDIndex;
+ private OpenIntLongHashMap itemIDIndex;
+
+ private final LongWritable userIDWritable = new LongWritable();
+ private final RecommendedItemsWritable recommendations = new RecommendedItemsWritable();
+
+ @Override
+ Pair<OpenIntObjectHashMap<Vector>, OpenIntObjectHashMap<Vector>> createSharedInstance(Context ctx) {
+ Configuration conf = ctx.getConfiguration();
+ Path pathToU = new Path(conf.get(RecommenderJob.USER_FEATURES_PATH));
+ Path pathToM = new Path(conf.get(RecommenderJob.ITEM_FEATURES_PATH));
+
+ OpenIntObjectHashMap<Vector> U = ALS.readMatrixByRows(pathToU, conf);
+ OpenIntObjectHashMap<Vector> M = ALS.readMatrixByRows(pathToM, conf);
+
+ return new Pair<>(U, M);
+ }
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ Configuration conf = ctx.getConfiguration();
+ recommendationsPerUser = conf.getInt(RecommenderJob.NUM_RECOMMENDATIONS,
+ RecommenderJob.DEFAULT_NUM_RECOMMENDATIONS);
+ maxRating = Float.parseFloat(conf.get(RecommenderJob.MAX_RATING));
+
+ usesLongIDs = conf.getBoolean(ParallelALSFactorizationJob.USES_LONG_IDS, false);
+ if (usesLongIDs) {
+ userIDIndex = TasteHadoopUtils.readIDIndexMap(conf.get(RecommenderJob.USER_INDEX_PATH), conf);
+ itemIDIndex = TasteHadoopUtils.readIDIndexMap(conf.get(RecommenderJob.ITEM_INDEX_PATH), conf);
+ }
+ }
+
+ @Override
+ protected void map(IntWritable userIndexWritable, VectorWritable ratingsWritable, Context ctx)
+ throws IOException, InterruptedException {
+
+ Pair<OpenIntObjectHashMap<Vector>, OpenIntObjectHashMap<Vector>> uAndM = getSharedInstance();
+ OpenIntObjectHashMap<Vector> U = uAndM.getFirst();
+ OpenIntObjectHashMap<Vector> M = uAndM.getSecond();
+
+ Vector ratings = ratingsWritable.get();
+ int userIndex = userIndexWritable.get();
+ final OpenIntHashSet alreadyRatedItems = new OpenIntHashSet(ratings.getNumNondefaultElements());
+
+ for (Vector.Element e : ratings.nonZeroes()) {
+ alreadyRatedItems.add(e.index());
+ }
+
+ final TopItemsQueue topItemsQueue = new TopItemsQueue(recommendationsPerUser);
+ final Vector userFeatures = U.get(userIndex);
+
+ M.forEachPair(new IntObjectProcedure<Vector>() {
+ @Override
+ public boolean apply(int itemID, Vector itemFeatures) {
+ if (!alreadyRatedItems.contains(itemID)) {
+ double predictedRating = userFeatures.dot(itemFeatures);
+
+ MutableRecommendedItem top = topItemsQueue.top();
+ if (predictedRating > top.getValue()) {
+ top.set(itemID, (float) predictedRating);
+ topItemsQueue.updateTop();
+ }
+ }
+ return true;
+ }
+ });
+
+ List<RecommendedItem> recommendedItems = topItemsQueue.getTopItems();
+
+ if (!recommendedItems.isEmpty()) {
+
+ // cap predictions to maxRating
+ for (RecommendedItem topItem : recommendedItems) {
+ ((MutableRecommendedItem) topItem).capToMaxValue(maxRating);
+ }
+
+ if (usesLongIDs) {
+ long userID = userIDIndex.get(userIndex);
+ userIDWritable.set(userID);
+
+ for (RecommendedItem topItem : recommendedItems) {
+ // remap item IDs
+ long itemID = itemIDIndex.get((int) topItem.getItemID());
+ ((MutableRecommendedItem) topItem).setItemID(itemID);
+ }
+
+ } else {
+ userIDWritable.set(userIndex);
+ }
+
+ recommendations.set(recommendedItems);
+ ctx.write(userIDWritable, recommendations);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java
new file mode 100644
index 0000000..679d227
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java
@@ -0,0 +1,110 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.RecommendedItemsWritable;
+import org.apache.mahout.common.AbstractJob;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * <p>Computes the top-N recommendations per user from a decomposition of the rating matrix</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--input (path): Directory containing the vectorized user ratings</li>
+ * <li>--output (path): path where output should go</li>
+ * <li>--numRecommendations (int): maximum number of recommendations per user (default: 10)</li>
+ * <li>--maxRating (double): maximum rating of an item</li>
+ * <li>--numThreads (int): threads to use per mapper, (default: 1)</li>
+ * </ol>
+ */
+public class RecommenderJob extends AbstractJob {
+
+ static final String NUM_RECOMMENDATIONS = RecommenderJob.class.getName() + ".numRecommendations";
+ static final String USER_FEATURES_PATH = RecommenderJob.class.getName() + ".userFeatures";
+ static final String ITEM_FEATURES_PATH = RecommenderJob.class.getName() + ".itemFeatures";
+ static final String MAX_RATING = RecommenderJob.class.getName() + ".maxRating";
+ static final String USER_INDEX_PATH = RecommenderJob.class.getName() + ".userIndex";
+ static final String ITEM_INDEX_PATH = RecommenderJob.class.getName() + ".itemIndex";
+
+ static final int DEFAULT_NUM_RECOMMENDATIONS = 10;
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new RecommenderJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOption("userFeatures", null, "path to the user feature matrix", true);
+ addOption("itemFeatures", null, "path to the item feature matrix", true);
+ addOption("numRecommendations", null, "number of recommendations per user",
+ String.valueOf(DEFAULT_NUM_RECOMMENDATIONS));
+ addOption("maxRating", null, "maximum rating available", true);
+ addOption("numThreads", null, "threads per mapper", String.valueOf(1));
+ addOption("usesLongIDs", null, "input contains long IDs that need to be translated");
+ addOption("userIDIndex", null, "index for user long IDs (necessary if usesLongIDs is true)");
+ addOption("itemIDIndex", null, "index for user long IDs (necessary if usesLongIDs is true)");
+ addOutputOption();
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Job prediction = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class,
+ MultithreadedSharingMapper.class, IntWritable.class, RecommendedItemsWritable.class, TextOutputFormat.class);
+ Configuration conf = prediction.getConfiguration();
+
+ int numThreads = Integer.parseInt(getOption("numThreads"));
+
+ conf.setInt(NUM_RECOMMENDATIONS, Integer.parseInt(getOption("numRecommendations")));
+ conf.set(USER_FEATURES_PATH, getOption("userFeatures"));
+ conf.set(ITEM_FEATURES_PATH, getOption("itemFeatures"));
+ conf.set(MAX_RATING, getOption("maxRating"));
+
+ boolean usesLongIDs = Boolean.parseBoolean(getOption("usesLongIDs"));
+ if (usesLongIDs) {
+ conf.set(ParallelALSFactorizationJob.USES_LONG_IDS, String.valueOf(true));
+ conf.set(USER_INDEX_PATH, getOption("userIDIndex"));
+ conf.set(ITEM_INDEX_PATH, getOption("itemIDIndex"));
+ }
+
+ MultithreadedMapper.setMapperClass(prediction, PredictionMapper.class);
+ MultithreadedMapper.setNumberOfThreads(prediction, numThreads);
+
+ boolean succeeded = prediction.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ return 0;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
new file mode 100644
index 0000000..9925807
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.mapreduce.Mapper;
+
+import java.io.IOException;
+
+/**
+ * Mapper class to be used by {@link MultithreadedSharingMapper}. Offers "global" before() and after() methods
+ * that will typically be used to set up static variables.
+ *
+ * Suitable for mappers that need large, read-only in-memory data to operate.
+ *
+ * @param <K1>
+ * @param <V1>
+ * @param <K2>
+ * @param <V2>
+ */
+public abstract class SharingMapper<K1,V1,K2,V2,S> extends Mapper<K1,V1,K2,V2> {
+
+ private static Object SHARED_INSTANCE;
+
+ /**
+ * Called before the multithreaded execution
+ *
+ * @param context mapper's context
+ */
+ abstract S createSharedInstance(Context context) throws IOException;
+
+ final void setupSharedInstance(Context context) throws IOException {
+ if (SHARED_INSTANCE == null) {
+ SHARED_INSTANCE = createSharedInstance(context);
+ }
+ }
+
+ final S getSharedInstance() {
+ return (S) SHARED_INSTANCE;
+ }
+
+ static void reset() {
+ SHARED_INSTANCE = null;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
new file mode 100644
index 0000000..2569918
--- /dev/null
+++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+
+import java.io.IOException;
+
+/** Solving mapper that can be safely executed using multiple threads */
+public class SolveExplicitFeedbackMapper
+ extends SharingMapper<IntWritable,VectorWritable,IntWritable,VectorWritable,OpenIntObjectHashMap<Vector>> {
+
+ private double lambda;
+ private int numFeatures;
+ private final VectorWritable uiOrmj = new VectorWritable();
+
+ @Override
+ OpenIntObjectHashMap<Vector> createSharedInstance(Context ctx) throws IOException {
+ Configuration conf = ctx.getConfiguration();
+ int numEntities = Integer.parseInt(conf.get(ParallelALSFactorizationJob.NUM_ENTITIES));
+ return ALS.readMatrixByRowsFromDistributedCache(numEntities, conf);
+ }
+
+ @Override
+ protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
+ lambda = Double.parseDouble(ctx.getConfiguration().get(ParallelALSFactorizationJob.LAMBDA));
+ numFeatures = ctx.getConfiguration().getInt(ParallelALSFactorizationJob.NUM_FEATURES, -1);
+ Preconditions.checkArgument(numFeatures > 0, "numFeatures must be greater then 0!");
+ }
+
+ @Override
+ protected void map(IntWritable userOrItemID, VectorWritable ratingsWritable, Context ctx)
+ throws IOException, InterruptedException {
+ OpenIntObjectHashMap<Vector> uOrM = getSharedInstance();
+ uiOrmj.set(ALS.solveExplicit(ratingsWritable, uOrM, lambda, numFeatures));
+ ctx.write(userOrItemID, uiOrmj);
+ }
+
+}