You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/27 13:14:45 UTC
[19/24] mahout git commit: MAHOUT-2034 Split MR and New Examples into
seperate modules
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
new file mode 100644
index 0000000..a99d54c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
@@ -0,0 +1,265 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collection;
+import java.util.Random;
+
+/**
+ * {@link Factorizer} based on Simon Funk's famous article <a href="http://sifter.org/~simon/journal/20061211.html">
+ * "Netflix Update: Try this at home"</a>.
+ *
+ * Attempts to be as memory efficient as possible, only iterating once through the
+ * {@link FactorizablePreferences} or {@link DataModel} while copying everything to primitive arrays.
+ * Learning works in place on these datastructures after that.
+ */
+public class ParallelArraysSGDFactorizer implements Factorizer {
+
+ public static final double DEFAULT_LEARNING_RATE = 0.005;
+ public static final double DEFAULT_PREVENT_OVERFITTING = 0.02;
+ public static final double DEFAULT_RANDOM_NOISE = 0.005;
+
+ private final int numFeatures;
+ private final int numIterations;
+ private final float minPreference;
+ private final float maxPreference;
+
+ private final Random random;
+ private final double learningRate;
+ private final double preventOverfitting;
+
+ private final FastByIDMap<Integer> userIDMapping;
+ private final FastByIDMap<Integer> itemIDMapping;
+
+ private final double[][] userFeatures;
+ private final double[][] itemFeatures;
+
+ private final int[] userIndexes;
+ private final int[] itemIndexes;
+ private final float[] values;
+
+ private final double defaultValue;
+ private final double interval;
+ private final double[] cachedEstimates;
+
+
+ private static final Logger log = LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class);
+
+ public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) {
+ this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, DEFAULT_LEARNING_RATE,
+ DEFAULT_PREVENT_OVERFITTING, DEFAULT_RANDOM_NOISE);
+ }
+
+ public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations, double learningRate,
+ double preventOverfitting, double randomNoise) {
+ this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, learningRate, preventOverfitting,
+ randomNoise);
+ }
+
+ public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePrefs, int numFeatures, int numIterations) {
+ this(factorizablePrefs, numFeatures, numIterations, DEFAULT_LEARNING_RATE, DEFAULT_PREVENT_OVERFITTING,
+ DEFAULT_RANDOM_NOISE);
+ }
+
+ public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int numFeatures,
+ int numIterations, double learningRate, double preventOverfitting, double randomNoise) {
+
+ this.numFeatures = numFeatures;
+ this.numIterations = numIterations;
+ minPreference = factorizablePreferences.getMinPreference();
+ maxPreference = factorizablePreferences.getMaxPreference();
+
+ this.random = RandomUtils.getRandom();
+ this.learningRate = learningRate;
+ this.preventOverfitting = preventOverfitting;
+
+ int numUsers = factorizablePreferences.numUsers();
+ int numItems = factorizablePreferences.numItems();
+ int numPrefs = factorizablePreferences.numPreferences();
+
+ log.info("Mapping {} users...", numUsers);
+ userIDMapping = new FastByIDMap<>(numUsers);
+ int index = 0;
+ LongPrimitiveIterator userIterator = factorizablePreferences.getUserIDs();
+ while (userIterator.hasNext()) {
+ userIDMapping.put(userIterator.nextLong(), index++);
+ }
+
+ log.info("Mapping {} items", numItems);
+ itemIDMapping = new FastByIDMap<>(numItems);
+ index = 0;
+ LongPrimitiveIterator itemIterator = factorizablePreferences.getItemIDs();
+ while (itemIterator.hasNext()) {
+ itemIDMapping.put(itemIterator.nextLong(), index++);
+ }
+
+ this.userIndexes = new int[numPrefs];
+ this.itemIndexes = new int[numPrefs];
+ this.values = new float[numPrefs];
+ this.cachedEstimates = new double[numPrefs];
+
+ index = 0;
+ log.info("Loading {} preferences into memory", numPrefs);
+ RunningAverage average = new FullRunningAverage();
+ for (Preference preference : factorizablePreferences.getPreferences()) {
+ userIndexes[index] = userIDMapping.get(preference.getUserID());
+ itemIndexes[index] = itemIDMapping.get(preference.getItemID());
+ values[index] = preference.getValue();
+ cachedEstimates[index] = 0;
+
+ average.addDatum(preference.getValue());
+
+ index++;
+ if (index % 1000000 == 0) {
+ log.info("Processed {} preferences", index);
+ }
+ }
+ log.info("Processed {} preferences, done.", index);
+
+ double averagePreference = average.getAverage();
+ log.info("Average preference value is {}", averagePreference);
+
+ double prefInterval = factorizablePreferences.getMaxPreference() - factorizablePreferences.getMinPreference();
+ defaultValue = Math.sqrt((averagePreference - prefInterval * 0.1) / numFeatures);
+ interval = prefInterval * 0.1 / numFeatures;
+
+ userFeatures = new double[numUsers][numFeatures];
+ itemFeatures = new double[numItems][numFeatures];
+
+ log.info("Initializing feature vectors...");
+ for (int feature = 0; feature < numFeatures; feature++) {
+ for (int userIndex = 0; userIndex < numUsers; userIndex++) {
+ userFeatures[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise;
+ }
+ for (int itemIndex = 0; itemIndex < numItems; itemIndex++) {
+ itemFeatures[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise;
+ }
+ }
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ for (int feature = 0; feature < numFeatures; feature++) {
+ log.info("Shuffling preferences...");
+ shufflePreferences();
+ log.info("Starting training of feature {} ...", feature);
+ for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) {
+ if (currentIteration == numIterations - 1) {
+ double rmse = trainingIterationWithRmse(feature);
+ log.info("Finished training feature {} with RMSE {}", feature, rmse);
+ } else {
+ trainingIteration(feature);
+ }
+ }
+ if (feature < numFeatures - 1) {
+ log.info("Updating cache...");
+ for (int index = 0; index < userIndexes.length; index++) {
+ cachedEstimates[index] = estimate(userIndexes[index], itemIndexes[index], feature, cachedEstimates[index],
+ false);
+ }
+ }
+ }
+ log.info("Factorization done");
+ return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
+ }
+
+ private void trainingIteration(int feature) {
+ for (int index = 0; index < userIndexes.length; index++) {
+ train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]);
+ }
+ }
+
+ private double trainingIterationWithRmse(int feature) {
+ double rmse = 0.0;
+ for (int index = 0; index < userIndexes.length; index++) {
+ double error = train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]);
+ rmse += error * error;
+ }
+ return Math.sqrt(rmse / userIndexes.length);
+ }
+
+ private double estimate(int userIndex, int itemIndex, int feature, double cachedEstimate, boolean trailing) {
+ double sum = cachedEstimate;
+ sum += userFeatures[userIndex][feature] * itemFeatures[itemIndex][feature];
+ if (trailing) {
+ sum += (numFeatures - feature - 1) * (defaultValue + interval) * (defaultValue + interval);
+ if (sum > maxPreference) {
+ sum = maxPreference;
+ } else if (sum < minPreference) {
+ sum = minPreference;
+ }
+ }
+ return sum;
+ }
+
+ public double train(int userIndex, int itemIndex, int feature, double original, double cachedEstimate) {
+ double error = original - estimate(userIndex, itemIndex, feature, cachedEstimate, true);
+ double[] userVector = userFeatures[userIndex];
+ double[] itemVector = itemFeatures[itemIndex];
+
+ userVector[feature] += learningRate * (error * itemVector[feature] - preventOverfitting * userVector[feature]);
+ itemVector[feature] += learningRate * (error * userVector[feature] - preventOverfitting * itemVector[feature]);
+
+ return error;
+ }
+
+ protected void shufflePreferences() {
+ /* Durstenfeld shuffle */
+ for (int currentPos = userIndexes.length - 1; currentPos > 0; currentPos--) {
+ int swapPos = random.nextInt(currentPos + 1);
+ swapPreferences(currentPos, swapPos);
+ }
+ }
+
+ private void swapPreferences(int posA, int posB) {
+ int tmpUserIndex = userIndexes[posA];
+ int tmpItemIndex = itemIndexes[posA];
+ float tmpValue = values[posA];
+ double tmpEstimate = cachedEstimates[posA];
+
+ userIndexes[posA] = userIndexes[posB];
+ itemIndexes[posA] = itemIndexes[posB];
+ values[posA] = values[posB];
+ cachedEstimates[posA] = cachedEstimates[posB];
+
+ userIndexes[posB] = tmpUserIndex;
+ itemIndexes[posB] = tmpItemIndex;
+ values[posB] = tmpValue;
+ cachedEstimates[posB] = tmpEstimate;
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // do nothing
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
new file mode 100644
index 0000000..5cce02d
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
@@ -0,0 +1,141 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
+
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.OutputStream;
+
+import com.google.common.io.Closeables;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.example.kddcup.track1.EstimateConverter;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
+import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * run an SVD factorization of the KDD track1 data.
+ *
+ * needs at least 6-7GB of memory, tested with -Xms6700M -Xmx6700M
+ *
+ */
+public final class Track1SVDRunner {
+
+ private static final Logger log = LoggerFactory.getLogger(Track1SVDRunner.class);
+
+ private Track1SVDRunner() {
+ }
+
+ public static void main(String[] args) throws Exception {
+
+ if (args.length != 2) {
+ System.err.println("Necessary arguments: <kddDataFileDirectory> <resultFile>");
+ return;
+ }
+
+ File dataFileDirectory = new File(args[0]);
+ if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
+ throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
+ }
+
+ File resultFile = new File(args[1]);
+
+ /* the knobs to turn */
+ int numFeatures = 20;
+ int numIterations = 5;
+ double learningRate = 0.0001;
+ double preventOverfitting = 0.002;
+ double randomNoise = 0.0001;
+
+
+ KDDCupFactorizablePreferences factorizablePreferences =
+ new KDDCupFactorizablePreferences(KDDCupDataModel.getTrainingFile(dataFileDirectory));
+
+ Factorizer sgdFactorizer = new ParallelArraysSGDFactorizer(factorizablePreferences, numFeatures, numIterations,
+ learningRate, preventOverfitting, randomNoise);
+
+ Factorization factorization = sgdFactorizer.factorize();
+
+ log.info("Estimating validation preferences...");
+ int prefsProcessed = 0;
+ RunningAverage average = new FullRunningAverage();
+ for (Pair<PreferenceArray,long[]> validationPair
+ : new DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory))) {
+ for (Preference validationPref : validationPair.getFirst()) {
+ double estimate = estimatePreference(factorization, validationPref.getUserID(), validationPref.getItemID(),
+ factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference());
+ double error = validationPref.getValue() - estimate;
+ average.addDatum(error * error);
+ prefsProcessed++;
+ if (prefsProcessed % 100000 == 0) {
+ log.info("Computed {} estimations", prefsProcessed);
+ }
+ }
+ }
+ log.info("Computed {} estimations, done.", prefsProcessed);
+
+ double rmse = Math.sqrt(average.getAverage());
+ log.info("RMSE {}", rmse);
+
+ log.info("Estimating test preferences...");
+ OutputStream out = null;
+ try {
+ out = new BufferedOutputStream(new FileOutputStream(resultFile));
+
+ for (Pair<PreferenceArray,long[]> testPair
+ : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) {
+ for (Preference testPref : testPair.getFirst()) {
+ double estimate = estimatePreference(factorization, testPref.getUserID(), testPref.getItemID(),
+ factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference());
+ byte result = EstimateConverter.convert(estimate, testPref.getUserID(), testPref.getItemID());
+ out.write(result);
+ }
+ }
+ } finally {
+ Closeables.close(out, false);
+ }
+ log.info("wrote estimates to {}, done.", resultFile.getAbsolutePath());
+ }
+
+ static double estimatePreference(Factorization factorization, long userID, long itemID, float minPreference,
+ float maxPreference) throws NoSuchUserException, NoSuchItemException {
+ double[] userFeatures = factorization.getUserFeatures(userID);
+ double[] itemFeatures = factorization.getItemFeatures(itemID);
+ double estimate = 0;
+ for (int feature = 0; feature < userFeatures.length; feature++) {
+ estimate += userFeatures[feature] * itemFeatures[feature];
+ }
+ if (estimate < minPreference) {
+ estimate = minPreference;
+ } else if (estimate > maxPreference) {
+ estimate = maxPreference;
+ }
+ return estimate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java
new file mode 100644
index 0000000..ce025a9
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.similarity.AbstractItemSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+
+final class HybridSimilarity extends AbstractItemSimilarity {
+
+ private final ItemSimilarity cfSimilarity;
+ private final ItemSimilarity contentSimilarity;
+
+ HybridSimilarity(DataModel dataModel, File dataFileDirectory) throws IOException {
+ super(dataModel);
+ cfSimilarity = new LogLikelihoodSimilarity(dataModel);
+ contentSimilarity = new TrackItemSimilarity(dataFileDirectory);
+ }
+
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
+ return contentSimilarity.itemSimilarity(itemID1, itemID2) * cfSimilarity.itemSimilarity(itemID1, itemID2);
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
+ double[] result = contentSimilarity.itemSimilarities(itemID1, itemID2s);
+ double[] multipliers = cfSimilarity.itemSimilarities(itemID1, itemID2s);
+ for (int i = 0; i < result.length; i++) {
+ result[i] *= multipliers[i];
+ }
+ return result;
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ cfSimilarity.refresh(alreadyRefreshed);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java
new file mode 100644
index 0000000..50fd35e
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java
@@ -0,0 +1,106 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.TreeMap;
+import java.util.concurrent.Callable;
+import java.util.concurrent.atomic.AtomicInteger;
+
+final class Track2Callable implements Callable<UserResult> {
+
+ private static final Logger log = LoggerFactory.getLogger(Track2Callable.class);
+ private static final AtomicInteger COUNT = new AtomicInteger();
+
+ private final Recommender recommender;
+ private final PreferenceArray userTest;
+
+ Track2Callable(Recommender recommender, PreferenceArray userTest) {
+ this.recommender = recommender;
+ this.userTest = userTest;
+ }
+
+ @Override
+ public UserResult call() throws TasteException {
+
+ int testSize = userTest.length();
+ if (testSize != 6) {
+ throw new IllegalArgumentException("Expecting 6 items for user but got " + userTest);
+ }
+ long userID = userTest.get(0).getUserID();
+ TreeMap<Double,Long> estimateToItemID = new TreeMap<>(Collections.reverseOrder());
+
+ for (int i = 0; i < testSize; i++) {
+ long itemID = userTest.getItemID(i);
+ double estimate;
+ try {
+ estimate = recommender.estimatePreference(userID, itemID);
+ } catch (NoSuchItemException nsie) {
+ // OK in the sample data provided before the contest, should never happen otherwise
+ log.warn("Unknown item {}; OK unless this is the real contest data", itemID);
+ continue;
+ }
+
+ if (!Double.isNaN(estimate)) {
+ estimateToItemID.put(estimate, itemID);
+ }
+ }
+
+ Collection<Long> itemIDs = estimateToItemID.values();
+ List<Long> topThree = new ArrayList<>(itemIDs);
+ if (topThree.size() > 3) {
+ topThree = topThree.subList(0, 3);
+ } else if (topThree.size() < 3) {
+ log.warn("Unable to recommend three items for {}", userID);
+ // Some NaNs - just guess at the rest then
+ Collection<Long> newItemIDs = new HashSet<>(3);
+ newItemIDs.addAll(itemIDs);
+ int i = 0;
+ while (i < testSize && newItemIDs.size() < 3) {
+ newItemIDs.add(userTest.getItemID(i));
+ i++;
+ }
+ topThree = new ArrayList<>(newItemIDs);
+ }
+ if (topThree.size() != 3) {
+ throw new IllegalStateException();
+ }
+
+ boolean[] result = new boolean[testSize];
+ for (int i = 0; i < testSize; i++) {
+ result[i] = topThree.contains(userTest.getItemID(i));
+ }
+
+ if (COUNT.incrementAndGet() % 1000 == 0) {
+ log.info("Completed {} users", COUNT.get());
+ }
+
+ return new UserResult(userID, result);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java
new file mode 100644
index 0000000..185a00d
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.recommender.GenericBooleanPrefItemBasedRecommender;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+
+public final class Track2Recommender implements Recommender {
+
+ private final Recommender recommender;
+
+ public Track2Recommender(DataModel dataModel, File dataFileDirectory) throws TasteException {
+ // Change this to whatever you like!
+ ItemSimilarity similarity;
+ try {
+ similarity = new HybridSimilarity(dataModel, dataFileDirectory);
+ } catch (IOException ioe) {
+ throw new TasteException(ioe);
+ }
+ recommender = new GenericBooleanPrefItemBasedRecommender(dataModel, similarity);
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
+ return recommender.recommend(userID, howMany);
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
+ return recommend(userID, howMany, null, includeKnownItems);
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
+ return recommender.recommend(userID, howMany, rescorer, false);
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ return recommender.recommend(userID, howMany, rescorer, includeKnownItems);
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ return recommender.estimatePreference(userID, itemID);
+ }
+
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ recommender.setPreference(userID, itemID, value);
+ }
+
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ recommender.removePreference(userID, itemID);
+ }
+
+ @Override
+ public DataModel getDataModel() {
+ return recommender.getDataModel();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ recommender.refresh(alreadyRefreshed);
+ }
+
+ @Override
+ public String toString() {
+ return "Track1Recommender[recommender:" + recommender + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java
new file mode 100644
index 0000000..09ade5d
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+
+final class Track2RecommenderBuilder implements RecommenderBuilder {
+
+ @Override
+ public Recommender buildRecommender(DataModel dataModel) throws TasteException {
+ return new Track2Recommender(dataModel, ((KDDCupDataModel) dataModel).getDataFileDirectory());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java
new file mode 100644
index 0000000..3cbb61c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java
@@ -0,0 +1,100 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+
+import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedOutputStream;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+
+/**
+ * <p>Runs "track 2" of the KDD Cup competition using whatever recommender is inside {@link Track2Recommender}
+ * and attempts to output the result in the correct contest format.</p>
+ *
+ * <p>Run as: {@code Track2Runner [track 2 data file directory] [output file]}</p>
+ */
+public final class Track2Runner {
+
+ private static final Logger log = LoggerFactory.getLogger(Track2Runner.class);
+
+ private Track2Runner() {
+ }
+
+ public static void main(String[] args) throws Exception {
+
+ File dataFileDirectory = new File(args[0]);
+ if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
+ throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
+ }
+
+ long start = System.currentTimeMillis();
+
+ KDDCupDataModel model = new KDDCupDataModel(KDDCupDataModel.getTrainingFile(dataFileDirectory));
+ Track2Recommender recommender = new Track2Recommender(model, dataFileDirectory);
+
+ long end = System.currentTimeMillis();
+ log.info("Loaded model in {}s", (end - start) / 1000);
+ start = end;
+
+ Collection<Track2Callable> callables = new ArrayList<>();
+ for (Pair<PreferenceArray,long[]> tests : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) {
+ PreferenceArray userTest = tests.getFirst();
+ callables.add(new Track2Callable(recommender, userTest));
+ }
+
+ int cores = Runtime.getRuntime().availableProcessors();
+ log.info("Running on {} cores", cores);
+ ExecutorService executor = Executors.newFixedThreadPool(cores);
+ List<Future<UserResult>> futures = executor.invokeAll(callables);
+ executor.shutdown();
+
+ end = System.currentTimeMillis();
+ log.info("Ran recommendations in {}s", (end - start) / 1000);
+ start = end;
+
+ try (OutputStream out = new BufferedOutputStream(new FileOutputStream(new File(args[1])))){
+ long lastUserID = Long.MIN_VALUE;
+ for (Future<UserResult> future : futures) {
+ UserResult result = future.get();
+ long userID = result.getUserID();
+ if (userID <= lastUserID) {
+ throw new IllegalStateException();
+ }
+ lastUserID = userID;
+ out.write(result.getResultBytes());
+ }
+ }
+
+ end = System.currentTimeMillis();
+ log.info("Wrote output in {}s", (end - start) / 1000);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java
new file mode 100644
index 0000000..abd15f8
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+
+import java.util.regex.Pattern;
+
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+
+final class TrackData {
+
+ private static final Pattern PIPE = Pattern.compile("\\|");
+ private static final String NO_VALUE = "None";
+ static final long NO_VALUE_ID = Long.MIN_VALUE;
+ private static final FastIDSet NO_GENRES = new FastIDSet();
+
+ private final long trackID;
+ private final long albumID;
+ private final long artistID;
+ private final FastIDSet genreIDs;
+
+ TrackData(CharSequence line) {
+ String[] tokens = PIPE.split(line);
+ trackID = Long.parseLong(tokens[0]);
+ albumID = parse(tokens[1]);
+ artistID = parse(tokens[2]);
+ if (tokens.length > 3) {
+ genreIDs = new FastIDSet(tokens.length - 3);
+ for (int i = 3; i < tokens.length; i++) {
+ genreIDs.add(Long.parseLong(tokens[i]));
+ }
+ } else {
+ genreIDs = NO_GENRES;
+ }
+ }
+
+ private static long parse(String value) {
+ return NO_VALUE.equals(value) ? NO_VALUE_ID : Long.parseLong(value);
+ }
+
+ public long getTrackID() {
+ return trackID;
+ }
+
+ public long getAlbumID() {
+ return albumID;
+ }
+
+ public long getArtistID() {
+ return artistID;
+ }
+
+ public FastIDSet getGenreIDs() {
+ return genreIDs;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java
new file mode 100644
index 0000000..3012a84
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java
@@ -0,0 +1,106 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.apache.mahout.common.iterator.FileLineIterable;
+
+final class TrackItemSimilarity implements ItemSimilarity {
+
+ private final FastByIDMap<TrackData> trackData;
+
+ TrackItemSimilarity(File dataFileDirectory) throws IOException {
+ trackData = new FastByIDMap<>();
+ for (String line : new FileLineIterable(KDDCupDataModel.getTrackFile(dataFileDirectory))) {
+ TrackData trackDatum = new TrackData(line);
+ trackData.put(trackDatum.getTrackID(), trackDatum);
+ }
+ }
+
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) {
+ if (itemID1 == itemID2) {
+ return 1.0;
+ }
+ TrackData data1 = trackData.get(itemID1);
+ TrackData data2 = trackData.get(itemID2);
+ if (data1 == null || data2 == null) {
+ return 0.0;
+ }
+
+ // Arbitrarily decide that same album means "very similar"
+ if (data1.getAlbumID() != TrackData.NO_VALUE_ID && data1.getAlbumID() == data2.getAlbumID()) {
+ return 0.9;
+ }
+ // ... and same artist means "fairly similar"
+ if (data1.getArtistID() != TrackData.NO_VALUE_ID && data1.getArtistID() == data2.getArtistID()) {
+ return 0.7;
+ }
+
+ // Tanimoto coefficient similarity based on genre, but maximum value of 0.25
+ FastIDSet genres1 = data1.getGenreIDs();
+ FastIDSet genres2 = data2.getGenreIDs();
+ if (genres1 == null || genres2 == null) {
+ return 0.0;
+ }
+ int intersectionSize = genres1.intersectionSize(genres2);
+ if (intersectionSize == 0) {
+ return 0.0;
+ }
+ int unionSize = genres1.size() + genres2.size() - intersectionSize;
+ return intersectionSize / (4.0 * unionSize);
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) {
+ int length = itemID2s.length;
+ double[] result = new double[length];
+ for (int i = 0; i < length; i++) {
+ result[i] = itemSimilarity(itemID1, itemID2s[i]);
+ }
+ return result;
+ }
+
+ @Override
+ public long[] allSimilarItemIDs(long itemID) {
+ FastIDSet allSimilarItemIDs = new FastIDSet();
+ LongPrimitiveIterator allItemIDs = trackData.keySetIterator();
+ while (allItemIDs.hasNext()) {
+ long possiblySimilarItemID = allItemIDs.nextLong();
+ if (!Double.isNaN(itemSimilarity(itemID, possiblySimilarItemID))) {
+ allSimilarItemIDs.add(possiblySimilarItemID);
+ }
+ }
+ return allSimilarItemIDs.toArray();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // do nothing
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java
new file mode 100644
index 0000000..e554d10
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.example.kddcup.track2;
+
+final class UserResult {
+
+ private final long userID;
+ private final byte[] resultBytes;
+
+ UserResult(long userID, boolean[] result) {
+
+ this.userID = userID;
+
+ int trueCount = 0;
+ for (boolean b : result) {
+ if (b) {
+ trueCount++;
+ }
+ }
+ if (trueCount != 3) {
+ throw new IllegalStateException();
+ }
+
+ resultBytes = new byte[result.length];
+ for (int i = 0; i < result.length; i++) {
+ resultBytes[i] = (byte) (result[i] ? '1' : '0');
+ }
+ }
+
+ public long getUserID() {
+ return userID;
+ }
+
+ public byte[] getResultBytes() {
+ return resultBytes;
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java
new file mode 100644
index 0000000..22f122e
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java
@@ -0,0 +1,140 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.example.als.netflix;
+
+import com.google.common.base.Preconditions;
+import org.apache.commons.io.Charsets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.common.iterator.FileLineIterable;
+import org.apache.mahout.common.iterator.FileLineIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.regex.Pattern;
+
+/** converts the raw files provided by netflix to an appropriate input format */
+public final class NetflixDatasetConverter {
+
+ private static final Logger log = LoggerFactory.getLogger(NetflixDatasetConverter.class);
+
+ private static final Pattern SEPARATOR = Pattern.compile(",");
+ private static final String MOVIE_DENOTER = ":";
+ private static final String TAB = "\t";
+ private static final String NEWLINE = "\n";
+
+ private NetflixDatasetConverter() {
+ }
+
+ public static void main(String[] args) throws IOException {
+
+ if (args.length != 4) {
+ System.err.println("Usage: NetflixDatasetConverter /path/to/training_set/ /path/to/qualifying.txt "
+ + "/path/to/judging.txt /path/to/destination");
+ return;
+ }
+
+ String trainingDataDir = args[0];
+ String qualifyingTxt = args[1];
+ String judgingTxt = args[2];
+ Path outputPath = new Path(args[3]);
+
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(outputPath.toUri(), conf);
+
+ Preconditions.checkArgument(trainingDataDir != null, "Training Data location needs to be specified");
+ log.info("Creating training set at {}/trainingSet/ratings.tsv ...", outputPath);
+ try (BufferedWriter writer =
+ new BufferedWriter(
+ new OutputStreamWriter(
+ fs.create(new Path(outputPath, "trainingSet/ratings.tsv")), Charsets.UTF_8))){
+
+ int ratingsProcessed = 0;
+ for (File movieRatings : new File(trainingDataDir).listFiles()) {
+ try (FileLineIterator lines = new FileLineIterator(movieRatings)) {
+ boolean firstLineRead = false;
+ String movieID = null;
+ while (lines.hasNext()) {
+ String line = lines.next();
+ if (firstLineRead) {
+ String[] tokens = SEPARATOR.split(line);
+ String userID = tokens[0];
+ String rating = tokens[1];
+ writer.write(userID + TAB + movieID + TAB + rating + NEWLINE);
+ ratingsProcessed++;
+ if (ratingsProcessed % 1000000 == 0) {
+ log.info("{} ratings processed...", ratingsProcessed);
+ }
+ } else {
+ movieID = line.replaceAll(MOVIE_DENOTER, "");
+ firstLineRead = true;
+ }
+ }
+ }
+
+ }
+ log.info("{} ratings processed. done.", ratingsProcessed);
+ }
+
+ log.info("Reading probes...");
+ List<Preference> probes = new ArrayList<>(2817131);
+ long currentMovieID = -1;
+ for (String line : new FileLineIterable(new File(qualifyingTxt))) {
+ if (line.contains(MOVIE_DENOTER)) {
+ currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, ""));
+ } else {
+ long userID = Long.parseLong(SEPARATOR.split(line)[0]);
+ probes.add(new GenericPreference(userID, currentMovieID, 0));
+ }
+ }
+ log.info("{} probes read...", probes.size());
+
+ log.info("Reading ratings, creating probe set at {}/probeSet/ratings.tsv ...", outputPath);
+ try (BufferedWriter writer =
+ new BufferedWriter(new OutputStreamWriter(
+ fs.create(new Path(outputPath, "probeSet/ratings.tsv")), Charsets.UTF_8))){
+ int ratingsProcessed = 0;
+ for (String line : new FileLineIterable(new File(judgingTxt))) {
+ if (line.contains(MOVIE_DENOTER)) {
+ currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, ""));
+ } else {
+ float rating = Float.parseFloat(SEPARATOR.split(line)[0]);
+ Preference pref = probes.get(ratingsProcessed);
+ Preconditions.checkState(pref.getItemID() == currentMovieID);
+ ratingsProcessed++;
+ writer.write(pref.getUserID() + TAB + pref.getItemID() + TAB + rating + NEWLINE);
+ if (ratingsProcessed % 1000000 == 0) {
+ log.info("{} ratings processed...", ratingsProcessed);
+ }
+ }
+ }
+ log.info("{} ratings processed. done.", ratingsProcessed);
+ }
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java
new file mode 100644
index 0000000..8021d00
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute.example;
+
+import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
+import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.precompute.FileSimilarItemsWriter;
+import org.apache.mahout.cf.taste.impl.similarity.precompute.MultithreadedBatchItemSimilarities;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
+import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities;
+
+import java.io.File;
+
+/**
+ * Example that precomputes all item similarities of the Movielens1M dataset
+ *
+ * Usage: download movielens1M from http://www.grouplens.org/node/73 , unzip it and invoke this code with the path
+ * to the ratings.dat file as argument
+ *
+ */
+public final class BatchItemSimilaritiesGroupLens {
+
+ private BatchItemSimilaritiesGroupLens() {}
+
+ public static void main(String[] args) throws Exception {
+
+ if (args.length != 1) {
+ System.err.println("Need path to ratings.dat of the movielens1M dataset as argument!");
+ System.exit(-1);
+ }
+
+ File resultFile = new File(System.getProperty("java.io.tmpdir"), "similarities.csv");
+ if (resultFile.exists()) {
+ resultFile.delete();
+ }
+
+ DataModel dataModel = new GroupLensDataModel(new File(args[0]));
+ ItemBasedRecommender recommender = new GenericItemBasedRecommender(dataModel,
+ new LogLikelihoodSimilarity(dataModel));
+ BatchItemSimilarities batch = new MultithreadedBatchItemSimilarities(recommender, 5);
+
+ int numSimilarities = batch.computeItemSimilarities(Runtime.getRuntime().availableProcessors(), 1,
+ new FileSimilarItemsWriter(resultFile));
+
+ System.out.println("Computed " + numSimilarities + " similarities for " + dataModel.getNumItems() + " items "
+ + "and saved them to " + resultFile.getAbsolutePath());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java
new file mode 100644
index 0000000..7ee9b17
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute.example;
+
+import com.google.common.io.Files;
+import com.google.common.io.InputSupplier;
+import com.google.common.io.Resources;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.net.URL;
+import java.util.regex.Pattern;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.cf.taste.impl.model.file.FileDataModel;
+import org.apache.mahout.common.iterator.FileLineIterable;
+
+public final class GroupLensDataModel extends FileDataModel {
+
+ private static final String COLON_DELIMTER = "::";
+ private static final Pattern COLON_DELIMITER_PATTERN = Pattern.compile(COLON_DELIMTER);
+
+ public GroupLensDataModel() throws IOException {
+ this(readResourceToTempFile("/org/apache/mahout/cf/taste/example/grouplens/ratings.dat"));
+ }
+
+ /**
+ * @param ratingsFile GroupLens ratings.dat file in its native format
+ * @throws IOException if an error occurs while reading or writing files
+ */
+ public GroupLensDataModel(File ratingsFile) throws IOException {
+ super(convertGLFile(ratingsFile));
+ }
+
+ private static File convertGLFile(File originalFile) throws IOException {
+ // Now translate the file; remove commas, then convert "::" delimiter to comma
+ File resultFile = new File(new File(System.getProperty("java.io.tmpdir")), "ratings.txt");
+ if (resultFile.exists()) {
+ resultFile.delete();
+ }
+ try (Writer writer = new OutputStreamWriter(new FileOutputStream(resultFile), Charsets.UTF_8)){
+ for (String line : new FileLineIterable(originalFile, false)) {
+ int lastDelimiterStart = line.lastIndexOf(COLON_DELIMTER);
+ if (lastDelimiterStart < 0) {
+ throw new IOException("Unexpected input format on line: " + line);
+ }
+ String subLine = line.substring(0, lastDelimiterStart);
+ String convertedLine = COLON_DELIMITER_PATTERN.matcher(subLine).replaceAll(",");
+ writer.write(convertedLine);
+ writer.write('\n');
+ }
+ } catch (IOException ioe) {
+ resultFile.delete();
+ throw ioe;
+ }
+ return resultFile;
+ }
+
+ public static File readResourceToTempFile(String resourceName) throws IOException {
+ InputSupplier<? extends InputStream> inSupplier;
+ try {
+ URL resourceURL = Resources.getResource(GroupLensDataModel.class, resourceName);
+ inSupplier = Resources.newInputStreamSupplier(resourceURL);
+ } catch (IllegalArgumentException iae) {
+ File resourceFile = new File("src/main/java" + resourceName);
+ inSupplier = Files.newInputStreamSupplier(resourceFile);
+ }
+ File tempFile = File.createTempFile("taste", null);
+ tempFile.deleteOnExit();
+ Files.copy(inSupplier, tempFile);
+ return tempFile;
+ }
+
+ @Override
+ public String toString() {
+ return "GroupLensDataModel";
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java
new file mode 100644
index 0000000..5cec51c
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java
@@ -0,0 +1,128 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import com.google.common.collect.ConcurrentHashMultiset;
+import com.google.common.collect.Multiset;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.commons.io.Charsets;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.Reader;
+import java.io.StringReader;
+import java.text.SimpleDateFormat;
+import java.util.Collection;
+import java.util.Date;
+import java.util.Locale;
+import java.util.Random;
+
+public final class NewsgroupHelper {
+
+ private static final SimpleDateFormat[] DATE_FORMATS = {
+ new SimpleDateFormat("", Locale.ENGLISH),
+ new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH),
+ new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)
+ };
+
+ public static final int FEATURES = 10000;
+ // 1997-01-15 00:01:00 GMT
+ private static final long DATE_REFERENCE = 853286460;
+ private static final long MONTH = 30 * 24 * 3600;
+ private static final long WEEK = 7 * 24 * 3600;
+
+ private final Random rand = RandomUtils.getRandom();
+ private final Analyzer analyzer = new StandardAnalyzer();
+ private final FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
+ private final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
+
+ public FeatureVectorEncoder getEncoder() {
+ return encoder;
+ }
+
+ public FeatureVectorEncoder getBias() {
+ return bias;
+ }
+
+ public Random getRandom() {
+ return rand;
+ }
+
+ public Vector encodeFeatureVector(File file, int actual, int leakType, Multiset<String> overallCounts)
+ throws IOException {
+ long date = (long) (1000 * (DATE_REFERENCE + actual * MONTH + 1 * WEEK * rand.nextDouble()));
+ Multiset<String> words = ConcurrentHashMultiset.create();
+
+ try (BufferedReader reader = Files.newReader(file, Charsets.UTF_8)) {
+ String line = reader.readLine();
+ Reader dateString = new StringReader(DATE_FORMATS[leakType % 3].format(new Date(date)));
+ countWords(analyzer, words, dateString, overallCounts);
+ while (line != null && !line.isEmpty()) {
+ boolean countHeader = (
+ line.startsWith("From:") || line.startsWith("Subject:")
+ || line.startsWith("Keywords:") || line.startsWith("Summary:")) && leakType < 6;
+ do {
+ Reader in = new StringReader(line);
+ if (countHeader) {
+ countWords(analyzer, words, in, overallCounts);
+ }
+ line = reader.readLine();
+ } while (line != null && line.startsWith(" "));
+ }
+ if (leakType < 3) {
+ countWords(analyzer, words, reader, overallCounts);
+ }
+ }
+
+ Vector v = new RandomAccessSparseVector(FEATURES);
+ bias.addToVector("", 1, v);
+ for (String word : words.elementSet()) {
+ encoder.addToVector(word, Math.log1p(words.count(word)), v);
+ }
+
+ return v;
+ }
+
+ public static void countWords(Analyzer analyzer,
+ Collection<String> words,
+ Reader in,
+ Multiset<String> overallCounts) throws IOException {
+ TokenStream ts = analyzer.tokenStream("text", in);
+ ts.addAttribute(CharTermAttribute.class);
+ ts.reset();
+ while (ts.incrementToken()) {
+ String s = ts.getAttribute(CharTermAttribute.class).toString();
+ words.add(s);
+ }
+ overallCounts.addAll(words);
+ ts.end();
+ Closeables.close(ts, true);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java
new file mode 100644
index 0000000..16e9d80
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.email;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.Locale;
+import java.util.regex.Pattern;
+
+/**
+ * Convert the labels created by the {@link org.apache.mahout.utils.email.MailProcessor} to one consumable
+ * by the classifiers
+ */
+public class PrepEmailMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
+
+ private static final Pattern DASH_DOT = Pattern.compile("-|\\.");
+ private static final Pattern SLASH = Pattern.compile("\\/");
+
+ private boolean useListName = false; //if true, use the project name and the list name in label creation
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ useListName = Boolean.parseBoolean(context.getConfiguration().get(PrepEmailVectorsDriver.USE_LIST_NAME));
+ }
+
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ String input = key.toString();
+ ///Example: /cocoon.apache.org/dev/200307.gz/001401c3414f$8394e160$1e01a8c0@WRPO
+ String[] splits = SLASH.split(input);
+ //we need the first two splits;
+ if (splits.length >= 3) {
+ StringBuilder bldr = new StringBuilder();
+ bldr.append(escape(splits[1]));
+ if (useListName) {
+ bldr.append('_').append(escape(splits[2]));
+ }
+ context.write(new Text(bldr.toString()), value);
+ }
+
+ }
+
+ private static String escape(CharSequence value) {
+ return DASH_DOT.matcher(value).replaceAll("_").toLowerCase(Locale.ENGLISH);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java
new file mode 100644
index 0000000..da6e613
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.email;
+
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+public class PrepEmailReducer extends Reducer<Text, VectorWritable, Text, VectorWritable> {
+
+ private long maxItemsPerLabel = 10000;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ maxItemsPerLabel = Long.parseLong(context.getConfiguration().get(PrepEmailVectorsDriver.ITEMS_PER_CLASS));
+ }
+
+ @Override
+ protected void reduce(Text key, Iterable<VectorWritable> values, Context context)
+ throws IOException, InterruptedException {
+ //TODO: support randomization? Likely not needed due to the SplitInput utility which does random selection
+ long i = 0;
+ Iterator<VectorWritable> iterator = values.iterator();
+ while (i < maxItemsPerLabel && iterator.hasNext()) {
+ context.write(key, iterator.next());
+ i++;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java
new file mode 100644
index 0000000..8fba739
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.email;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.VectorWritable;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Convert the labels generated by {@link org.apache.mahout.text.SequenceFilesFromMailArchives} and
+ * {@link org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles} to ones consumable by the classifiers. We do this
+ * here b/c if it is done in the creation of sparse vectors, the Reducer collapses all the vectors.
+ */
+public class PrepEmailVectorsDriver extends AbstractJob {
+
+ public static final String ITEMS_PER_CLASS = "itemsPerClass";
+ public static final String USE_LIST_NAME = "USE_LIST_NAME";
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new PrepEmailVectorsDriver(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addOption("maxItemsPerLabel", "mipl", "The maximum number of items per label. Can be useful for making the "
+ + "training sets the same size", String.valueOf(100000));
+ addOption(buildOption("useListName", "ul", "Use the name of the list as part of the label. If not set, then "
+ + "just use the project name", false, false, "false"));
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ Job convertJob = prepareJob(input, output, SequenceFileInputFormat.class, PrepEmailMapper.class, Text.class,
+ VectorWritable.class, PrepEmailReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+ convertJob.getConfiguration().set(ITEMS_PER_CLASS, getOption("maxItemsPerLabel"));
+ convertJob.getConfiguration().set(USE_LIST_NAME, String.valueOf(hasOption("useListName")));
+
+ boolean succeeded = convertJob.waitForCompletion(true);
+ return succeeded ? 0 : -1;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/02f75f99/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
new file mode 100644
index 0000000..9c0ef56
--- /dev/null
+++ b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
@@ -0,0 +1,277 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import com.google.common.io.Resources;
+import org.apache.commons.io.Charsets;
+import org.apache.mahout.math.Matrix;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.net.URL;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+/**
+ * This class implements a sample program that uses a pre-tagged training data
+ * set to train an HMM model as a POS tagger. The training data is automatically
+ * downloaded from the following URL:
+ * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt It then
+ * trains an HMM Model using supervised learning and tests the model on the
+ * following test data set:
+ * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further
+ * details regarding the data files can be found at
+ * http://flexcrfs.sourceforge.net/#Case_Study
+ */
+public final class PosTagger {
+
+ private static final Logger log = LoggerFactory.getLogger(PosTagger.class);
+
+ private static final Pattern SPACE = Pattern.compile(" ");
+ private static final Pattern SPACES = Pattern.compile("[ ]+");
+
+ /**
+ * No public constructors for utility classes.
+ */
+ private PosTagger() {
+ // nothing to do here really.
+ }
+
+ /**
+ * Model trained in the example.
+ */
+ private static HmmModel taggingModel;
+
+ /**
+ * Map for storing the IDs for the POS tags (hidden states)
+ */
+ private static Map<String, Integer> tagIDs;
+
+ /**
+ * Counter for the next assigned POS tag ID The value of 0 is reserved for
+ * "unknown POS tag"
+ */
+ private static int nextTagId;
+
+ /**
+ * Map for storing the IDs for observed words (observed states)
+ */
+ private static Map<String, Integer> wordIDs;
+
+ /**
+ * Counter for the next assigned word ID The value of 0 is reserved for
+ * "unknown word"
+ */
+ private static int nextWordId = 1; // 0 is reserved for "unknown word"
+
+ /**
+ * Used for storing a list of POS tags of read sentences.
+ */
+ private static List<int[]> hiddenSequences;
+
+ /**
+ * Used for storing a list of word tags of read sentences.
+ */
+ private static List<int[]> observedSequences;
+
+ /**
+ * number of read lines
+ */
+ private static int readLines;
+
+ /**
+ * Given an URL, this function fetches the data file, parses it, assigns POS
+ * Tag/word IDs and fills the hiddenSequences/observedSequences lists with
+ * data from those files. The data is expected to be in the following format
+ * (one word per line): word pos-tag np-tag sentences are closed with the .
+ * pos tag
+ *
+ * @param url Where the data file is stored
+ * @param assignIDs Should IDs for unknown words/tags be assigned? (Needed for
+ * training data, not needed for test data)
+ * @throws IOException in case data file cannot be read.
+ */
+ private static void readFromURL(String url, boolean assignIDs) throws IOException {
+ // initialize the data structure
+ hiddenSequences = new LinkedList<>();
+ observedSequences = new LinkedList<>();
+ readLines = 0;
+
+ // now read line by line of the input file
+ List<Integer> observedSequence = new LinkedList<>();
+ List<Integer> hiddenSequence = new LinkedList<>();
+
+ for (String line :Resources.readLines(new URL(url), Charsets.UTF_8)) {
+ if (line.isEmpty()) {
+ // new sentence starts
+ int[] observedSequenceArray = new int[observedSequence.size()];
+ int[] hiddenSequenceArray = new int[hiddenSequence.size()];
+ for (int i = 0; i < observedSequence.size(); ++i) {
+ observedSequenceArray[i] = observedSequence.get(i);
+ hiddenSequenceArray[i] = hiddenSequence.get(i);
+ }
+ // now register those arrays
+ hiddenSequences.add(hiddenSequenceArray);
+ observedSequences.add(observedSequenceArray);
+ // and reset the linked lists
+ observedSequence.clear();
+ hiddenSequence.clear();
+ continue;
+ }
+ readLines++;
+ // we expect the format [word] [POS tag] [NP tag]
+ String[] tags = SPACE.split(line);
+ // when analyzing the training set, assign IDs
+ if (assignIDs) {
+ if (!wordIDs.containsKey(tags[0])) {
+ wordIDs.put(tags[0], nextWordId++);
+ }
+ if (!tagIDs.containsKey(tags[1])) {
+ tagIDs.put(tags[1], nextTagId++);
+ }
+ }
+ // determine the IDs
+ Integer wordID = wordIDs.get(tags[0]);
+ Integer tagID = tagIDs.get(tags[1]);
+ // now construct the current sequence
+ if (wordID == null) {
+ observedSequence.add(0);
+ } else {
+ observedSequence.add(wordID);
+ }
+
+ if (tagID == null) {
+ hiddenSequence.add(0);
+ } else {
+ hiddenSequence.add(tagID);
+ }
+ }
+
+ // if there is still something in the pipe, register it
+ if (!observedSequence.isEmpty()) {
+ int[] observedSequenceArray = new int[observedSequence.size()];
+ int[] hiddenSequenceArray = new int[hiddenSequence.size()];
+ for (int i = 0; i < observedSequence.size(); ++i) {
+ observedSequenceArray[i] = observedSequence.get(i);
+ hiddenSequenceArray[i] = hiddenSequence.get(i);
+ }
+ // now register those arrays
+ hiddenSequences.add(hiddenSequenceArray);
+ observedSequences.add(observedSequenceArray);
+ }
+ }
+
+ private static void trainModel(String trainingURL) throws IOException {
+ tagIDs = new HashMap<>(44); // we expect 44 distinct tags
+ wordIDs = new HashMap<>(19122); // we expect 19122
+ // distinct words
+ log.info("Reading and parsing training data file from URL: {}", trainingURL);
+ long start = System.currentTimeMillis();
+ readFromURL(trainingURL, true);
+ long end = System.currentTimeMillis();
+ double duration = (end - start) / 1000.0;
+ log.info("Parsing done in {} seconds!", duration);
+ log.info("Read {} lines containing {} sentences with a total of {} distinct words and {} distinct POS tags.",
+ readLines, hiddenSequences.size(), nextWordId - 1, nextTagId - 1);
+ start = System.currentTimeMillis();
+ taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId,
+ hiddenSequences, observedSequences, 0.05);
+ // we have to adjust the model a bit,
+ // since we assume a higher probability that a given unknown word is NNP
+ // than anything else
+ Matrix emissions = taggingModel.getEmissionMatrix();
+ for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i) {
+ emissions.setQuick(i, 0, 0.1 / taggingModel.getNrOfHiddenStates());
+ }
+ int nnptag = tagIDs.get("NNP");
+ emissions.setQuick(nnptag, 0, 1 / (double) taggingModel.getNrOfHiddenStates());
+ // re-normalize the emission probabilities
+ HmmUtils.normalizeModel(taggingModel);
+ // now register the names
+ taggingModel.registerHiddenStateNames(tagIDs);
+ taggingModel.registerOutputStateNames(wordIDs);
+ end = System.currentTimeMillis();
+ duration = (end - start) / 1000.0;
+ log.info("Trained HMM models in {} seconds!", duration);
+ }
+
+ private static void testModel(String testingURL) throws IOException {
+ log.info("Reading and parsing test data file from URL: {}", testingURL);
+ long start = System.currentTimeMillis();
+ readFromURL(testingURL, false);
+ long end = System.currentTimeMillis();
+ double duration = (end - start) / 1000.0;
+ log.info("Parsing done in {} seconds!", duration);
+ log.info("Read {} lines containing {} sentences.", readLines, hiddenSequences.size());
+
+ start = System.currentTimeMillis();
+ int errorCount = 0;
+ int totalCount = 0;
+ for (int i = 0; i < observedSequences.size(); ++i) {
+ // fetch the viterbi path as the POS tag for this observed sequence
+ int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences.get(i), false);
+ // compare with the expected
+ int[] posExpected = hiddenSequences.get(i);
+ for (int j = 0; j < posExpected.length; ++j) {
+ totalCount++;
+ if (posEstimate[j] != posExpected[j]) {
+ errorCount++;
+ }
+ }
+ }
+ end = System.currentTimeMillis();
+ duration = (end - start) / 1000.0;
+ log.info("POS tagged test file in {} seconds!", duration);
+ double errorRate = (double) errorCount / totalCount;
+ log.info("Tagged the test file with an error rate of: {}", errorRate);
+ }
+
+ private static List<String> tagSentence(String sentence) {
+ // first, we need to isolate all punctuation characters, so that they
+ // can be recognized
+ sentence = sentence.replaceAll("[,.!?:;\"]", " $0 ");
+ sentence = sentence.replaceAll("''", " '' ");
+ // now we tokenize the sentence
+ String[] tokens = SPACES.split(sentence);
+ // now generate the observed sequence
+ int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays.asList(tokens), true, 0);
+ // POS tag this observedSequence
+ int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence, false);
+ // and now decode the tag names
+ return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false, null);
+ }
+
+ public static void main(String[] args) throws IOException {
+ // generate the model from URL
+ trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt");
+ testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt");
+ // tag an exemplary sentence
+ String test = "McDonalds is a huge company with many employees .";
+ String[] testWords = SPACE.split(test);
+ List<String> posTags = tagSentence(test);
+ for (int i = 0; i < posTags.size(); ++i) {
+ log.info("{}[{}]", testWords[i], posTags.get(i));
+ }
+ }
+
+}