You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ss...@apache.org on 2011/03/23 23:33:58 UTC
svn commit: r1084789 - in /mahout/trunk:
core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/
core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ examples/bin/
src/conf/
Author: ssc
Date: Wed Mar 23 22:33:57 2011
New Revision: 1084789
URL: http://svn.apache.org/viewvc?rev=1084789&view=rev
Log:
MAHOUT-542 MapReduce implementation of ALS-WR
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FeatureVectorWithRatingWritable.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/IndexedVarIntWritable.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionJob.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/TaggedVarIntWritable.java
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/VectorWithIndexWritable.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/FeatureVectorWithRatingWritableTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/PredictionJobTest.java
mahout/trunk/examples/bin/factorize-movielens-1M.sh
mahout/trunk/src/conf/evaluateFactorization.props
mahout/trunk/src/conf/evaluateFactorizationParallel.props
mahout/trunk/src/conf/parallelALS.props
mahout/trunk/src/conf/predictFromFactorization.props
mahout/trunk/src/conf/splitDataset.props
Modified:
mahout/trunk/src/conf/driver.classes.props
Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FeatureVectorWithRatingWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FeatureVectorWithRatingWritable.java?rev=1084789&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FeatureVectorWithRatingWritable.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FeatureVectorWithRatingWritable.java Wed Mar 23 22:33:57 2011
@@ -0,0 +1,135 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Varint;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+public class FeatureVectorWithRatingWritable implements Writable, Cloneable {
+
+ private int idIndex;
+ private Float rating;
+ private Vector vector;
+
+ public FeatureVectorWithRatingWritable() {
+ }
+
+ public FeatureVectorWithRatingWritable(int idIndex, Float rating, Vector featureVector) {
+ this.idIndex = idIndex;
+ this.rating = rating;
+ this.vector = featureVector;
+ }
+
+ public FeatureVectorWithRatingWritable(int idIndex, Vector featureVector) {
+ this.idIndex = idIndex;
+ this.vector = featureVector;
+ }
+
+ public FeatureVectorWithRatingWritable(int idIndex, float rating) {
+ this.idIndex = idIndex;
+ this.rating = rating;
+ }
+
+ public boolean containsFeatureVector() {
+ return vector != null;
+ }
+
+ public boolean containsRating() {
+ return rating != null;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ Varint.writeUnsignedVarInt(idIndex, out);
+ boolean containsRating = containsRating();
+ out.writeBoolean(containsRating);
+ if (containsRating) {
+ out.writeFloat(rating);
+ }
+ boolean containsFeatureVector = containsFeatureVector();
+ out.writeBoolean(containsFeatureVector);
+ if (containsFeatureVector) {
+ VectorWritable.writeVector(out, vector);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ rating = null;
+ vector = null;
+ idIndex = Varint.readUnsignedVarInt(in);
+ boolean containsRating = in.readBoolean();
+ if (containsRating) {
+ rating = in.readFloat();
+ }
+ boolean containsFeatureVector = in.readBoolean();
+ if (containsFeatureVector) {
+ VectorWritable vw = new VectorWritable();
+ vw.readFields(in);
+ vector = vw.get();
+ }
+ }
+
+ public int getIDIndex() {
+ return idIndex;
+ }
+
+ public Float getRating() {
+ return rating;
+ }
+
+ public Vector getFeatureVector() {
+ return vector;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof FeatureVectorWithRatingWritable)) {
+ return false;
+ }
+ FeatureVectorWithRatingWritable other = (FeatureVectorWithRatingWritable) o;
+ if (idIndex != other.idIndex) {
+ return false;
+ }
+ if (rating != null ? !rating.equals(other.rating) : other.rating != null) {
+ return false;
+ }
+ if (vector != null ? !vector.equals(other.vector) : other.vector != null) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = 31 * idIndex + (rating != null ? rating.hashCode() : 0);
+ result = 31 * result + (vector != null ? vector.hashCode() : 0);
+ return result;
+ }
+
+ @Override
+ protected FeatureVectorWithRatingWritable clone() {
+ return new FeatureVectorWithRatingWritable(idIndex, rating, vector);
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/IndexedVarIntWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/IndexedVarIntWritable.java?rev=1084789&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/IndexedVarIntWritable.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/IndexedVarIntWritable.java Wed Mar 23 22:33:57 2011
@@ -0,0 +1,115 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.io.WritableComparator;
+import org.apache.mahout.math.Varint;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.io.Serializable;
+
+public class IndexedVarIntWritable implements WritableComparable<IndexedVarIntWritable> {
+
+ private int value;
+ private int index;
+
+ static {
+ WritableComparator.define(IndexedVarIntWritable.class, new SecondarySortComparator());
+ }
+
+ public IndexedVarIntWritable() {
+ }
+
+ public IndexedVarIntWritable(int value, int index) {
+ this.value = value;
+ this.index = index;
+ }
+
+ public int getValue() {
+ return value;
+ }
+
+ @Override
+ public int compareTo(IndexedVarIntWritable other) {
+ return value == other.value ? 0 : value < other.value ? -1 : 1;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ Varint.writeSignedVarInt(value, out);
+ Varint.writeSignedVarInt(index, out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ value = Varint.readSignedVarInt(in);
+ index = Varint.readSignedVarInt(in);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof IndexedVarIntWritable) {
+ return value == ((IndexedVarIntWritable) o).value;
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return value;
+ }
+
+ public static class SecondarySortComparator extends WritableComparator implements Serializable {
+
+ protected SecondarySortComparator() {
+ super(IndexedVarIntWritable.class, true);
+ }
+
+ @Override
+ public int compare(WritableComparable a, WritableComparable b) {
+ IndexedVarIntWritable first = (IndexedVarIntWritable) a;
+ IndexedVarIntWritable second = (IndexedVarIntWritable) b;
+
+ int result = compare(first.value, second.value);
+ if (result == 0) {
+ result = compare(first.index, second.index);
+ }
+ return result;
+ }
+
+ protected static int compare(int a, int b) {
+ return (a == b) ? 0 : (a < b) ? -1 : 1;
+ }
+ }
+
+ public static class GroupingComparator extends WritableComparator implements Serializable {
+
+ protected GroupingComparator() {
+ super(IndexedVarIntWritable.class, true);
+ }
+
+ @Override
+ public int compare(WritableComparable a, WritableComparable b) {
+ return a.compareTo(b);
+ }
+ }
+
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java?rev=1084789&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java Wed Mar 23 22:33:57 2011
@@ -0,0 +1,322 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.partition.HashPartitioner;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.math.*;
+import org.apache.mahout.math.als.AlternateLeastSquaresSolver;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+/**
+ * <p>MapReduce implementation of the factorization algorithm described in "Large-scale Parallel Collaborative Filtering for the Netï¬ix Prize"
+ * available at http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf.</p>
+ *
+ * <p>Implements a parallel algorithm that uses "Alternating-Least-Squares with Weighted-λ-Regularization" to factorize the
+ * preference-matrix </p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--input (path): Directory containing one or more text files with the dataset</li>
+ * <li>--output (path): path where output should go</li>
+ * <li>--lambda (double): regularization parameter to avoid overfitting</li>
+ * <li>--numFeatures (int): number of features to use for decomposition </li>
+* <li>--numIterations (int): number of iterations to run</li>
+ * </ol>
+ */
+public class ParallelALSFactorizationJob extends AbstractJob {
+
+ static final String NUM_FEATURES = ParallelALSFactorizationJob.class.getName() + ".numFeatures";
+ static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + ".lambda";
+ static final String MAP_TRANSPOSED = ParallelALSFactorizationJob.class.getName() + ".mapTransposed";
+
+ static final String STEP_ONE = "fixMcomputeU";
+ static final String STEP_TWO = "fixUcomputeM";
+
+ private String tempDir;
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new ParallelALSFactorizationJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("lambda", "l", "", true);
+ addOption("numFeatures", "f", "", true);
+ addOption("numIterations", "i", "", true);
+
+ Map<String,String> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ int numFeatures = Integer.parseInt(parsedArgs.get("--numFeatures"));
+ int numIterations = Integer.parseInt(parsedArgs.get("--numIterations"));
+ double lambda = Double.parseDouble(parsedArgs.get("--lambda"));
+ tempDir = parsedArgs.get("--tempDir");
+
+ Job itemRatings = prepareJob(getInputPath(), pathToItemRatings(),
+ TextInputFormat.class, PrefsToRatingsMapper.class, VarIntWritable.class,
+ FeatureVectorWithRatingWritable.class, Reducer.class, VarIntWritable.class,
+ FeatureVectorWithRatingWritable.class, SequenceFileOutputFormat.class);
+ itemRatings.waitForCompletion(true);
+
+ Job userRatings = prepareJob(getInputPath(), pathToUserRatings(),
+ TextInputFormat.class, PrefsToRatingsMapper.class, VarIntWritable.class,
+ FeatureVectorWithRatingWritable.class, Reducer.class, VarIntWritable.class,
+ FeatureVectorWithRatingWritable.class, SequenceFileOutputFormat.class);
+ userRatings.getConfiguration().setBoolean(MAP_TRANSPOSED, Boolean.TRUE);
+ userRatings.waitForCompletion(true);
+
+ Job initializeM = prepareJob(getInputPath(), pathToM(-1), TextInputFormat.class, ItemIDRatingMapper.class,
+ VarLongWritable.class, FloatWritable.class, InitializeMReducer.class, VarIntWritable.class,
+ FeatureVectorWithRatingWritable.class, SequenceFileOutputFormat.class);
+ initializeM.getConfiguration().setInt(NUM_FEATURES, numFeatures);
+ initializeM.waitForCompletion(true);
+
+ for (int n = 0; n < numIterations; n++) {
+ iterate(n, numFeatures, lambda);
+ }
+
+ Job uAsMatrix = prepareJob(pathToU(numIterations - 1), new Path(getOutputPath(), "U"),
+ SequenceFileInputFormat.class, ToMatrixMapper.class, IntWritable.class, VectorWritable.class, Reducer.class,
+ IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
+ uAsMatrix.waitForCompletion(true);
+
+ Job mAsMatrix = prepareJob(pathToM(numIterations - 1), new Path(getOutputPath(), "M"),
+ SequenceFileInputFormat.class, ToMatrixMapper.class, IntWritable.class, VectorWritable.class, Reducer.class,
+ IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class);
+ mAsMatrix.waitForCompletion(true);
+
+ return 0;
+ }
+
+ static class ToMatrixMapper
+ extends Mapper<VarIntWritable,FeatureVectorWithRatingWritable,IntWritable,VectorWritable> {
+ @Override
+ protected void map(VarIntWritable key, FeatureVectorWithRatingWritable value, Context ctx)
+ throws IOException, InterruptedException {
+ ctx.write(new IntWritable(key.get()), new VectorWritable(value.getFeatureVector()));
+ }
+ }
+
+
+ private void iterate(int currentIteration, int numFeatures, double lambda)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ /* fix M, compute U */
+ joinAndSolve(pathToM(currentIteration - 1), pathToItemRatings(), pathToU(currentIteration), numFeatures,
+ lambda, currentIteration, STEP_ONE);
+ /* fix U, compute M */
+ joinAndSolve(pathToU(currentIteration), pathToUserRatings(), pathToM(currentIteration), numFeatures,
+ lambda, currentIteration, STEP_TWO);
+ }
+
+ private void joinAndSolve(Path featureMatrix, Path ratingMatrix, Path outputPath, int numFeatures, double lambda,
+ int currentIteration, String step) throws IOException, ClassNotFoundException, InterruptedException {
+
+ Path joinPath = new Path(ratingMatrix.toString() + "," + featureMatrix.toString());
+ Path featureVectorWithRatingPath = joinAndSolvePath(currentIteration, step);
+
+ Job joinToFeatureVectorWithRating = prepareJob(joinPath, featureVectorWithRatingPath, SequenceFileInputFormat.class,
+ Mapper.class, VarIntWritable.class, FeatureVectorWithRatingWritable.class,
+ JoinFeatureVectorAndRatingsReducer.class, IndexedVarIntWritable.class, FeatureVectorWithRatingWritable.class,
+ SequenceFileOutputFormat.class);
+ joinToFeatureVectorWithRating.waitForCompletion(true);
+
+ Job solve = prepareJob(featureVectorWithRatingPath, outputPath, SequenceFileInputFormat.class, Mapper.class,
+ IndexedVarIntWritable.class, FeatureVectorWithRatingWritable.class, SolvingReducer.class, VarIntWritable.class,
+ FeatureVectorWithRatingWritable.class, SequenceFileOutputFormat.class);
+ Configuration solveConf = solve.getConfiguration();
+ solve.setPartitionerClass(HashPartitioner.class);
+ solve.setGroupingComparatorClass(IndexedVarIntWritable.GroupingComparator.class);
+ solveConf.setInt(NUM_FEATURES, numFeatures);
+ solveConf.set(LAMBDA, String.valueOf(lambda));
+ solve.waitForCompletion(true);
+ }
+
+ static class PrefsToRatingsMapper
+ extends Mapper<LongWritable,Text,VarIntWritable,FeatureVectorWithRatingWritable> {
+
+ private boolean transpose;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ transpose = ctx.getConfiguration().getBoolean(MAP_TRANSPOSED, false);
+ }
+
+ @Override
+ protected void map(LongWritable offset, Text line, Context ctx) throws IOException, InterruptedException {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
+ int keyIDIndex = TasteHadoopUtils.idToIndex(Long.parseLong(tokens[transpose ? 0 : 1]));
+ int valueIDIndex = TasteHadoopUtils.idToIndex(Long.parseLong(tokens[transpose ? 1 : 0]));
+ float rating = Float.parseFloat(tokens[2]);
+ ctx.write(new VarIntWritable(keyIDIndex), new FeatureVectorWithRatingWritable(valueIDIndex, rating));
+ }
+ }
+
+ static class JoinFeatureVectorAndRatingsReducer
+ extends Reducer<VarIntWritable,FeatureVectorWithRatingWritable,IndexedVarIntWritable,FeatureVectorWithRatingWritable> {
+
+ @Override
+ protected void reduce(VarIntWritable id, Iterable<FeatureVectorWithRatingWritable> values, Context ctx)
+ throws IOException, InterruptedException {
+ Vector featureVector = null;
+ Map<Integer,Float> ratings = new HashMap<Integer,Float>();
+ for (FeatureVectorWithRatingWritable value : values) {
+ if (value.getFeatureVector() == null) {
+ ratings.put(value.getIDIndex(), new Float(value.getRating()));
+ } else {
+ featureVector = value.getFeatureVector().clone();
+ }
+ }
+
+ if (featureVector == null || ratings.isEmpty()) {
+ throw new IllegalStateException("Unable to join data for " + id);
+ }
+ for (Map.Entry<Integer,Float> rating : ratings.entrySet()) {
+ ctx.write(new IndexedVarIntWritable(rating.getKey(), id.get()),
+ new FeatureVectorWithRatingWritable(id.get(), rating.getValue(), featureVector));
+ }
+ }
+ }
+
+ static class SolvingReducer
+ extends Reducer<IndexedVarIntWritable,FeatureVectorWithRatingWritable,VarIntWritable,FeatureVectorWithRatingWritable> {
+
+ private int numFeatures;
+ private double lambda;
+ private AlternateLeastSquaresSolver solver;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
+ lambda = Double.parseDouble(ctx.getConfiguration().get(LAMBDA));
+ numFeatures = ctx.getConfiguration().getInt(NUM_FEATURES, -1);
+ if (numFeatures < 1) {
+ throw new IllegalStateException("numFeatures was not set correctly!");
+ }
+ solver = new AlternateLeastSquaresSolver();
+ }
+
+ @Override
+ protected void reduce(IndexedVarIntWritable key, Iterable<FeatureVectorWithRatingWritable> values, Context ctx)
+ throws IOException, InterruptedException {
+ List<Vector> UorMColumns = new ArrayList<Vector>();
+ Vector ratingVector = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ int n = 0;
+ for (FeatureVectorWithRatingWritable value : values) {
+ ratingVector.setQuick(n++, value.getRating());
+ UorMColumns.add(value.getFeatureVector());
+ }
+ Vector uiOrmj = solver.solve(UorMColumns, new SequentialAccessSparseVector(ratingVector), lambda, numFeatures);
+ ctx.write(new VarIntWritable(key.getValue()), new FeatureVectorWithRatingWritable(key.getValue(), uiOrmj));
+ }
+ }
+
+ static class ItemIDRatingMapper extends Mapper<LongWritable,Text,VarLongWritable,FloatWritable> {
+ @Override
+ protected void map(LongWritable key, Text value, Context ctx) throws IOException, InterruptedException {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString());
+ ctx.write(new VarLongWritable(Long.parseLong(tokens[1])), new FloatWritable(Float.parseFloat(tokens[2])));
+ }
+ }
+
+ static class InitializeMReducer
+ extends Reducer<VarLongWritable,FloatWritable,VarIntWritable,FeatureVectorWithRatingWritable> {
+
+ private int numFeatures;
+ private static final Random random = new Random();
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
+ numFeatures = ctx.getConfiguration().getInt(NUM_FEATURES, -1);
+ if (numFeatures < 1) {
+ throw new IllegalStateException("numFeatures was not set correctly!");
+ }
+ }
+
+ @Override
+ protected void reduce(VarLongWritable itemID, Iterable<FloatWritable> ratings, Context ctx)
+ throws IOException, InterruptedException {
+
+ RunningAverage averageRating = new FullRunningAverage();
+ for (FloatWritable rating : ratings) {
+ averageRating.addDatum(rating.get());
+ }
+
+ int itemIDIndex = TasteHadoopUtils.idToIndex(itemID.get());
+ Vector columnOfM = new DenseVector(numFeatures);
+
+ columnOfM.setQuick(0, averageRating.getAverage());
+ for (int n = 1; n < numFeatures; n++) {
+ columnOfM.setQuick(n, random.nextDouble());
+ }
+
+ ctx.write(new VarIntWritable(itemIDIndex), new FeatureVectorWithRatingWritable(itemIDIndex, columnOfM));
+ }
+ }
+
+ private Path joinAndSolvePath(int currentIteration, String step) {
+ return new Path(tempDir, "joinAndSolve-" + currentIteration + "-" + step);
+ }
+
+ private Path pathToM(int iteration) {
+ return new Path(tempDir, "M-" + iteration);
+ }
+
+ private Path pathToU(int iteration) {
+ return new Path(tempDir, "U-" + iteration);
+ }
+
+ private Path pathToItemRatings() {
+ return new Path(tempDir, "itemsAsFeatureWithRatingWritable");
+ }
+
+ private Path pathToUserRatings() {
+ return new Path(tempDir, "usersAsFeatureWithRatingWritable");
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionJob.java?rev=1084789&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionJob.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionJob.java Wed Mar 23 22:33:57 2011
@@ -0,0 +1,183 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.mapreduce.lib.partition.HashPartitioner;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.Map;
+
+/**
+ * <p>Compute predictions for user,item pairs using an existing matrix factorization</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--output (path): path where output should go</li>
+ * <li>--pairs (path): path containing the test ratings, each line must be userID,itemID</li>
+ * <li>--userFeatures (path): path to the user feature matrix</li>
+ * <li>--itemFeatures (path): path to the item feature matrix</li>
+ * </ol>
+ */
+public class PredictionJob extends AbstractJob {
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new PredictionJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addOption("pairs", "p", "path containing the test ratings, each line must be: userID,itemID", true);
+ addOption("userFeatures", "u", "path to the user feature matrix", true);
+ addOption("itemFeatures", "i", "path to the item feature matrix", true);
+ addOutputOption();
+
+ Map<String,String> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Path pairs = new Path(parsedArgs.get("--pairs"));
+ Path userFeatures = new Path(parsedArgs.get("--userFeatures"));
+ Path itemFeatures = new Path(parsedArgs.get("--itemFeatures"));
+
+ Path tempDirPath = new Path(parsedArgs.get("--tempDir"));
+
+ Path convertedPairs = new Path(tempDirPath, "convertedPairs");
+ Path convertedUserFeatures = new Path(tempDirPath, "convertedUserFeatures");
+ Path convertedItemFeatures = new Path(tempDirPath, "convertedItemFeatures");
+
+ Path pairsJoinedWithItemFeatures = new Path(tempDirPath, "pairsJoinedWithItemFeatures");
+
+ Job convertPairs = prepareJob(pairs, convertedPairs, TextInputFormat.class, PairsMapper.class,
+ TaggedVarIntWritable.class, VectorWithIndexWritable.class, Reducer.class, TaggedVarIntWritable.class,
+ VectorWithIndexWritable.class, SequenceFileOutputFormat.class);
+ convertPairs.waitForCompletion(true);
+
+ Job convertUserFeatures = prepareJob(userFeatures, convertedUserFeatures, SequenceFileInputFormat.class,
+ FeaturesMapper.class, TaggedVarIntWritable.class, VectorWithIndexWritable.class, Reducer.class,
+ TaggedVarIntWritable.class, VectorWithIndexWritable.class, SequenceFileOutputFormat.class);
+ convertUserFeatures.waitForCompletion(true);
+
+ Job convertItemFeatures = prepareJob(itemFeatures, convertedItemFeatures, SequenceFileInputFormat.class,
+ FeaturesMapper.class, TaggedVarIntWritable.class, VectorWithIndexWritable.class, Reducer.class,
+ TaggedVarIntWritable.class, VectorWithIndexWritable.class, SequenceFileOutputFormat.class);
+ convertItemFeatures.waitForCompletion(true);
+
+ Job joinPairsWithItemFeatures = prepareJob(new Path(convertedPairs + "," + convertedItemFeatures),
+ pairsJoinedWithItemFeatures, SequenceFileInputFormat.class, Mapper.class, TaggedVarIntWritable.class,
+ VectorWithIndexWritable.class, JoinProbesWithItemFeaturesReducer.class, TaggedVarIntWritable.class,
+ VectorWithIndexWritable.class, SequenceFileOutputFormat.class);
+ joinPairsWithItemFeatures.setPartitionerClass(HashPartitioner.class);
+ joinPairsWithItemFeatures.setGroupingComparatorClass(TaggedVarIntWritable.GroupingComparator.class);
+ joinPairsWithItemFeatures.waitForCompletion(true);
+
+ Job predictRatings = prepareJob(new Path(pairsJoinedWithItemFeatures + "," + convertedUserFeatures),
+ getOutputPath(), SequenceFileInputFormat.class, Mapper.class, TaggedVarIntWritable.class,
+ VectorWithIndexWritable.class, PredictRatingReducer.class, Text.class, NullWritable.class,
+ TextOutputFormat.class);
+ predictRatings.setPartitionerClass(HashPartitioner.class);
+ predictRatings.setGroupingComparatorClass(TaggedVarIntWritable.GroupingComparator.class);
+ predictRatings.waitForCompletion(true);
+
+ return 0;
+ }
+
+ public static class PairsMapper extends Mapper<LongWritable,Text,TaggedVarIntWritable,VectorWithIndexWritable> {
+ @Override
+ protected void map(LongWritable key, Text value, Context ctx) throws IOException, InterruptedException {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString());
+ int userIDIndex = TasteHadoopUtils.idToIndex(Long.parseLong(tokens[0]));
+ int itemIDIndex = TasteHadoopUtils.idToIndex(Long.parseLong(tokens[1]));
+ ctx.write(new TaggedVarIntWritable(itemIDIndex, false), new VectorWithIndexWritable(userIDIndex));
+ }
+ }
+
+ public static class FeaturesMapper
+ extends Mapper<IntWritable,VectorWritable,TaggedVarIntWritable,VectorWithIndexWritable> {
+ @Override
+ protected void map(IntWritable id, VectorWritable features, Context ctx)
+ throws IOException, InterruptedException {
+ ctx.write(new TaggedVarIntWritable(id.get(), true), new VectorWithIndexWritable(features.get()));
+ }
+ }
+
+ public static class JoinProbesWithItemFeaturesReducer
+ extends Reducer<TaggedVarIntWritable,VectorWithIndexWritable,TaggedVarIntWritable,VectorWithIndexWritable> {
+ @Override
+ protected void reduce(TaggedVarIntWritable key, Iterable<VectorWithIndexWritable> values, Context ctx)
+ throws IOException, InterruptedException {
+ int itemIDIndex = key.get();
+ Vector itemFeatures = null;
+ for (VectorWithIndexWritable vectorWithIndexWritable : values) {
+ if (itemFeatures == null && vectorWithIndexWritable.getVector() != null) {
+ itemFeatures = vectorWithIndexWritable.getVector();
+ } else if (itemFeatures == null && vectorWithIndexWritable.getVector() == null) {
+ /* no feature vector is found for that item */
+ return;
+ } else {
+ int userIDIndex = vectorWithIndexWritable.getIDIndex();
+ ctx.write(new TaggedVarIntWritable(userIDIndex, false),
+ new VectorWithIndexWritable(itemIDIndex, itemFeatures));
+ }
+ }
+ }
+ }
+
+ public static class PredictRatingReducer
+ extends Reducer<TaggedVarIntWritable,VectorWithIndexWritable,Text,NullWritable> {
+ @Override
+ protected void reduce(TaggedVarIntWritable key, Iterable<VectorWithIndexWritable> values, Context ctx)
+ throws IOException, InterruptedException {
+ Vector userFeatures = null;
+ int userIDIndex = key.get();
+ for (VectorWithIndexWritable vectorWithIndexWritable : values) {
+ if (userFeatures == null && vectorWithIndexWritable.getVector() != null) {
+ userFeatures = vectorWithIndexWritable.getVector();
+ } else if (userFeatures == null && vectorWithIndexWritable.getVector() == null) {
+ /* no feature vector is found for that user */
+ return;
+ } else {
+ int itemIDIndex = vectorWithIndexWritable.getIDIndex();
+ Vector itemFeatures = vectorWithIndexWritable.getVector();
+ double estimatedPrediction = userFeatures.dot(itemFeatures);
+ ctx.write(new Text(userIDIndex + "," + itemIDIndex + "," + estimatedPrediction), NullWritable.get());
+ }
+ }
+ }
+ }
+
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/TaggedVarIntWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/TaggedVarIntWritable.java?rev=1084789&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/TaggedVarIntWritable.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/TaggedVarIntWritable.java Wed Mar 23 22:33:57 2011
@@ -0,0 +1,120 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.io.WritableComparator;
+import org.apache.mahout.math.Varint;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.io.Serializable;
+
+public class TaggedVarIntWritable implements WritableComparable<TaggedVarIntWritable> {
+
+ private int value;
+ private boolean tagged;
+
+ static {
+ WritableComparator.define(TaggedVarIntWritable.class, new SecondarySortComparator());
+ }
+
+ public TaggedVarIntWritable() {
+ }
+
+ public TaggedVarIntWritable(int value, boolean tagged) {
+ this.value = value;
+ this.tagged = tagged;
+ }
+
+ public int get() {
+ return value;
+ }
+
+ @Override
+ public int compareTo(TaggedVarIntWritable other) {
+ return value == other.value ? 0 : value < other.value ? -1 : 1;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeBoolean(tagged);
+ Varint.writeSignedVarInt(value, out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ tagged = in.readBoolean();
+ value = Varint.readSignedVarInt(in);
+ }
+
+ @Override
+ public int hashCode() {
+ return value;
+ }
+
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof TaggedVarIntWritable) {
+ TaggedVarIntWritable other = (TaggedVarIntWritable) o;
+ return value == other.value;
+ }
+ return false;
+ }
+
+ public static class SecondarySortComparator extends WritableComparator implements Serializable {
+
+ protected SecondarySortComparator() {
+ super(TaggedVarIntWritable.class, true);
+ }
+
+ @Override
+ public int compare(WritableComparable a, WritableComparable b) {
+ TaggedVarIntWritable first = (TaggedVarIntWritable) a;
+ TaggedVarIntWritable second = (TaggedVarIntWritable) b;
+
+ int result = compare(first.value, second.value);
+ if (result == 0) {
+ if (first.tagged && !second.tagged) {
+ return -1;
+ } else if (!first.tagged && second.tagged) {
+ return 1;
+ }
+ }
+ return result;
+ }
+
+ protected static int compare(int a, int b) {
+ return (a == b) ? 0 : (a < b) ? -1 : 1;
+ }
+ }
+
+ public static class GroupingComparator extends WritableComparator implements Serializable {
+
+ protected GroupingComparator() {
+ super(TaggedVarIntWritable.class, true);
+ }
+
+ @Override
+ public int compare(WritableComparable a, WritableComparable b) {
+ return a.compareTo(b);
+ }
+ }
+}
Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/VectorWithIndexWritable.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/VectorWithIndexWritable.java?rev=1084789&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/VectorWithIndexWritable.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/hadoop/als/VectorWithIndexWritable.java Wed Mar 23 22:33:57 2011
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Varint;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+public class VectorWithIndexWritable implements Writable {
+
+ private Vector vector;
+ private Integer idIndex;
+
+ public VectorWithIndexWritable() {
+ }
+
+ public VectorWithIndexWritable(Vector vector) {
+ this.vector = vector;
+ }
+
+ public VectorWithIndexWritable(int idIndex) {
+ this.idIndex = idIndex;
+ }
+
+ public VectorWithIndexWritable(Integer idIndex, Vector vector) {
+ this.vector = vector;
+ this.idIndex = idIndex;
+ }
+
+ public Vector getVector() {
+ return vector;
+ }
+
+ public int getIDIndex() {
+ return idIndex;
+ }
+
+ public boolean hasVector() {
+ return vector != null;
+ }
+
+ public boolean hasIndex() {
+ return idIndex != null;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeBoolean(hasVector());
+ if (hasVector()) {
+ new VectorWritable(vector).write(out);
+ }
+ out.writeBoolean(hasIndex());
+ if (hasIndex()) {
+ Varint.writeSignedVarInt(idIndex, out);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ vector = null;
+ idIndex = null;
+ boolean hasVector = in.readBoolean();
+ if (hasVector) {
+ VectorWritable writable = new VectorWritable();
+ writable.readFields(in);
+ vector = writable.get();
+ }
+ boolean hasRating = in.readBoolean();
+ if (hasRating) {
+ idIndex = Varint.readSignedVarInt(in);
+ }
+ }
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/FeatureVectorWithRatingWritableTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/FeatureVectorWithRatingWritableTest.java?rev=1084789&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/FeatureVectorWithRatingWritableTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/FeatureVectorWithRatingWritableTest.java Wed Mar 23 22:33:57 2011
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+
+public class FeatureVectorWithRatingWritableTest extends MahoutTestCase {
+
+ @Test
+ public void rating() throws Exception {
+
+ FeatureVectorWithRatingWritable rating = new FeatureVectorWithRatingWritable(1, 3f);
+
+ assertTrue(rating.containsRating());
+ assertFalse(rating.containsFeatureVector());
+ assertEquals(1, rating.getIDIndex());
+ assertEquals(3f, rating.getRating(), 0f);
+ assertNull(rating.getFeatureVector());
+
+ FeatureVectorWithRatingWritable clonedRating = recreate(FeatureVectorWithRatingWritable.class, rating);
+
+ assertEquals(rating, clonedRating);
+ assertTrue(clonedRating.containsRating());
+ assertFalse(clonedRating.containsFeatureVector());
+ assertEquals(1, clonedRating.getIDIndex());
+ assertEquals(3f, clonedRating.getRating(), 0f);
+ assertNull(clonedRating.getFeatureVector());
+ }
+
+ @Test
+ public void featureVector() throws Exception {
+
+ Vector v = new DenseVector(new double[] { 1.5, 2.3, 0.9 });
+
+ FeatureVectorWithRatingWritable featureVector = new FeatureVectorWithRatingWritable(7, v);
+
+ assertFalse(featureVector.containsRating());
+ assertTrue(featureVector.containsFeatureVector());
+ assertEquals(7, featureVector.getIDIndex());
+ assertNull(featureVector.getRating());
+ assertEquals(v, featureVector.getFeatureVector());
+
+ FeatureVectorWithRatingWritable clonedFeatureVector =
+ recreate(FeatureVectorWithRatingWritable.class, featureVector);
+
+ assertEquals(featureVector, clonedFeatureVector);
+ assertFalse(clonedFeatureVector.containsRating());
+ assertTrue(clonedFeatureVector.containsFeatureVector());
+ assertEquals(7, clonedFeatureVector.getIDIndex());
+ assertNull(clonedFeatureVector.getRating());
+ assertEquals(v, clonedFeatureVector.getFeatureVector());
+ }
+
+ static <T extends Writable> T recreate(Class<T> tClass, T original)
+ throws IOException, IllegalAccessException, InstantiationException {
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ original.write(new DataOutputStream(out));
+
+ T clone = tClass.newInstance();
+ clone.readFields(new DataInputStream(new ByteArrayInputStream(out.toByteArray())));
+ return clone;
+ }
+
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java?rev=1084789&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java Wed Mar 23 22:33:57 2011
@@ -0,0 +1,286 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.FloatWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.als.AlternateLeastSquaresSolver;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.easymock.IArgumentMatcher;
+import org.easymock.classextension.EasyMock;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+import java.util.Arrays;
+import java.util.Iterator;
+
+public class ParallelALSFactorizationJobTest extends TasteTestCase {
+
+ private static final Logger logger = LoggerFactory.getLogger(ParallelALSFactorizationJobTest.class);
+
+ @Test
+ public void prefsToRatingsMapper() throws Exception {
+ Mapper<LongWritable,Text,VarIntWritable,FeatureVectorWithRatingWritable>.Context ctx =
+ EasyMock.createMock(Mapper.Context.class);
+ ctx.write(new VarIntWritable(TasteHadoopUtils.idToIndex(456L)),
+ new FeatureVectorWithRatingWritable(TasteHadoopUtils.idToIndex(123L), 2.35f));
+ EasyMock.replay(ctx);
+
+ new ParallelALSFactorizationJob.PrefsToRatingsMapper().map(null, new Text("123,456,2.35"), ctx);
+ EasyMock.verify(ctx);
+ }
+
+ @Test
+ public void prefsToRatingsMapperTranspose() throws Exception {
+ Mapper<LongWritable,Text,VarIntWritable,FeatureVectorWithRatingWritable>.Context ctx =
+ EasyMock.createMock(Mapper.Context.class);
+ ctx.write(new VarIntWritable(TasteHadoopUtils.idToIndex(123L)),
+ new FeatureVectorWithRatingWritable(TasteHadoopUtils.idToIndex(456L), 2.35f));
+ EasyMock.replay(ctx);
+
+ ParallelALSFactorizationJob.PrefsToRatingsMapper mapper = new ParallelALSFactorizationJob.PrefsToRatingsMapper();
+ setField(mapper, "transpose", true);
+ mapper.map(null, new Text("123,456,2.35"), ctx);
+ EasyMock.verify(ctx);
+ }
+
+ @Test
+ public void initializeMReducer() throws Exception {
+ Reducer<VarLongWritable,FloatWritable,VarIntWritable,FeatureVectorWithRatingWritable>.Context ctx =
+ EasyMock.createMock(Reducer.Context.class);
+ ctx.write(EasyMock.eq(new VarIntWritable(TasteHadoopUtils.idToIndex(123L))), matchInitializedFeatureVector(3d, 3));
+ EasyMock.replay(ctx);
+
+ ParallelALSFactorizationJob.InitializeMReducer reducer = new ParallelALSFactorizationJob.InitializeMReducer();
+ setField(reducer, "numFeatures", 3);
+ reducer.reduce(new VarLongWritable(123L), Arrays.asList(new FloatWritable(4f), new FloatWritable(2f)), ctx);
+ EasyMock.verify(ctx);
+ }
+
+ static FeatureVectorWithRatingWritable matchInitializedFeatureVector(final double average, final int numFeatures) {
+ EasyMock.reportMatcher(new IArgumentMatcher() {
+ @Override
+ public boolean matches(Object argument) {
+ if (argument instanceof FeatureVectorWithRatingWritable) {
+ Vector v = ((FeatureVectorWithRatingWritable) argument).getFeatureVector();
+ if (v.get(0) != average) {
+ return false;
+ }
+ for (int n = 1; n < numFeatures; n++) {
+ if (v.get(n) < 0 || v.get(n) > 1) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void appendTo(StringBuffer buffer) {}
+ });
+ return null;
+ }
+
+ @Test
+ public void itemIDRatingMapper() throws Exception {
+ Mapper<LongWritable,Text,VarLongWritable,FloatWritable>.Context ctx = EasyMock.createMock(Mapper.Context.class);
+ ctx.write(new VarLongWritable(456L), new FloatWritable(2.35f));
+ EasyMock.replay(ctx);
+ new ParallelALSFactorizationJob.ItemIDRatingMapper().map(null, new Text("123,456,2.35"), ctx);
+ EasyMock.verify(ctx);
+ }
+
+ @Test
+ public void joinFeatureVectorAndRatingsReducer() throws Exception {
+ Vector vector = new DenseVector(new double[] { 4.5, 1.2 });
+ Reducer<VarIntWritable,FeatureVectorWithRatingWritable,IndexedVarIntWritable,FeatureVectorWithRatingWritable>.Context ctx =
+ EasyMock.createMock(Reducer.Context.class);
+ ctx.write(new IndexedVarIntWritable(456, 123), new FeatureVectorWithRatingWritable(123, 2.35f, vector));
+ EasyMock.replay(ctx);
+ new ParallelALSFactorizationJob.JoinFeatureVectorAndRatingsReducer().reduce(new VarIntWritable(123),
+ Arrays.asList(new FeatureVectorWithRatingWritable(456, vector),
+ new FeatureVectorWithRatingWritable(456, 2.35f)), ctx);
+ EasyMock.verify(ctx);
+ }
+
+
+ @Test
+ public void solvingReducer() throws Exception {
+
+ AlternateLeastSquaresSolver solver = new AlternateLeastSquaresSolver();
+
+ int numFeatures = 2;
+ double lambda = 0.01;
+ Vector ratings = new DenseVector(new double[] { 2, 1 });
+ Vector col1 = new DenseVector(new double[] { 1, 2 });
+ Vector col2 = new DenseVector(new double[] { 3, 4 });
+
+ Vector result = solver.solve(Arrays.asList(col1, col2), ratings, lambda, numFeatures);
+ Vector.Element[] elems = new Vector.Element[result.size()];
+ for (int n = 0; n < result.size(); n++) {
+ elems[n] = result.getElement(n);
+ }
+
+ Reducer<IndexedVarIntWritable,FeatureVectorWithRatingWritable,VarIntWritable,FeatureVectorWithRatingWritable>.Context ctx =
+ EasyMock.createMock(Reducer.Context.class);
+ ctx.write(EasyMock.eq(new VarIntWritable(123)), matchFeatureVector(elems));
+ EasyMock.replay(ctx);
+
+ ParallelALSFactorizationJob.SolvingReducer reducer = new ParallelALSFactorizationJob.SolvingReducer();
+ setField(reducer, "numFeatures", numFeatures);
+ setField(reducer, "lambda", lambda);
+ setField(reducer, "solver", solver);
+
+ reducer.reduce(new IndexedVarIntWritable(123, 1), Arrays.asList(
+ new FeatureVectorWithRatingWritable(456, new Float(ratings.get(0)), col1),
+ new FeatureVectorWithRatingWritable(789, new Float(ratings.get(1)), col2)), ctx);
+
+ EasyMock.verify(ctx);
+ }
+
+ static FeatureVectorWithRatingWritable matchFeatureVector(final Vector.Element... elements) {
+ EasyMock.reportMatcher(new IArgumentMatcher() {
+ @Override
+ public boolean matches(Object argument) {
+ if (argument instanceof FeatureVectorWithRatingWritable) {
+ Vector v = ((FeatureVectorWithRatingWritable) argument).getFeatureVector();
+ return MathHelper.consistsOf(v, elements);
+ }
+ return false;
+ }
+
+ @Override
+ public void appendTo(StringBuffer buffer) {}
+ });
+ return null;
+ }
+
+
+ /**
+ * small integration test that runs the full job
+ *
+ * <pre>
+ *
+ * user-item-matrix
+ *
+ * burger hotdog berries icecream
+ * dog 5 5 2 -
+ * rabbit 2 - 3 5
+ * cow - 5 - 3
+ * donkey 3 - - 5
+ *
+ * </pre>
+ */
+ @Test
+ public void completeJobToyExample() throws Exception {
+
+ File inputFile = getTestTempFile("prefs.txt");
+ File outputDir = getTestTempDir("output");
+ outputDir.delete();
+ File tmpDir = getTestTempDir("tmp");
+
+ Double na = Double.NaN;
+ Matrix preferences = new SparseRowMatrix(new int[] { 4, 4 }, new Vector[] {
+ new DenseVector(new double[] { 5d, 5d, 2d, na }),
+ new DenseVector(new double[] { 2d, na, 3d, 5d }),
+ new DenseVector(new double[] { na, 5d, na, 3d }),
+ new DenseVector(new double[] { 3d, na, na, 5d }) });
+
+ StringBuilder prefsAsText = new StringBuilder();
+ String separator = "";
+ Iterator<MatrixSlice> sliceIterator = preferences.iterateAll();
+ while (sliceIterator.hasNext()) {
+ MatrixSlice slice = sliceIterator.next();
+ Iterator<Vector.Element> elementIterator = slice.vector().iterateNonZero();
+ while (elementIterator.hasNext()) {
+ Vector.Element e = elementIterator.next();
+ if (!Double.isNaN(e.get())) {
+ prefsAsText.append(separator).append(slice.index()).append(",").append(e.index()).append(",").append(e.get());
+ separator = "\n";
+ }
+ }
+ }
+ logger.info("Input matrix:\n" + prefsAsText);
+ writeLines(inputFile, prefsAsText.toString());
+
+ ParallelALSFactorizationJob alsFactorization = new ParallelALSFactorizationJob();
+
+ Configuration conf = new Configuration();
+ conf.set("mapred.input.dir", inputFile.getAbsolutePath());
+ conf.set("mapred.output.dir", outputDir.getAbsolutePath());
+ conf.setBoolean("mapred.output.compress", false);
+
+ int numFeatures = 3;
+ int numIterations = 5;
+ double lambda = 0.065;
+
+ alsFactorization.setConf(conf);
+ alsFactorization.run(new String[] { "--tempDir", tmpDir.getAbsolutePath(), "--lambda", String.valueOf(lambda),
+ "--numFeatures", String.valueOf(numFeatures), "--numIterations", String.valueOf(numIterations) });
+
+ Path inputPath = new Path(inputFile.getAbsolutePath());
+ FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
+
+ Matrix u = MathHelper.readEntries(fs, conf, new Path(outputDir.getAbsolutePath(), "U/part-r-00000"),
+ preferences.numRows(), numFeatures);
+ Matrix m = MathHelper.readEntries(fs, conf, new Path(outputDir.getAbsolutePath(), "M/part-r-00000"),
+ preferences.numCols(), numFeatures);
+
+ FullRunningAverage avg = new FullRunningAverage();
+ sliceIterator = preferences.iterateAll();
+ while (sliceIterator.hasNext()) {
+ MatrixSlice slice = sliceIterator.next();
+ Iterator<Vector.Element> elementIterator = slice.vector().iterateNonZero();
+ while (elementIterator.hasNext()) {
+ Vector.Element e = elementIterator.next();
+ if (!Double.isNaN(e.get())) {
+ double pref = e.get();
+ double estimate = u.getRow(slice.index()).dot(m.getRow(e.index()));
+ double err = pref - estimate;
+ avg.addDatum(err * err);
+ logger.info("Comparing preference of user [" + slice.index() + "] towards item [" + e.index() + "], " +
+ "was [" + pref + "] estimate is [" + estimate + "]");
+ }
+ }
+ }
+ double rmse = Math.sqrt(avg.getAverage());
+ logger.info("RMSE: " + rmse);
+
+ assertTrue(rmse < 0.2d);
+ }
+
+}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/PredictionJobTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/PredictionJobTest.java?rev=1084789&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/PredictionJobTest.java (added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/hadoop/als/PredictionJobTest.java Wed Mar 23 22:33:57 2011
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.model.file.FileDataModel;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.junit.Test;
+
+import java.io.File;
+
+public class PredictionJobTest extends TasteTestCase {
+
+ @Test
+ public void smallIntegration() throws Exception {
+
+ File pairs = getTestTempFile("pairs.txt");
+ File userFeatures = getTestTempFile("userFeatures.seq");
+ File itemFeatures = getTestTempFile("itemFeatures.seq");
+ File tempDir = getTestTempDir("temp");
+ File outputDir = getTestTempDir("out");
+ outputDir.delete();
+
+ Configuration conf = new Configuration();
+ Path inputPath = new Path(pairs.getAbsolutePath());
+ FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
+
+ MathHelper.writeEntries(new double[][]{
+ new double[] { 1.5, -2, 0.3 },
+ new double[] { -0.7, 2, 0.6 },
+ new double[] { -1, 2.5, 3 } }, fs, conf, new Path(userFeatures.getAbsolutePath()));
+
+ MathHelper.writeEntries(new double [][] {
+ new double[] { 2.3, 0.5, 0 },
+ new double[] { 4.7, -1, 0.2 },
+ new double[] { 0.6, 2, 1.3 } }, fs, conf, new Path(itemFeatures.getAbsolutePath()));
+
+ writeLines(pairs, "0,0", "2,1", "1,0");
+
+ PredictionJob predictor = new PredictionJob();
+ predictor.setConf(conf);
+ predictor.run(new String[] { "--output", outputDir.getAbsolutePath(), "--pairs", pairs.getAbsolutePath(),
+ "--userFeatures", userFeatures.getAbsolutePath(), "--itemFeatures", itemFeatures.getAbsolutePath(),
+ "--tempDir", tempDir.getAbsolutePath() });
+
+ FileDataModel dataModel = new FileDataModel(new File(outputDir, "part-r-00000"));
+ assertEquals(dataModel.getNumUsers(), 3);
+ assertEquals(dataModel.getNumItems(), 2);
+ assertEquals(dataModel.getPreferenceValue(0, 0), 2.45, EPSILON);
+ assertEquals(dataModel.getPreferenceValue(2, 1), -6.6, EPSILON);
+ assertEquals(dataModel.getPreferenceValue(1, 0), -0.61, EPSILON);
+ }
+
+}
Added: mahout/trunk/examples/bin/factorize-movielens-1M.sh
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/bin/factorize-movielens-1M.sh?rev=1084789&view=auto
==============================================================================
--- mahout/trunk/examples/bin/factorize-movielens-1M.sh (added)
+++ mahout/trunk/examples/bin/factorize-movielens-1M.sh Wed Mar 23 22:33:57 2011
@@ -0,0 +1,59 @@
+#!/bin/bash
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Instructions:
+#
+# Before using this script, you have to download and extract the Movielens 1M dataset
+# from http://www.grouplens.org/node/73
+#
+# To run: change into the mahout directory and type:
+# examples/bin/factorize-movielens-1M.sh /path/to/ratings.dat
+
+if [ $# -ne 1 ]
+then
+ echo -e "\nYou have to download the Movielens 1M dataset from http://www.grouplens.org/node/73 before"
+ echo -e "you can run this example. After that extract it and supply the path to the ratings.dat file.\n"
+ echo -e "Syntax: $0 /path/to/ratings.dat\n"
+ exit -1
+fi
+
+echo "creating work directory"
+mkdir -p work/movielens
+
+echo "Converting ratings..."
+cat $1 |sed -e s/::/,/g| cut -d, -f1,2,3 > work/movielens/ratings.csv
+
+#create a 90% percent training set and a 10% probe set
+bin/mahout splitDataset --input work/movielens/ratings.csv --output work/dataset \
+ --trainingPercentage 0.9 --probePercentage 0.1 --tempDir work/dataset/tmp
+
+#run distributed ALS-WR to factorize the rating matrix based on the training set
+bin/mahout parallelALS --input work/dataset/trainingSet/ --output work/als/out \
+ --tempDir work/als/tmp --numFeatures 20 --numIterations 10 --lambda 0.065
+
+# compute predictions against the probe set, measure the error
+bin/mahout evaluateFactorizationParallel --output work/als/rmse --pairs work/dataset/probeSet/ \
+ --userFeatures work/als/out/U/ --itemFeatures work/als/out/M/
+
+# print the error
+echo -e "\nRMSE is:\n"
+cat work/als/rmse/rmse.txt
+echo -e "\n\n"
+
+echo "removing work directory"
+rm -rf work
\ No newline at end of file
Modified: mahout/trunk/src/conf/driver.classes.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/driver.classes.props?rev=1084789&r1=1084788&r2=1084789&view=diff
==============================================================================
--- mahout/trunk/src/conf/driver.classes.props (original)
+++ mahout/trunk/src/conf/driver.classes.props Wed Mar 23 22:33:57 2011
@@ -1,6 +1,9 @@
org.apache.mahout.utils.vectors.VectorDumper = vectordump : Dump vectors from a sequence file to text
org.apache.mahout.utils.clustering.ClusterDumper = clusterdump : Dump cluster output to text
org.apache.mahout.utils.SequenceFileDumper = seqdumper : Generic Sequence File dumper
+org.apache.mahout.utils.eval.DatasetSplitter = splitDataset : split a rating dataset into training and probe parts
+org.apache.mahout.utils.eval.InMemoryFactorizationEvaluator = evaluateFactorization : compute RMSE of a rating matrix factorization against probes in memory
+org.apache.mahout.utils.eval.ParallelFactorizationEvaluator = evaluateFactorizationParallel : compute RMSE of a rating matrix factorization against probes
org.apache.mahout.clustering.kmeans.KMeansDriver = kmeans : K-means clustering
org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansDriver = fkmeans : Fuzzy K-means clustering
org.apache.mahout.clustering.lda.LDADriver = lda : Latent Dirchlet Allocation
@@ -33,3 +36,5 @@ org.apache.mahout.classifier.bayes.Wikip
org.apache.mahout.math.hadoop.stochasticsvd.SSVDCli = ssvd : Stochastic SVD
org.apache.mahout.clustering.spectral.eigencuts.EigencutsDriver = eigencuts : Eigencuts spectral clustering
org.apache.mahout.clustering.spectral.kmeans.SpectralKMeansDriver = spectralkmeans : Spectral k-means clustering
+org.apache.mahout.cf.taste.hadoop.als.ParallelALSFactorizationJob = parallelALS : ALS-WR factorization of a rating matrix
+org.apache.mahout.cf.taste.hadoop.als.PredictionJob = predictFromFactorization : predict preferences from a factorization of a rating matrix
\ No newline at end of file
Added: mahout/trunk/src/conf/evaluateFactorization.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/evaluateFactorization.props?rev=1084789&view=auto
==============================================================================
(empty)
Added: mahout/trunk/src/conf/evaluateFactorizationParallel.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/evaluateFactorizationParallel.props?rev=1084789&view=auto
==============================================================================
(empty)
Added: mahout/trunk/src/conf/parallelALS.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/parallelALS.props?rev=1084789&view=auto
==============================================================================
(empty)
Added: mahout/trunk/src/conf/predictFromFactorization.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/predictFromFactorization.props?rev=1084789&view=auto
==============================================================================
(empty)
Added: mahout/trunk/src/conf/splitDataset.props
URL: http://svn.apache.org/viewvc/mahout/trunk/src/conf/splitDataset.props?rev=1084789&view=auto
==============================================================================
(empty)