You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ss...@apache.org on 2012/04/27 17:31:58 UTC

svn commit: r1331469 - /mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ImplicitLinearRegressionFactorizer.java

Author: ssc
Date: Fri Apr 27 15:31:58 2012
New Revision: 1331469

URL: http://svn.apache.org/viewvc?rev=1331469&view=rev
Log:
MAHOUT-737 Implicit Alternating Least Squares SVD

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ImplicitLinearRegressionFactorizer.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ImplicitLinearRegressionFactorizer.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ImplicitLinearRegressionFactorizer.java?rev=1331469&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ImplicitLinearRegressionFactorizer.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ImplicitLinearRegressionFactorizer.java Fri Apr 27 15:31:58 2012
@@ -0,0 +1,395 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.recommender.svd.AbstractFactorizer;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.DiagonalMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.QRDecomposition;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class ImplicitLinearRegressionFactorizer extends AbstractFactorizer {
+
+    private static final Logger log = LoggerFactory.getLogger(ImplicitLinearRegressionFactorizer.class);
+    private final double preventOverfitting;
+    /** number of features used to compute this factorization */
+    private final int numFeatures;
+    /** number of iterations */
+    private final int numIterations;
+    private final DataModel dataModel;
+    /** User singular vector. */
+    private double[][] userMatrix;
+    /** Item singular vector. */
+    private double[][] itemMatrix;
+    private Matrix userTransUser;
+    private Matrix itemTransItem;
+    Collection<Callable<Void>> fVectorCallables;
+    private boolean recomputeUserFeatures;
+    private RunningAverage avrChange;
+
+    public ImplicitLinearRegressionFactorizer(DataModel dataModel) throws TasteException {
+        this(dataModel, 64, 10, 0.1);
+    }
+
+    public ImplicitLinearRegressionFactorizer(DataModel dataModel, int numFeatures, int numIterations,
+            double preventOverfitting) throws TasteException {
+
+        super(dataModel);
+        this.dataModel = dataModel;
+        this.numFeatures = numFeatures;
+        this.numIterations = numIterations;
+        this.preventOverfitting = preventOverfitting;
+        fVectorCallables = Lists.newArrayList();
+        avrChange = new FullRunningAverage();
+    }
+
+    @Override
+    public Factorization factorize() throws TasteException {
+        Random random = RandomUtils.getRandom();
+        userMatrix = new double[dataModel.getNumUsers()][numFeatures];
+        itemMatrix = new double[dataModel.getNumItems()][numFeatures];
+
+        /* start with the user side */
+        recomputeUserFeatures = true;
+
+        double average = getAveragePreference();
+
+        double prefInterval = dataModel.getMaxPreference() - dataModel.getMinPreference();
+        double defaultValue = Math.sqrt((average - prefInterval * 0.1) / numFeatures);
+        double interval = prefInterval * 0.1 / numFeatures;
+
+        for (int feature = 0; feature < numFeatures; feature++) {
+            for (int userIndex = 0; userIndex < dataModel.getNumUsers(); userIndex++) {
+                userMatrix[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * random.nextDouble();
+            }
+            for (int itemIndex = 0; itemIndex < dataModel.getNumItems(); itemIndex++) {
+                itemMatrix[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * random.nextDouble();
+            }
+        }
+        train();
+        return createFactorization(userMatrix, itemMatrix);
+    }
+
+    public void train() throws TasteException {
+        for (int i = 0; i < numIterations; i++) {
+            if (recomputeUserFeatures) {
+                LongPrimitiveIterator userIds = dataModel.getUserIDs();
+                /* start with calculating X^TX or Y^TX */
+                log.info("Calculating Y^TY");
+                reCalculateTrans(recomputeUserFeatures);
+                log.info("Building callables for users.");
+                while (userIds.hasNext()) {
+                    long userId = userIds.nextLong();
+                    int useridx = userIndex(userId);
+                    buildCallables(buildConfidenceMatrixForUser(userId), buildPreferenceVectorForUser(userId), useridx);
+                }
+                finishProcessing();
+            } else {
+                LongPrimitiveIterator itemIds = dataModel.getItemIDs();
+                /* start with calculating X^TX or Y^TX */
+                log.info("Calculating X^TX");
+                reCalculateTrans(recomputeUserFeatures);
+                log.info("Building callables for items.");
+                while (itemIds.hasNext()) {
+                    long itemId = itemIds.nextLong();
+                    int itemidx = itemIndex(itemId);
+                    buildCallables(buildConfidenceMatrixForItem(itemId), buildPreferenceVectorForItem(itemId), itemidx);
+                }
+                finishProcessing();
+            }
+        }
+    }
+
+    public Matrix buildPreferenceVectorForUser(long realId) throws TasteException {
+        Matrix ids = new SparseMatrix(1, dataModel.getNumItems());
+        for (Preference pref : dataModel.getPreferencesFromUser(realId)) {
+            int itemidx = itemIndex(pref.getItemID());
+            ids.setQuick(0, itemidx, pref.getValue());
+        }
+        return ids;
+    }
+
+    private Matrix buildConfidenceMatrixForItem(long itemId) throws TasteException {
+        PreferenceArray prefs = dataModel.getPreferencesForItem(itemId);
+        Matrix confidenceMatrix = new SparseMatrix(dataModel.getNumUsers(), dataModel.getNumUsers());
+        for (Preference pref : prefs) {
+            long userId = pref.getUserID();
+            int userIdx = userIndex(userId);
+            confidenceMatrix.setQuick(userIdx, userIdx, 1);
+        }
+        return new DiagonalMatrix(confidenceMatrix);
+    }
+
+    private Matrix buildConfidenceMatrixForUser(long userId) throws TasteException {
+        PreferenceArray prefs = dataModel.getPreferencesFromUser(userId);
+        Matrix confidenceMatrix = new SparseMatrix(dataModel.getNumItems(), dataModel.getNumItems());
+        for (Preference pref : prefs) {
+            long itemId = pref.getItemID();
+            int itemIdx = itemIndex(itemId);
+            confidenceMatrix.setQuick(itemIdx, itemIdx, 1);
+        }
+        return new DiagonalMatrix(confidenceMatrix);
+    }
+
+    private Matrix buildPreferenceVectorForItem(long realId) throws TasteException {
+        Matrix ids = new SparseMatrix(1, dataModel.getNumUsers());
+        for (Preference pref : dataModel.getPreferencesForItem(realId)) {
+            int useridx = userIndex(pref.getUserID());
+            ids.setQuick(0, useridx, pref.getValue());
+        }
+        return ids;
+    }
+
+    private Matrix ones(int size) {
+        double[] vector = new double[size];
+        for (int i = 0; i < size; i++) {
+            vector[i] = 1;
+        }
+        Matrix ones = new DiagonalMatrix(vector);
+        return ones;
+    }
+
+    private double getAveragePreference() throws TasteException {
+        RunningAverage average = new FullRunningAverage();
+        LongPrimitiveIterator it = dataModel.getUserIDs();
+        while (it.hasNext()) {
+            int count = 0;
+            PreferenceArray prefs;
+            try {
+                prefs = dataModel.getPreferencesFromUser(it.nextLong());
+                for (Preference pref : prefs) {
+                    average.addDatum(pref.getValue());
+                    count++;
+                }
+            } catch (NoSuchUserException ex) {
+                continue;
+            }
+            /* add the remaining zeros */
+            for (int i = 0; i < (dataModel.getNumItems() - count); i++) {
+                average.addDatum(0);
+            }
+        }
+        return average.getAverage();
+    }
+
+    /**
+     * Recalculating Y^TY or X^TX which is needed for further calculations
+     * @param recomputeUserFeatures
+     */
+    public void reCalculateTrans(boolean recomputeUserFeatures) {
+        if (!recomputeUserFeatures) {
+            Matrix uMatrix = new DenseMatrix(userMatrix);
+            userTransUser = uMatrix.transpose().times(uMatrix);
+        } else {
+            Matrix iMatrix = new DenseMatrix(itemMatrix);
+            itemTransItem = iMatrix.transpose().times(iMatrix);
+        }
+    }
+
+    private synchronized void updateMatrix(int id, Matrix m) {
+        double normA = 0;
+        double normB = 0;
+        double aTb = 0;
+        for (int feature = 0; feature < numFeatures; feature++) {
+            if (recomputeUserFeatures) {
+                normA += userMatrix[id][feature] * userMatrix[id][feature];
+                normB += m.get(feature, 0) * m.get(feature, 0);
+                aTb += userMatrix[id][feature] * m.get(feature, 0);
+                userMatrix[id][feature] = m.get(feature, 0);
+            } else {
+                normA += itemMatrix[id][feature] * itemMatrix[id][feature];
+                normB += m.get(feature, 0) * m.get(feature, 0);
+                aTb += itemMatrix[id][feature] * m.get(feature, 0);
+                itemMatrix[id][feature] = m.get(feature, 0);
+            }
+        }
+        /* calculating cosine similarity to determine when to stop the algorithm, this could be used to detect convergence */
+        double cosine = (aTb) / (Math.sqrt(normA) * Math.sqrt(normB));
+        if (Double.isNaN(cosine)) {
+            log.info("Cosine similarity is NaN, recomputeUserFeatures=" + recomputeUserFeatures + " id=" + id);
+        } else {
+            avrChange.addDatum(cosine);
+        }
+    }
+
+    public void resetCallables() {
+        fVectorCallables = Lists.newArrayList();
+    }
+
+    private void resetAvrChange() {
+        log.info("Avr Change: {}", avrChange.getAverage());
+        avrChange = new FullRunningAverage();
+    }
+
+    public void buildCallables(Matrix C, Matrix prefVector, int id) throws TasteException {
+        fVectorCallables.add(new FeatureVectorCallable(C, prefVector, id));
+        if (fVectorCallables.size() % (200 * Runtime.getRuntime().availableProcessors()) == 0) {
+            execute(fVectorCallables);
+            resetCallables();
+        }
+    }
+
+    public void finishProcessing() throws TasteException {
+        /* run the remaining part */
+        if (fVectorCallables != null) {
+            execute(fVectorCallables);
+        }
+        resetCallables();
+        if ((recomputeUserFeatures && avrChange.getCount() != userMatrix.length)
+                || (!recomputeUserFeatures && avrChange.getCount() != itemMatrix.length)) {
+            log.info("Matrix length is not equal to count");
+        }
+        resetAvrChange();
+        recomputeUserFeatures = !recomputeUserFeatures;
+    }
+
+    public Matrix identityV(int size) {
+        return ones(size);
+    }
+
+    void execute(Collection<Callable<Void>> callables) throws TasteException {
+        callables = wrapWithStatsCallables(callables);
+        int numProcessors = Runtime.getRuntime().availableProcessors();
+        ExecutorService executor = Executors.newFixedThreadPool(numProcessors);
+        log.info("Starting timing of {} tasks in {} threads", callables.size(), numProcessors);
+        try {
+            List<Future<Void>> futures = executor.invokeAll(callables);
+            //TODO go look for exceptions here, really
+            for (Future<Void> future : futures) {
+                future.get();
+            }
+        } catch (InterruptedException ie) {
+            log.warn("error in factorization", ie);
+        } catch (ExecutionException ee) {
+            log.warn("error in factorization", ee);
+        }
+        executor.shutdown();
+    }
+
+    private Collection<Callable<Void>> wrapWithStatsCallables(Collection<Callable<Void>> callables) {
+        int size = callables.size();
+        Collection<Callable<Void>> wrapped = Lists.newArrayListWithExpectedSize(size);
+        int count = 1;
+        RunningAverageAndStdDev timing = new FullRunningAverageAndStdDev();
+        for (Callable<Void> callable : callables) {
+            boolean logStats = count++ % 1000 == 0;
+            wrapped.add(new StatsCallable(callable, logStats, timing));
+        }
+        return wrapped;
+    }
+
+    private class FeatureVectorCallable implements Callable<Void> {
+
+        private final Matrix C;
+        private final Matrix prefVector;
+        private final int id;
+
+        private FeatureVectorCallable(Matrix C, Matrix prefVector, int id) {
+            this.C = C;
+            this.prefVector = prefVector;
+            this.id = id;
+        }
+
+        @Override
+        public Void call() throws Exception {
+            Matrix XTCX;
+            if (recomputeUserFeatures) {
+                Matrix I = identityV(dataModel.getNumItems());
+                Matrix I2 = identityV(numFeatures);
+                Matrix iTi = itemTransItem.clone();
+                Matrix itemM = new DenseMatrix(itemMatrix);
+                XTCX = iTi.plus(itemM.transpose().times(C.minus(I)).times(itemM));
+
+                Matrix diag = solve(XTCX.plus(I2.times(preventOverfitting)), I2);
+                Matrix results = diag.times(itemM.transpose().times(C)).times(prefVector.transpose());
+                updateMatrix(id, results);
+            } else {
+                Matrix I = identityV(dataModel.getNumUsers());
+                Matrix I2 = identityV(numFeatures);
+                Matrix uTu = userTransUser.clone();
+                Matrix userM = new DenseMatrix(userMatrix);
+                XTCX = uTu.plus(userM.transpose().times(C.minus(I)).times(userM));
+
+                Matrix diag = solve(XTCX.plus(I2.times(preventOverfitting)), I2);
+                Matrix results = diag.times(userM.transpose().times(C)).times(prefVector.transpose());
+                updateMatrix(id, results);
+            }
+            return null;
+        }
+    }
+
+    private Matrix solve(Matrix A, Matrix y) {
+        return new QRDecomposition(A).solve(y);
+    }
+
+    private static class StatsCallable implements Callable<Void> {
+
+        private final Callable<Void> delegate;
+        private final boolean logStats;
+        private final RunningAverageAndStdDev timing;
+
+        private StatsCallable(Callable<Void> delegate, boolean logStats, RunningAverageAndStdDev timing) {
+            this.delegate = delegate;
+            this.logStats = logStats;
+            this.timing = timing;
+        }
+
+        @Override
+        public Void call() throws Exception {
+            long start = System.currentTimeMillis();
+            delegate.call();
+            long end = System.currentTimeMillis();
+            timing.addDatum(end - start);
+            if (logStats) {
+                Runtime runtime = Runtime.getRuntime();
+                int average = (int) timing.getAverage();
+                log.info("Average time per task: {}ms", average);
+                long totalMemory = runtime.totalMemory();
+                long memory = totalMemory - runtime.freeMemory();
+                log.info("Approximate memory used: {}MB / {}MB", memory / 1000000L, totalMemory / 1000000L);
+            }
+            return null;
+        }
+    }
+}