You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by ra...@apache.org on 2018/06/28 14:55:14 UTC
[46/51] [partial] mahout git commit: NO-JIRA Clean up MR refactor
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
deleted file mode 100644
index a99d54c..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/ParallelArraysSGDFactorizer.java
+++ /dev/null
@@ -1,265 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
-
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
-import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
-import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
-import org.apache.mahout.cf.taste.impl.common.RunningAverage;
-import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
-import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.model.Preference;
-import org.apache.mahout.common.RandomUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.Collection;
-import java.util.Random;
-
-/**
- * {@link Factorizer} based on Simon Funk's famous article <a href="http://sifter.org/~simon/journal/20061211.html">
- * "Netflix Update: Try this at home"</a>.
- *
- * Attempts to be as memory efficient as possible, only iterating once through the
- * {@link FactorizablePreferences} or {@link DataModel} while copying everything to primitive arrays.
- * Learning works in place on these datastructures after that.
- */
-public class ParallelArraysSGDFactorizer implements Factorizer {
-
- public static final double DEFAULT_LEARNING_RATE = 0.005;
- public static final double DEFAULT_PREVENT_OVERFITTING = 0.02;
- public static final double DEFAULT_RANDOM_NOISE = 0.005;
-
- private final int numFeatures;
- private final int numIterations;
- private final float minPreference;
- private final float maxPreference;
-
- private final Random random;
- private final double learningRate;
- private final double preventOverfitting;
-
- private final FastByIDMap<Integer> userIDMapping;
- private final FastByIDMap<Integer> itemIDMapping;
-
- private final double[][] userFeatures;
- private final double[][] itemFeatures;
-
- private final int[] userIndexes;
- private final int[] itemIndexes;
- private final float[] values;
-
- private final double defaultValue;
- private final double interval;
- private final double[] cachedEstimates;
-
-
- private static final Logger log = LoggerFactory.getLogger(ParallelArraysSGDFactorizer.class);
-
- public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) {
- this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, DEFAULT_LEARNING_RATE,
- DEFAULT_PREVENT_OVERFITTING, DEFAULT_RANDOM_NOISE);
- }
-
- public ParallelArraysSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations, double learningRate,
- double preventOverfitting, double randomNoise) {
- this(new DataModelFactorizablePreferences(dataModel), numFeatures, numIterations, learningRate, preventOverfitting,
- randomNoise);
- }
-
- public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePrefs, int numFeatures, int numIterations) {
- this(factorizablePrefs, numFeatures, numIterations, DEFAULT_LEARNING_RATE, DEFAULT_PREVENT_OVERFITTING,
- DEFAULT_RANDOM_NOISE);
- }
-
- public ParallelArraysSGDFactorizer(FactorizablePreferences factorizablePreferences, int numFeatures,
- int numIterations, double learningRate, double preventOverfitting, double randomNoise) {
-
- this.numFeatures = numFeatures;
- this.numIterations = numIterations;
- minPreference = factorizablePreferences.getMinPreference();
- maxPreference = factorizablePreferences.getMaxPreference();
-
- this.random = RandomUtils.getRandom();
- this.learningRate = learningRate;
- this.preventOverfitting = preventOverfitting;
-
- int numUsers = factorizablePreferences.numUsers();
- int numItems = factorizablePreferences.numItems();
- int numPrefs = factorizablePreferences.numPreferences();
-
- log.info("Mapping {} users...", numUsers);
- userIDMapping = new FastByIDMap<>(numUsers);
- int index = 0;
- LongPrimitiveIterator userIterator = factorizablePreferences.getUserIDs();
- while (userIterator.hasNext()) {
- userIDMapping.put(userIterator.nextLong(), index++);
- }
-
- log.info("Mapping {} items", numItems);
- itemIDMapping = new FastByIDMap<>(numItems);
- index = 0;
- LongPrimitiveIterator itemIterator = factorizablePreferences.getItemIDs();
- while (itemIterator.hasNext()) {
- itemIDMapping.put(itemIterator.nextLong(), index++);
- }
-
- this.userIndexes = new int[numPrefs];
- this.itemIndexes = new int[numPrefs];
- this.values = new float[numPrefs];
- this.cachedEstimates = new double[numPrefs];
-
- index = 0;
- log.info("Loading {} preferences into memory", numPrefs);
- RunningAverage average = new FullRunningAverage();
- for (Preference preference : factorizablePreferences.getPreferences()) {
- userIndexes[index] = userIDMapping.get(preference.getUserID());
- itemIndexes[index] = itemIDMapping.get(preference.getItemID());
- values[index] = preference.getValue();
- cachedEstimates[index] = 0;
-
- average.addDatum(preference.getValue());
-
- index++;
- if (index % 1000000 == 0) {
- log.info("Processed {} preferences", index);
- }
- }
- log.info("Processed {} preferences, done.", index);
-
- double averagePreference = average.getAverage();
- log.info("Average preference value is {}", averagePreference);
-
- double prefInterval = factorizablePreferences.getMaxPreference() - factorizablePreferences.getMinPreference();
- defaultValue = Math.sqrt((averagePreference - prefInterval * 0.1) / numFeatures);
- interval = prefInterval * 0.1 / numFeatures;
-
- userFeatures = new double[numUsers][numFeatures];
- itemFeatures = new double[numItems][numFeatures];
-
- log.info("Initializing feature vectors...");
- for (int feature = 0; feature < numFeatures; feature++) {
- for (int userIndex = 0; userIndex < numUsers; userIndex++) {
- userFeatures[userIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise;
- }
- for (int itemIndex = 0; itemIndex < numItems; itemIndex++) {
- itemFeatures[itemIndex][feature] = defaultValue + (random.nextDouble() - 0.5) * interval * randomNoise;
- }
- }
- }
-
- @Override
- public Factorization factorize() throws TasteException {
- for (int feature = 0; feature < numFeatures; feature++) {
- log.info("Shuffling preferences...");
- shufflePreferences();
- log.info("Starting training of feature {} ...", feature);
- for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) {
- if (currentIteration == numIterations - 1) {
- double rmse = trainingIterationWithRmse(feature);
- log.info("Finished training feature {} with RMSE {}", feature, rmse);
- } else {
- trainingIteration(feature);
- }
- }
- if (feature < numFeatures - 1) {
- log.info("Updating cache...");
- for (int index = 0; index < userIndexes.length; index++) {
- cachedEstimates[index] = estimate(userIndexes[index], itemIndexes[index], feature, cachedEstimates[index],
- false);
- }
- }
- }
- log.info("Factorization done");
- return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
- }
-
- private void trainingIteration(int feature) {
- for (int index = 0; index < userIndexes.length; index++) {
- train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]);
- }
- }
-
- private double trainingIterationWithRmse(int feature) {
- double rmse = 0.0;
- for (int index = 0; index < userIndexes.length; index++) {
- double error = train(userIndexes[index], itemIndexes[index], feature, values[index], cachedEstimates[index]);
- rmse += error * error;
- }
- return Math.sqrt(rmse / userIndexes.length);
- }
-
- private double estimate(int userIndex, int itemIndex, int feature, double cachedEstimate, boolean trailing) {
- double sum = cachedEstimate;
- sum += userFeatures[userIndex][feature] * itemFeatures[itemIndex][feature];
- if (trailing) {
- sum += (numFeatures - feature - 1) * (defaultValue + interval) * (defaultValue + interval);
- if (sum > maxPreference) {
- sum = maxPreference;
- } else if (sum < minPreference) {
- sum = minPreference;
- }
- }
- return sum;
- }
-
- public double train(int userIndex, int itemIndex, int feature, double original, double cachedEstimate) {
- double error = original - estimate(userIndex, itemIndex, feature, cachedEstimate, true);
- double[] userVector = userFeatures[userIndex];
- double[] itemVector = itemFeatures[itemIndex];
-
- userVector[feature] += learningRate * (error * itemVector[feature] - preventOverfitting * userVector[feature]);
- itemVector[feature] += learningRate * (error * userVector[feature] - preventOverfitting * itemVector[feature]);
-
- return error;
- }
-
- protected void shufflePreferences() {
- /* Durstenfeld shuffle */
- for (int currentPos = userIndexes.length - 1; currentPos > 0; currentPos--) {
- int swapPos = random.nextInt(currentPos + 1);
- swapPreferences(currentPos, swapPos);
- }
- }
-
- private void swapPreferences(int posA, int posB) {
- int tmpUserIndex = userIndexes[posA];
- int tmpItemIndex = itemIndexes[posA];
- float tmpValue = values[posA];
- double tmpEstimate = cachedEstimates[posA];
-
- userIndexes[posA] = userIndexes[posB];
- itemIndexes[posA] = itemIndexes[posB];
- values[posA] = values[posB];
- cachedEstimates[posA] = cachedEstimates[posB];
-
- userIndexes[posB] = tmpUserIndex;
- itemIndexes[posB] = tmpItemIndex;
- values[posB] = tmpValue;
- cachedEstimates[posB] = tmpEstimate;
- }
-
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- // do nothing
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
deleted file mode 100644
index 5cce02d..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track1/svd/Track1SVDRunner.java
+++ /dev/null
@@ -1,141 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.example.kddcup.track1.svd;
-
-import java.io.BufferedOutputStream;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.OutputStream;
-
-import com.google.common.io.Closeables;
-import org.apache.mahout.cf.taste.common.NoSuchItemException;
-import org.apache.mahout.cf.taste.common.NoSuchUserException;
-import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.example.kddcup.track1.EstimateConverter;
-import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
-import org.apache.mahout.cf.taste.impl.common.RunningAverage;
-import org.apache.mahout.cf.taste.impl.recommender.svd.Factorization;
-import org.apache.mahout.cf.taste.impl.recommender.svd.Factorizer;
-import org.apache.mahout.cf.taste.model.Preference;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.common.Pair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-/**
- * run an SVD factorization of the KDD track1 data.
- *
- * needs at least 6-7GB of memory, tested with -Xms6700M -Xmx6700M
- *
- */
-public final class Track1SVDRunner {
-
- private static final Logger log = LoggerFactory.getLogger(Track1SVDRunner.class);
-
- private Track1SVDRunner() {
- }
-
- public static void main(String[] args) throws Exception {
-
- if (args.length != 2) {
- System.err.println("Necessary arguments: <kddDataFileDirectory> <resultFile>");
- return;
- }
-
- File dataFileDirectory = new File(args[0]);
- if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
- throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
- }
-
- File resultFile = new File(args[1]);
-
- /* the knobs to turn */
- int numFeatures = 20;
- int numIterations = 5;
- double learningRate = 0.0001;
- double preventOverfitting = 0.002;
- double randomNoise = 0.0001;
-
-
- KDDCupFactorizablePreferences factorizablePreferences =
- new KDDCupFactorizablePreferences(KDDCupDataModel.getTrainingFile(dataFileDirectory));
-
- Factorizer sgdFactorizer = new ParallelArraysSGDFactorizer(factorizablePreferences, numFeatures, numIterations,
- learningRate, preventOverfitting, randomNoise);
-
- Factorization factorization = sgdFactorizer.factorize();
-
- log.info("Estimating validation preferences...");
- int prefsProcessed = 0;
- RunningAverage average = new FullRunningAverage();
- for (Pair<PreferenceArray,long[]> validationPair
- : new DataFileIterable(KDDCupDataModel.getValidationFile(dataFileDirectory))) {
- for (Preference validationPref : validationPair.getFirst()) {
- double estimate = estimatePreference(factorization, validationPref.getUserID(), validationPref.getItemID(),
- factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference());
- double error = validationPref.getValue() - estimate;
- average.addDatum(error * error);
- prefsProcessed++;
- if (prefsProcessed % 100000 == 0) {
- log.info("Computed {} estimations", prefsProcessed);
- }
- }
- }
- log.info("Computed {} estimations, done.", prefsProcessed);
-
- double rmse = Math.sqrt(average.getAverage());
- log.info("RMSE {}", rmse);
-
- log.info("Estimating test preferences...");
- OutputStream out = null;
- try {
- out = new BufferedOutputStream(new FileOutputStream(resultFile));
-
- for (Pair<PreferenceArray,long[]> testPair
- : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) {
- for (Preference testPref : testPair.getFirst()) {
- double estimate = estimatePreference(factorization, testPref.getUserID(), testPref.getItemID(),
- factorizablePreferences.getMinPreference(), factorizablePreferences.getMaxPreference());
- byte result = EstimateConverter.convert(estimate, testPref.getUserID(), testPref.getItemID());
- out.write(result);
- }
- }
- } finally {
- Closeables.close(out, false);
- }
- log.info("wrote estimates to {}, done.", resultFile.getAbsolutePath());
- }
-
- static double estimatePreference(Factorization factorization, long userID, long itemID, float minPreference,
- float maxPreference) throws NoSuchUserException, NoSuchItemException {
- double[] userFeatures = factorization.getUserFeatures(userID);
- double[] itemFeatures = factorization.getItemFeatures(itemID);
- double estimate = 0;
- for (int feature = 0; feature < userFeatures.length; feature++) {
- estimate += userFeatures[feature] * itemFeatures[feature];
- }
- if (estimate < minPreference) {
- estimate = minPreference;
- } else if (estimate > maxPreference) {
- estimate = maxPreference;
- }
- return estimate;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java
deleted file mode 100644
index ce025a9..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/HybridSimilarity.java
+++ /dev/null
@@ -1,62 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.Collection;
-
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.similarity.AbstractItemSimilarity;
-import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
-
-final class HybridSimilarity extends AbstractItemSimilarity {
-
- private final ItemSimilarity cfSimilarity;
- private final ItemSimilarity contentSimilarity;
-
- HybridSimilarity(DataModel dataModel, File dataFileDirectory) throws IOException {
- super(dataModel);
- cfSimilarity = new LogLikelihoodSimilarity(dataModel);
- contentSimilarity = new TrackItemSimilarity(dataFileDirectory);
- }
-
- @Override
- public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
- return contentSimilarity.itemSimilarity(itemID1, itemID2) * cfSimilarity.itemSimilarity(itemID1, itemID2);
- }
-
- @Override
- public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
- double[] result = contentSimilarity.itemSimilarities(itemID1, itemID2s);
- double[] multipliers = cfSimilarity.itemSimilarities(itemID1, itemID2s);
- for (int i = 0; i < result.length; i++) {
- result[i] *= multipliers[i];
- }
- return result;
- }
-
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- cfSimilarity.refresh(alreadyRefreshed);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java
deleted file mode 100644
index 50fd35e..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Callable.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-
-import org.apache.mahout.cf.taste.common.NoSuchItemException;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.TreeMap;
-import java.util.concurrent.Callable;
-import java.util.concurrent.atomic.AtomicInteger;
-
-final class Track2Callable implements Callable<UserResult> {
-
- private static final Logger log = LoggerFactory.getLogger(Track2Callable.class);
- private static final AtomicInteger COUNT = new AtomicInteger();
-
- private final Recommender recommender;
- private final PreferenceArray userTest;
-
- Track2Callable(Recommender recommender, PreferenceArray userTest) {
- this.recommender = recommender;
- this.userTest = userTest;
- }
-
- @Override
- public UserResult call() throws TasteException {
-
- int testSize = userTest.length();
- if (testSize != 6) {
- throw new IllegalArgumentException("Expecting 6 items for user but got " + userTest);
- }
- long userID = userTest.get(0).getUserID();
- TreeMap<Double,Long> estimateToItemID = new TreeMap<>(Collections.reverseOrder());
-
- for (int i = 0; i < testSize; i++) {
- long itemID = userTest.getItemID(i);
- double estimate;
- try {
- estimate = recommender.estimatePreference(userID, itemID);
- } catch (NoSuchItemException nsie) {
- // OK in the sample data provided before the contest, should never happen otherwise
- log.warn("Unknown item {}; OK unless this is the real contest data", itemID);
- continue;
- }
-
- if (!Double.isNaN(estimate)) {
- estimateToItemID.put(estimate, itemID);
- }
- }
-
- Collection<Long> itemIDs = estimateToItemID.values();
- List<Long> topThree = new ArrayList<>(itemIDs);
- if (topThree.size() > 3) {
- topThree = topThree.subList(0, 3);
- } else if (topThree.size() < 3) {
- log.warn("Unable to recommend three items for {}", userID);
- // Some NaNs - just guess at the rest then
- Collection<Long> newItemIDs = new HashSet<>(3);
- newItemIDs.addAll(itemIDs);
- int i = 0;
- while (i < testSize && newItemIDs.size() < 3) {
- newItemIDs.add(userTest.getItemID(i));
- i++;
- }
- topThree = new ArrayList<>(newItemIDs);
- }
- if (topThree.size() != 3) {
- throw new IllegalStateException();
- }
-
- boolean[] result = new boolean[testSize];
- for (int i = 0; i < testSize; i++) {
- result[i] = topThree.contains(userTest.getItemID(i));
- }
-
- if (COUNT.incrementAndGet() % 1000 == 0) {
- log.info("Completed {} users", COUNT.get());
- }
-
- return new UserResult(userID, result);
- }
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java
deleted file mode 100644
index 185a00d..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Recommender.java
+++ /dev/null
@@ -1,100 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.Collection;
-import java.util.List;
-
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.impl.recommender.GenericBooleanPrefItemBasedRecommender;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.recommender.IDRescorer;
-import org.apache.mahout.cf.taste.recommender.RecommendedItem;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
-
-public final class Track2Recommender implements Recommender {
-
- private final Recommender recommender;
-
- public Track2Recommender(DataModel dataModel, File dataFileDirectory) throws TasteException {
- // Change this to whatever you like!
- ItemSimilarity similarity;
- try {
- similarity = new HybridSimilarity(dataModel, dataFileDirectory);
- } catch (IOException ioe) {
- throw new TasteException(ioe);
- }
- recommender = new GenericBooleanPrefItemBasedRecommender(dataModel, similarity);
- }
-
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
- return recommender.recommend(userID, howMany);
- }
-
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
- return recommend(userID, howMany, null, includeKnownItems);
- }
-
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException {
- return recommender.recommend(userID, howMany, rescorer, false);
- }
-
- @Override
- public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
- throws TasteException {
- return recommender.recommend(userID, howMany, rescorer, includeKnownItems);
- }
-
- @Override
- public float estimatePreference(long userID, long itemID) throws TasteException {
- return recommender.estimatePreference(userID, itemID);
- }
-
- @Override
- public void setPreference(long userID, long itemID, float value) throws TasteException {
- recommender.setPreference(userID, itemID, value);
- }
-
- @Override
- public void removePreference(long userID, long itemID) throws TasteException {
- recommender.removePreference(userID, itemID);
- }
-
- @Override
- public DataModel getDataModel() {
- return recommender.getDataModel();
- }
-
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- recommender.refresh(alreadyRefreshed);
- }
-
- @Override
- public String toString() {
- return "Track1Recommender[recommender:" + recommender + ']';
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java
deleted file mode 100644
index 09ade5d..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2RecommenderBuilder.java
+++ /dev/null
@@ -1,33 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-
-import org.apache.mahout.cf.taste.common.TasteException;
-import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.recommender.Recommender;
-
-final class Track2RecommenderBuilder implements RecommenderBuilder {
-
- @Override
- public Recommender buildRecommender(DataModel dataModel) throws TasteException {
- return new Track2Recommender(dataModel, ((KDDCupDataModel) dataModel).getDataFileDirectory());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java
deleted file mode 100644
index 3cbb61c..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/Track2Runner.java
+++ /dev/null
@@ -1,100 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-
-import org.apache.mahout.cf.taste.example.kddcup.DataFileIterable;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.model.PreferenceArray;
-import org.apache.mahout.common.Pair;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.BufferedOutputStream;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.OutputStream;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.concurrent.Future;
-
-/**
- * <p>Runs "track 2" of the KDD Cup competition using whatever recommender is inside {@link Track2Recommender}
- * and attempts to output the result in the correct contest format.</p>
- *
- * <p>Run as: {@code Track2Runner [track 2 data file directory] [output file]}</p>
- */
-public final class Track2Runner {
-
- private static final Logger log = LoggerFactory.getLogger(Track2Runner.class);
-
- private Track2Runner() {
- }
-
- public static void main(String[] args) throws Exception {
-
- File dataFileDirectory = new File(args[0]);
- if (!dataFileDirectory.exists() || !dataFileDirectory.isDirectory()) {
- throw new IllegalArgumentException("Bad data file directory: " + dataFileDirectory);
- }
-
- long start = System.currentTimeMillis();
-
- KDDCupDataModel model = new KDDCupDataModel(KDDCupDataModel.getTrainingFile(dataFileDirectory));
- Track2Recommender recommender = new Track2Recommender(model, dataFileDirectory);
-
- long end = System.currentTimeMillis();
- log.info("Loaded model in {}s", (end - start) / 1000);
- start = end;
-
- Collection<Track2Callable> callables = new ArrayList<>();
- for (Pair<PreferenceArray,long[]> tests : new DataFileIterable(KDDCupDataModel.getTestFile(dataFileDirectory))) {
- PreferenceArray userTest = tests.getFirst();
- callables.add(new Track2Callable(recommender, userTest));
- }
-
- int cores = Runtime.getRuntime().availableProcessors();
- log.info("Running on {} cores", cores);
- ExecutorService executor = Executors.newFixedThreadPool(cores);
- List<Future<UserResult>> futures = executor.invokeAll(callables);
- executor.shutdown();
-
- end = System.currentTimeMillis();
- log.info("Ran recommendations in {}s", (end - start) / 1000);
- start = end;
-
- try (OutputStream out = new BufferedOutputStream(new FileOutputStream(new File(args[1])))){
- long lastUserID = Long.MIN_VALUE;
- for (Future<UserResult> future : futures) {
- UserResult result = future.get();
- long userID = result.getUserID();
- if (userID <= lastUserID) {
- throw new IllegalStateException();
- }
- lastUserID = userID;
- out.write(result.getResultBytes());
- }
- }
-
- end = System.currentTimeMillis();
- log.info("Wrote output in {}s", (end - start) / 1000);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java
deleted file mode 100644
index abd15f8..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackData.java
+++ /dev/null
@@ -1,71 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-
-import java.util.regex.Pattern;
-
-import org.apache.mahout.cf.taste.impl.common.FastIDSet;
-
-final class TrackData {
-
- private static final Pattern PIPE = Pattern.compile("\\|");
- private static final String NO_VALUE = "None";
- static final long NO_VALUE_ID = Long.MIN_VALUE;
- private static final FastIDSet NO_GENRES = new FastIDSet();
-
- private final long trackID;
- private final long albumID;
- private final long artistID;
- private final FastIDSet genreIDs;
-
- TrackData(CharSequence line) {
- String[] tokens = PIPE.split(line);
- trackID = Long.parseLong(tokens[0]);
- albumID = parse(tokens[1]);
- artistID = parse(tokens[2]);
- if (tokens.length > 3) {
- genreIDs = new FastIDSet(tokens.length - 3);
- for (int i = 3; i < tokens.length; i++) {
- genreIDs.add(Long.parseLong(tokens[i]));
- }
- } else {
- genreIDs = NO_GENRES;
- }
- }
-
- private static long parse(String value) {
- return NO_VALUE.equals(value) ? NO_VALUE_ID : Long.parseLong(value);
- }
-
- public long getTrackID() {
- return trackID;
- }
-
- public long getAlbumID() {
- return albumID;
- }
-
- public long getArtistID() {
- return artistID;
- }
-
- public FastIDSet getGenreIDs() {
- return genreIDs;
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java
deleted file mode 100644
index 3012a84..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/TrackItemSimilarity.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.Collection;
-
-import org.apache.mahout.cf.taste.common.Refreshable;
-import org.apache.mahout.cf.taste.example.kddcup.KDDCupDataModel;
-import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
-import org.apache.mahout.cf.taste.impl.common.FastIDSet;
-import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
-import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
-import org.apache.mahout.common.iterator.FileLineIterable;
-
-final class TrackItemSimilarity implements ItemSimilarity {
-
- private final FastByIDMap<TrackData> trackData;
-
- TrackItemSimilarity(File dataFileDirectory) throws IOException {
- trackData = new FastByIDMap<>();
- for (String line : new FileLineIterable(KDDCupDataModel.getTrackFile(dataFileDirectory))) {
- TrackData trackDatum = new TrackData(line);
- trackData.put(trackDatum.getTrackID(), trackDatum);
- }
- }
-
- @Override
- public double itemSimilarity(long itemID1, long itemID2) {
- if (itemID1 == itemID2) {
- return 1.0;
- }
- TrackData data1 = trackData.get(itemID1);
- TrackData data2 = trackData.get(itemID2);
- if (data1 == null || data2 == null) {
- return 0.0;
- }
-
- // Arbitrarily decide that same album means "very similar"
- if (data1.getAlbumID() != TrackData.NO_VALUE_ID && data1.getAlbumID() == data2.getAlbumID()) {
- return 0.9;
- }
- // ... and same artist means "fairly similar"
- if (data1.getArtistID() != TrackData.NO_VALUE_ID && data1.getArtistID() == data2.getArtistID()) {
- return 0.7;
- }
-
- // Tanimoto coefficient similarity based on genre, but maximum value of 0.25
- FastIDSet genres1 = data1.getGenreIDs();
- FastIDSet genres2 = data2.getGenreIDs();
- if (genres1 == null || genres2 == null) {
- return 0.0;
- }
- int intersectionSize = genres1.intersectionSize(genres2);
- if (intersectionSize == 0) {
- return 0.0;
- }
- int unionSize = genres1.size() + genres2.size() - intersectionSize;
- return intersectionSize / (4.0 * unionSize);
- }
-
- @Override
- public double[] itemSimilarities(long itemID1, long[] itemID2s) {
- int length = itemID2s.length;
- double[] result = new double[length];
- for (int i = 0; i < length; i++) {
- result[i] = itemSimilarity(itemID1, itemID2s[i]);
- }
- return result;
- }
-
- @Override
- public long[] allSimilarItemIDs(long itemID) {
- FastIDSet allSimilarItemIDs = new FastIDSet();
- LongPrimitiveIterator allItemIDs = trackData.keySetIterator();
- while (allItemIDs.hasNext()) {
- long possiblySimilarItemID = allItemIDs.nextLong();
- if (!Double.isNaN(itemSimilarity(itemID, possiblySimilarItemID))) {
- allSimilarItemIDs.add(possiblySimilarItemID);
- }
- }
- return allSimilarItemIDs.toArray();
- }
-
- @Override
- public void refresh(Collection<Refreshable> alreadyRefreshed) {
- // do nothing
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java
deleted file mode 100644
index e554d10..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/example/kddcup/track2/UserResult.java
+++ /dev/null
@@ -1,54 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.example.kddcup.track2;
-
-final class UserResult {
-
- private final long userID;
- private final byte[] resultBytes;
-
- UserResult(long userID, boolean[] result) {
-
- this.userID = userID;
-
- int trueCount = 0;
- for (boolean b : result) {
- if (b) {
- trueCount++;
- }
- }
- if (trueCount != 3) {
- throw new IllegalStateException();
- }
-
- resultBytes = new byte[result.length];
- for (int i = 0; i < result.length; i++) {
- resultBytes[i] = (byte) (result[i] ? '1' : '0');
- }
- }
-
- public long getUserID() {
- return userID;
- }
-
- public byte[] getResultBytes() {
- return resultBytes;
- }
-
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java
deleted file mode 100644
index 22f122e..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/hadoop/example/als/netflix/NetflixDatasetConverter.java
+++ /dev/null
@@ -1,140 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.hadoop.example.als.netflix;
-
-import com.google.common.base.Preconditions;
-import org.apache.commons.io.Charsets;
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-import org.apache.mahout.cf.taste.impl.model.GenericPreference;
-import org.apache.mahout.cf.taste.model.Preference;
-import org.apache.mahout.common.iterator.FileLineIterable;
-import org.apache.mahout.common.iterator.FileLineIterator;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.BufferedWriter;
-import java.io.File;
-import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.regex.Pattern;
-
-/** converts the raw files provided by netflix to an appropriate input format */
-public final class NetflixDatasetConverter {
-
- private static final Logger log = LoggerFactory.getLogger(NetflixDatasetConverter.class);
-
- private static final Pattern SEPARATOR = Pattern.compile(",");
- private static final String MOVIE_DENOTER = ":";
- private static final String TAB = "\t";
- private static final String NEWLINE = "\n";
-
- private NetflixDatasetConverter() {
- }
-
- public static void main(String[] args) throws IOException {
-
- if (args.length != 4) {
- System.err.println("Usage: NetflixDatasetConverter /path/to/training_set/ /path/to/qualifying.txt "
- + "/path/to/judging.txt /path/to/destination");
- return;
- }
-
- String trainingDataDir = args[0];
- String qualifyingTxt = args[1];
- String judgingTxt = args[2];
- Path outputPath = new Path(args[3]);
-
- Configuration conf = new Configuration();
- FileSystem fs = FileSystem.get(outputPath.toUri(), conf);
-
- Preconditions.checkArgument(trainingDataDir != null, "Training Data location needs to be specified");
- log.info("Creating training set at {}/trainingSet/ratings.tsv ...", outputPath);
- try (BufferedWriter writer =
- new BufferedWriter(
- new OutputStreamWriter(
- fs.create(new Path(outputPath, "trainingSet/ratings.tsv")), Charsets.UTF_8))){
-
- int ratingsProcessed = 0;
- for (File movieRatings : new File(trainingDataDir).listFiles()) {
- try (FileLineIterator lines = new FileLineIterator(movieRatings)) {
- boolean firstLineRead = false;
- String movieID = null;
- while (lines.hasNext()) {
- String line = lines.next();
- if (firstLineRead) {
- String[] tokens = SEPARATOR.split(line);
- String userID = tokens[0];
- String rating = tokens[1];
- writer.write(userID + TAB + movieID + TAB + rating + NEWLINE);
- ratingsProcessed++;
- if (ratingsProcessed % 1000000 == 0) {
- log.info("{} ratings processed...", ratingsProcessed);
- }
- } else {
- movieID = line.replaceAll(MOVIE_DENOTER, "");
- firstLineRead = true;
- }
- }
- }
-
- }
- log.info("{} ratings processed. done.", ratingsProcessed);
- }
-
- log.info("Reading probes...");
- List<Preference> probes = new ArrayList<>(2817131);
- long currentMovieID = -1;
- for (String line : new FileLineIterable(new File(qualifyingTxt))) {
- if (line.contains(MOVIE_DENOTER)) {
- currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, ""));
- } else {
- long userID = Long.parseLong(SEPARATOR.split(line)[0]);
- probes.add(new GenericPreference(userID, currentMovieID, 0));
- }
- }
- log.info("{} probes read...", probes.size());
-
- log.info("Reading ratings, creating probe set at {}/probeSet/ratings.tsv ...", outputPath);
- try (BufferedWriter writer =
- new BufferedWriter(new OutputStreamWriter(
- fs.create(new Path(outputPath, "probeSet/ratings.tsv")), Charsets.UTF_8))){
- int ratingsProcessed = 0;
- for (String line : new FileLineIterable(new File(judgingTxt))) {
- if (line.contains(MOVIE_DENOTER)) {
- currentMovieID = Long.parseLong(line.replaceAll(MOVIE_DENOTER, ""));
- } else {
- float rating = Float.parseFloat(SEPARATOR.split(line)[0]);
- Preference pref = probes.get(ratingsProcessed);
- Preconditions.checkState(pref.getItemID() == currentMovieID);
- ratingsProcessed++;
- writer.write(pref.getUserID() + TAB + pref.getItemID() + TAB + rating + NEWLINE);
- if (ratingsProcessed % 1000000 == 0) {
- log.info("{} ratings processed...", ratingsProcessed);
- }
- }
- }
- log.info("{} ratings processed. done.", ratingsProcessed);
- }
- }
-
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java
deleted file mode 100644
index 8021d00..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/BatchItemSimilaritiesGroupLens.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.similarity.precompute.example;
-
-import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
-import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
-import org.apache.mahout.cf.taste.impl.similarity.precompute.FileSimilarItemsWriter;
-import org.apache.mahout.cf.taste.impl.similarity.precompute.MultithreadedBatchItemSimilarities;
-import org.apache.mahout.cf.taste.model.DataModel;
-import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
-import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities;
-
-import java.io.File;
-
-/**
- * Example that precomputes all item similarities of the Movielens1M dataset
- *
- * Usage: download movielens1M from http://www.grouplens.org/node/73 , unzip it and invoke this code with the path
- * to the ratings.dat file as argument
- *
- */
-public final class BatchItemSimilaritiesGroupLens {
-
- private BatchItemSimilaritiesGroupLens() {}
-
- public static void main(String[] args) throws Exception {
-
- if (args.length != 1) {
- System.err.println("Need path to ratings.dat of the movielens1M dataset as argument!");
- System.exit(-1);
- }
-
- File resultFile = new File(System.getProperty("java.io.tmpdir"), "similarities.csv");
- if (resultFile.exists()) {
- resultFile.delete();
- }
-
- DataModel dataModel = new GroupLensDataModel(new File(args[0]));
- ItemBasedRecommender recommender = new GenericItemBasedRecommender(dataModel,
- new LogLikelihoodSimilarity(dataModel));
- BatchItemSimilarities batch = new MultithreadedBatchItemSimilarities(recommender, 5);
-
- int numSimilarities = batch.computeItemSimilarities(Runtime.getRuntime().availableProcessors(), 1,
- new FileSimilarItemsWriter(resultFile));
-
- System.out.println("Computed " + numSimilarities + " similarities for " + dataModel.getNumItems() + " items "
- + "and saved them to " + resultFile.getAbsolutePath());
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java
deleted file mode 100644
index 7ee9b17..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/example/GroupLensDataModel.java
+++ /dev/null
@@ -1,96 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.cf.taste.similarity.precompute.example;
-
-import com.google.common.io.Files;
-import com.google.common.io.InputSupplier;
-import com.google.common.io.Resources;
-import java.io.File;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.io.OutputStreamWriter;
-import java.io.Writer;
-import java.net.URL;
-import java.util.regex.Pattern;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.cf.taste.impl.model.file.FileDataModel;
-import org.apache.mahout.common.iterator.FileLineIterable;
-
-public final class GroupLensDataModel extends FileDataModel {
-
- private static final String COLON_DELIMTER = "::";
- private static final Pattern COLON_DELIMITER_PATTERN = Pattern.compile(COLON_DELIMTER);
-
- public GroupLensDataModel() throws IOException {
- this(readResourceToTempFile("/org/apache/mahout/cf/taste/example/grouplens/ratings.dat"));
- }
-
- /**
- * @param ratingsFile GroupLens ratings.dat file in its native format
- * @throws IOException if an error occurs while reading or writing files
- */
- public GroupLensDataModel(File ratingsFile) throws IOException {
- super(convertGLFile(ratingsFile));
- }
-
- private static File convertGLFile(File originalFile) throws IOException {
- // Now translate the file; remove commas, then convert "::" delimiter to comma
- File resultFile = new File(new File(System.getProperty("java.io.tmpdir")), "ratings.txt");
- if (resultFile.exists()) {
- resultFile.delete();
- }
- try (Writer writer = new OutputStreamWriter(new FileOutputStream(resultFile), Charsets.UTF_8)){
- for (String line : new FileLineIterable(originalFile, false)) {
- int lastDelimiterStart = line.lastIndexOf(COLON_DELIMTER);
- if (lastDelimiterStart < 0) {
- throw new IOException("Unexpected input format on line: " + line);
- }
- String subLine = line.substring(0, lastDelimiterStart);
- String convertedLine = COLON_DELIMITER_PATTERN.matcher(subLine).replaceAll(",");
- writer.write(convertedLine);
- writer.write('\n');
- }
- } catch (IOException ioe) {
- resultFile.delete();
- throw ioe;
- }
- return resultFile;
- }
-
- public static File readResourceToTempFile(String resourceName) throws IOException {
- InputSupplier<? extends InputStream> inSupplier;
- try {
- URL resourceURL = Resources.getResource(GroupLensDataModel.class, resourceName);
- inSupplier = Resources.newInputStreamSupplier(resourceURL);
- } catch (IllegalArgumentException iae) {
- File resourceFile = new File("src/main/java" + resourceName);
- inSupplier = Files.newInputStreamSupplier(resourceFile);
- }
- File tempFile = File.createTempFile("taste", null);
- tempFile.deleteOnExit();
- Files.copy(inSupplier, tempFile);
- return tempFile;
- }
-
- @Override
- public String toString() {
- return "GroupLensDataModel";
- }
-
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java
deleted file mode 100644
index 5cec51c..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/NewsgroupHelper.java
+++ /dev/null
@@ -1,128 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier;
-
-import com.google.common.collect.ConcurrentHashMultiset;
-import com.google.common.collect.Multiset;
-import com.google.common.io.Closeables;
-import com.google.common.io.Files;
-import org.apache.commons.io.Charsets;
-import org.apache.lucene.analysis.Analyzer;
-import org.apache.lucene.analysis.TokenStream;
-import org.apache.lucene.analysis.standard.StandardAnalyzer;
-import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
-import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.math.RandomAccessSparseVector;
-import org.apache.mahout.math.Vector;
-import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
-import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
-import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
-
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.Reader;
-import java.io.StringReader;
-import java.text.SimpleDateFormat;
-import java.util.Collection;
-import java.util.Date;
-import java.util.Locale;
-import java.util.Random;
-
-public final class NewsgroupHelper {
-
- private static final SimpleDateFormat[] DATE_FORMATS = {
- new SimpleDateFormat("", Locale.ENGLISH),
- new SimpleDateFormat("MMM-yyyy", Locale.ENGLISH),
- new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss", Locale.ENGLISH)
- };
-
- public static final int FEATURES = 10000;
- // 1997-01-15 00:01:00 GMT
- private static final long DATE_REFERENCE = 853286460;
- private static final long MONTH = 30 * 24 * 3600;
- private static final long WEEK = 7 * 24 * 3600;
-
- private final Random rand = RandomUtils.getRandom();
- private final Analyzer analyzer = new StandardAnalyzer();
- private final FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
- private final FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
-
- public FeatureVectorEncoder getEncoder() {
- return encoder;
- }
-
- public FeatureVectorEncoder getBias() {
- return bias;
- }
-
- public Random getRandom() {
- return rand;
- }
-
- public Vector encodeFeatureVector(File file, int actual, int leakType, Multiset<String> overallCounts)
- throws IOException {
- long date = (long) (1000 * (DATE_REFERENCE + actual * MONTH + 1 * WEEK * rand.nextDouble()));
- Multiset<String> words = ConcurrentHashMultiset.create();
-
- try (BufferedReader reader = Files.newReader(file, Charsets.UTF_8)) {
- String line = reader.readLine();
- Reader dateString = new StringReader(DATE_FORMATS[leakType % 3].format(new Date(date)));
- countWords(analyzer, words, dateString, overallCounts);
- while (line != null && !line.isEmpty()) {
- boolean countHeader = (
- line.startsWith("From:") || line.startsWith("Subject:")
- || line.startsWith("Keywords:") || line.startsWith("Summary:")) && leakType < 6;
- do {
- Reader in = new StringReader(line);
- if (countHeader) {
- countWords(analyzer, words, in, overallCounts);
- }
- line = reader.readLine();
- } while (line != null && line.startsWith(" "));
- }
- if (leakType < 3) {
- countWords(analyzer, words, reader, overallCounts);
- }
- }
-
- Vector v = new RandomAccessSparseVector(FEATURES);
- bias.addToVector("", 1, v);
- for (String word : words.elementSet()) {
- encoder.addToVector(word, Math.log1p(words.count(word)), v);
- }
-
- return v;
- }
-
- public static void countWords(Analyzer analyzer,
- Collection<String> words,
- Reader in,
- Multiset<String> overallCounts) throws IOException {
- TokenStream ts = analyzer.tokenStream("text", in);
- ts.addAttribute(CharTermAttribute.class);
- ts.reset();
- while (ts.incrementToken()) {
- String s = ts.getAttribute(CharTermAttribute.class).toString();
- words.add(s);
- }
- overallCounts.addAll(words);
- ts.end();
- Closeables.close(ts, true);
- }
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java
deleted file mode 100644
index 16e9d80..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailMapper.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.email;
-
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.WritableComparable;
-import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.math.VectorWritable;
-
-import java.io.IOException;
-import java.util.Locale;
-import java.util.regex.Pattern;
-
-/**
- * Convert the labels created by the {@link org.apache.mahout.utils.email.MailProcessor} to one consumable
- * by the classifiers
- */
-public class PrepEmailMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
-
- private static final Pattern DASH_DOT = Pattern.compile("-|\\.");
- private static final Pattern SLASH = Pattern.compile("\\/");
-
- private boolean useListName = false; //if true, use the project name and the list name in label creation
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- useListName = Boolean.parseBoolean(context.getConfiguration().get(PrepEmailVectorsDriver.USE_LIST_NAME));
- }
-
- @Override
- protected void map(WritableComparable<?> key, VectorWritable value, Context context)
- throws IOException, InterruptedException {
- String input = key.toString();
- ///Example: /cocoon.apache.org/dev/200307.gz/001401c3414f$8394e160$1e01a8c0@WRPO
- String[] splits = SLASH.split(input);
- //we need the first two splits;
- if (splits.length >= 3) {
- StringBuilder bldr = new StringBuilder();
- bldr.append(escape(splits[1]));
- if (useListName) {
- bldr.append('_').append(escape(splits[2]));
- }
- context.write(new Text(bldr.toString()), value);
- }
-
- }
-
- private static String escape(CharSequence value) {
- return DASH_DOT.matcher(value).replaceAll("_").toLowerCase(Locale.ENGLISH);
- }
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java
deleted file mode 100644
index da6e613..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailReducer.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.email;
-
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Reducer;
-import org.apache.mahout.math.VectorWritable;
-
-import java.io.IOException;
-import java.util.Iterator;
-
-public class PrepEmailReducer extends Reducer<Text, VectorWritable, Text, VectorWritable> {
-
- private long maxItemsPerLabel = 10000;
-
- @Override
- protected void setup(Context context) throws IOException, InterruptedException {
- maxItemsPerLabel = Long.parseLong(context.getConfiguration().get(PrepEmailVectorsDriver.ITEMS_PER_CLASS));
- }
-
- @Override
- protected void reduce(Text key, Iterable<VectorWritable> values, Context context)
- throws IOException, InterruptedException {
- //TODO: support randomization? Likely not needed due to the SplitInput utility which does random selection
- long i = 0;
- Iterator<VectorWritable> iterator = values.iterator();
- while (i < maxItemsPerLabel && iterator.hasNext()) {
- context.write(key, iterator.next());
- i++;
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java
deleted file mode 100644
index 8fba739..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/email/PrepEmailVectorsDriver.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.email;
-
-import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.mapreduce.Job;
-import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
-import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
-import org.apache.hadoop.util.ToolRunner;
-import org.apache.mahout.common.AbstractJob;
-import org.apache.mahout.common.HadoopUtil;
-import org.apache.mahout.common.commandline.DefaultOptionCreator;
-import org.apache.mahout.math.VectorWritable;
-
-import java.util.List;
-import java.util.Map;
-
-/**
- * Convert the labels generated by {@link org.apache.mahout.text.SequenceFilesFromMailArchives} and
- * {@link org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles} to ones consumable by the classifiers. We do this
- * here b/c if it is done in the creation of sparse vectors, the Reducer collapses all the vectors.
- */
-public class PrepEmailVectorsDriver extends AbstractJob {
-
- public static final String ITEMS_PER_CLASS = "itemsPerClass";
- public static final String USE_LIST_NAME = "USE_LIST_NAME";
-
- public static void main(String[] args) throws Exception {
- ToolRunner.run(new Configuration(), new PrepEmailVectorsDriver(), args);
- }
-
- @Override
- public int run(String[] args) throws Exception {
- addInputOption();
- addOutputOption();
- addOption(DefaultOptionCreator.overwriteOption().create());
- addOption("maxItemsPerLabel", "mipl", "The maximum number of items per label. Can be useful for making the "
- + "training sets the same size", String.valueOf(100000));
- addOption(buildOption("useListName", "ul", "Use the name of the list as part of the label. If not set, then "
- + "just use the project name", false, false, "false"));
- Map<String,List<String>> parsedArgs = parseArguments(args);
- if (parsedArgs == null) {
- return -1;
- }
-
- Path input = getInputPath();
- Path output = getOutputPath();
- if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
- HadoopUtil.delete(getConf(), output);
- }
- Job convertJob = prepareJob(input, output, SequenceFileInputFormat.class, PrepEmailMapper.class, Text.class,
- VectorWritable.class, PrepEmailReducer.class, Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
- convertJob.getConfiguration().set(ITEMS_PER_CLASS, getOption("maxItemsPerLabel"));
- convertJob.getConfiguration().set(USE_LIST_NAME, String.valueOf(hasOption("useListName")));
-
- boolean succeeded = convertJob.waitForCompletion(true);
- return succeeded ? 0 : -1;
- }
-}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
----------------------------------------------------------------------
diff --git a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java b/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
deleted file mode 100644
index 9c0ef56..0000000
--- a/community/mahout-mr/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
+++ /dev/null
@@ -1,277 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.mahout.classifier.sequencelearning.hmm;
-
-import com.google.common.io.Resources;
-import org.apache.commons.io.Charsets;
-import org.apache.mahout.math.Matrix;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.IOException;
-import java.net.URL;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map;
-import java.util.regex.Pattern;
-
-/**
- * This class implements a sample program that uses a pre-tagged training data
- * set to train an HMM model as a POS tagger. The training data is automatically
- * downloaded from the following URL:
- * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt It then
- * trains an HMM Model using supervised learning and tests the model on the
- * following test data set:
- * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further
- * details regarding the data files can be found at
- * http://flexcrfs.sourceforge.net/#Case_Study
- */
-public final class PosTagger {
-
- private static final Logger log = LoggerFactory.getLogger(PosTagger.class);
-
- private static final Pattern SPACE = Pattern.compile(" ");
- private static final Pattern SPACES = Pattern.compile("[ ]+");
-
- /**
- * No public constructors for utility classes.
- */
- private PosTagger() {
- // nothing to do here really.
- }
-
- /**
- * Model trained in the example.
- */
- private static HmmModel taggingModel;
-
- /**
- * Map for storing the IDs for the POS tags (hidden states)
- */
- private static Map<String, Integer> tagIDs;
-
- /**
- * Counter for the next assigned POS tag ID The value of 0 is reserved for
- * "unknown POS tag"
- */
- private static int nextTagId;
-
- /**
- * Map for storing the IDs for observed words (observed states)
- */
- private static Map<String, Integer> wordIDs;
-
- /**
- * Counter for the next assigned word ID The value of 0 is reserved for
- * "unknown word"
- */
- private static int nextWordId = 1; // 0 is reserved for "unknown word"
-
- /**
- * Used for storing a list of POS tags of read sentences.
- */
- private static List<int[]> hiddenSequences;
-
- /**
- * Used for storing a list of word tags of read sentences.
- */
- private static List<int[]> observedSequences;
-
- /**
- * number of read lines
- */
- private static int readLines;
-
- /**
- * Given an URL, this function fetches the data file, parses it, assigns POS
- * Tag/word IDs and fills the hiddenSequences/observedSequences lists with
- * data from those files. The data is expected to be in the following format
- * (one word per line): word pos-tag np-tag sentences are closed with the .
- * pos tag
- *
- * @param url Where the data file is stored
- * @param assignIDs Should IDs for unknown words/tags be assigned? (Needed for
- * training data, not needed for test data)
- * @throws IOException in case data file cannot be read.
- */
- private static void readFromURL(String url, boolean assignIDs) throws IOException {
- // initialize the data structure
- hiddenSequences = new LinkedList<>();
- observedSequences = new LinkedList<>();
- readLines = 0;
-
- // now read line by line of the input file
- List<Integer> observedSequence = new LinkedList<>();
- List<Integer> hiddenSequence = new LinkedList<>();
-
- for (String line :Resources.readLines(new URL(url), Charsets.UTF_8)) {
- if (line.isEmpty()) {
- // new sentence starts
- int[] observedSequenceArray = new int[observedSequence.size()];
- int[] hiddenSequenceArray = new int[hiddenSequence.size()];
- for (int i = 0; i < observedSequence.size(); ++i) {
- observedSequenceArray[i] = observedSequence.get(i);
- hiddenSequenceArray[i] = hiddenSequence.get(i);
- }
- // now register those arrays
- hiddenSequences.add(hiddenSequenceArray);
- observedSequences.add(observedSequenceArray);
- // and reset the linked lists
- observedSequence.clear();
- hiddenSequence.clear();
- continue;
- }
- readLines++;
- // we expect the format [word] [POS tag] [NP tag]
- String[] tags = SPACE.split(line);
- // when analyzing the training set, assign IDs
- if (assignIDs) {
- if (!wordIDs.containsKey(tags[0])) {
- wordIDs.put(tags[0], nextWordId++);
- }
- if (!tagIDs.containsKey(tags[1])) {
- tagIDs.put(tags[1], nextTagId++);
- }
- }
- // determine the IDs
- Integer wordID = wordIDs.get(tags[0]);
- Integer tagID = tagIDs.get(tags[1]);
- // now construct the current sequence
- if (wordID == null) {
- observedSequence.add(0);
- } else {
- observedSequence.add(wordID);
- }
-
- if (tagID == null) {
- hiddenSequence.add(0);
- } else {
- hiddenSequence.add(tagID);
- }
- }
-
- // if there is still something in the pipe, register it
- if (!observedSequence.isEmpty()) {
- int[] observedSequenceArray = new int[observedSequence.size()];
- int[] hiddenSequenceArray = new int[hiddenSequence.size()];
- for (int i = 0; i < observedSequence.size(); ++i) {
- observedSequenceArray[i] = observedSequence.get(i);
- hiddenSequenceArray[i] = hiddenSequence.get(i);
- }
- // now register those arrays
- hiddenSequences.add(hiddenSequenceArray);
- observedSequences.add(observedSequenceArray);
- }
- }
-
- private static void trainModel(String trainingURL) throws IOException {
- tagIDs = new HashMap<>(44); // we expect 44 distinct tags
- wordIDs = new HashMap<>(19122); // we expect 19122
- // distinct words
- log.info("Reading and parsing training data file from URL: {}", trainingURL);
- long start = System.currentTimeMillis();
- readFromURL(trainingURL, true);
- long end = System.currentTimeMillis();
- double duration = (end - start) / 1000.0;
- log.info("Parsing done in {} seconds!", duration);
- log.info("Read {} lines containing {} sentences with a total of {} distinct words and {} distinct POS tags.",
- readLines, hiddenSequences.size(), nextWordId - 1, nextTagId - 1);
- start = System.currentTimeMillis();
- taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId,
- hiddenSequences, observedSequences, 0.05);
- // we have to adjust the model a bit,
- // since we assume a higher probability that a given unknown word is NNP
- // than anything else
- Matrix emissions = taggingModel.getEmissionMatrix();
- for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i) {
- emissions.setQuick(i, 0, 0.1 / taggingModel.getNrOfHiddenStates());
- }
- int nnptag = tagIDs.get("NNP");
- emissions.setQuick(nnptag, 0, 1 / (double) taggingModel.getNrOfHiddenStates());
- // re-normalize the emission probabilities
- HmmUtils.normalizeModel(taggingModel);
- // now register the names
- taggingModel.registerHiddenStateNames(tagIDs);
- taggingModel.registerOutputStateNames(wordIDs);
- end = System.currentTimeMillis();
- duration = (end - start) / 1000.0;
- log.info("Trained HMM models in {} seconds!", duration);
- }
-
- private static void testModel(String testingURL) throws IOException {
- log.info("Reading and parsing test data file from URL: {}", testingURL);
- long start = System.currentTimeMillis();
- readFromURL(testingURL, false);
- long end = System.currentTimeMillis();
- double duration = (end - start) / 1000.0;
- log.info("Parsing done in {} seconds!", duration);
- log.info("Read {} lines containing {} sentences.", readLines, hiddenSequences.size());
-
- start = System.currentTimeMillis();
- int errorCount = 0;
- int totalCount = 0;
- for (int i = 0; i < observedSequences.size(); ++i) {
- // fetch the viterbi path as the POS tag for this observed sequence
- int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences.get(i), false);
- // compare with the expected
- int[] posExpected = hiddenSequences.get(i);
- for (int j = 0; j < posExpected.length; ++j) {
- totalCount++;
- if (posEstimate[j] != posExpected[j]) {
- errorCount++;
- }
- }
- }
- end = System.currentTimeMillis();
- duration = (end - start) / 1000.0;
- log.info("POS tagged test file in {} seconds!", duration);
- double errorRate = (double) errorCount / totalCount;
- log.info("Tagged the test file with an error rate of: {}", errorRate);
- }
-
- private static List<String> tagSentence(String sentence) {
- // first, we need to isolate all punctuation characters, so that they
- // can be recognized
- sentence = sentence.replaceAll("[,.!?:;\"]", " $0 ");
- sentence = sentence.replaceAll("''", " '' ");
- // now we tokenize the sentence
- String[] tokens = SPACES.split(sentence);
- // now generate the observed sequence
- int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays.asList(tokens), true, 0);
- // POS tag this observedSequence
- int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence, false);
- // and now decode the tag names
- return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false, null);
- }
-
- public static void main(String[] args) throws IOException {
- // generate the model from URL
- trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt");
- testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt");
- // tag an exemplary sentence
- String test = "McDonalds is a huge company with many employees .";
- String[] testWords = SPACE.split(test);
- List<String> posTags = tagSentence(test);
- for (int i = 0; i < posTags.size(); ++i) {
- log.info("{}[{}]", testWords[i], posTags.get(i));
- }
- }
-
-}