You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:07:32 UTC
[01/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Repository: mahout
Updated Branches:
refs/heads/master 0853c069f -> b988c493b
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasuresTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasuresTest.java b/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasuresTest.java
new file mode 100644
index 0000000..c8a8c51
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasuresTest.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class VectorSimilarityMeasuresTest extends MahoutTestCase {
+
+ static double distributedSimilarity(double[] one,
+ double[] two,
+ Class<? extends VectorSimilarityMeasure> similarityMeasureClass) {
+ double rand = computeSimilarity(one, two, similarityMeasureClass, new RandomAccessSparseVector(one.length));
+ double seq = computeSimilarity(one, two, similarityMeasureClass, new SequentialAccessSparseVector(one.length));
+ double dense = computeSimilarity(one, two, similarityMeasureClass, new DenseVector(one.length));
+ assertEquals(seq, rand, 1.0e-10);
+ assertEquals(seq, dense, 1.0e-10);
+ assertEquals(dense, rand, 1.0e-10);
+ return seq;
+ }
+
+ private static double computeSimilarity(double[] one, double[] two,
+ Class<? extends VectorSimilarityMeasure> similarityMeasureClass,
+ Vector like) {
+ VectorSimilarityMeasure similarityMeasure = ClassUtils.instantiateAs(similarityMeasureClass,
+ VectorSimilarityMeasure.class);
+ Vector oneNormalized = similarityMeasure.normalize(asVector(one, like));
+ Vector twoNormalized = similarityMeasure.normalize(asVector(two, like));
+
+ double normOne = similarityMeasure.norm(oneNormalized);
+ double normTwo = similarityMeasure.norm(twoNormalized);
+
+ double dot = 0;
+ for (int n = 0; n < one.length; n++) {
+ if (oneNormalized.get(n) != 0 && twoNormalized.get(n) != 0) {
+ dot += similarityMeasure.aggregate(oneNormalized.get(n), twoNormalized.get(n));
+ }
+ }
+
+ return similarityMeasure.similarity(dot, normOne, normTwo, one.length);
+ }
+
+ static Vector asVector(double[] values, Vector like) {
+ Vector vector = like.like();
+ for (int dim = 0; dim < values.length; dim++) {
+ vector.set(dim, values[dim]);
+ }
+ return vector;
+ }
+
+ @Test
+ public void testCooccurrenceCountSimilarity() {
+ double similarity = distributedSimilarity(
+ new double[] { 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0 },
+ new double[] { 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1 }, CooccurrenceCountSimilarity.class);
+
+ assertEquals(5.0, similarity, 0);
+ }
+
+ @Test
+ public void testTanimotoCoefficientSimilarity() {
+ double similarity = distributedSimilarity(
+ new double[] { 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0 },
+ new double[] { 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1 }, TanimotoCoefficientSimilarity.class);
+
+ assertEquals(0.454545455, similarity, EPSILON);
+ }
+
+ @Test
+ public void testCityblockSimilarity() {
+ double similarity = distributedSimilarity(
+ new double[] { 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0 },
+ new double[] { 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1 }, CityBlockSimilarity.class);
+
+ assertEquals(0.142857143, similarity, EPSILON);
+ }
+
+ @Test
+ public void testLoglikelihoodSimilarity() {
+ double similarity = distributedSimilarity(
+ new double[] { 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0 },
+ new double[] { 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1 }, LoglikelihoodSimilarity.class);
+
+ assertEquals(0.03320155369284261, similarity, EPSILON);
+ }
+
+ @Test
+ public void testCosineSimilarity() {
+ double similarity = distributedSimilarity(
+ new double[] { 0, 2, 0, 0, 8, 3, 0, 6, 0, 1, 2, 2, 0 },
+ new double[] { 3, 0, 0, 0, 7, 0, 2, 2, 1, 3, 2, 1, 1 }, CosineSimilarity.class);
+
+ assertEquals(0.769846046, similarity, EPSILON);
+ }
+
+ @Test
+ public void testPearsonCorrelationSimilarity() {
+ double similarity = distributedSimilarity(
+ new double[] { 0, 2, 0, 0, 8, 3, 0, 6, 0, 1, 1, 2, 1 },
+ new double[] { 3, 0, 0, 0, 7, 0, 2, 2, 1, 3, 2, 4, 3 }, PearsonCorrelationSimilarity.class);
+
+ assertEquals(0.5303300858899108, similarity, EPSILON);
+ }
+
+ @Test
+ public void testEuclideanDistanceSimilarity() {
+ double similarity = distributedSimilarity(
+ new double[] { 0, 2, 0, 0, 8, 3, 0, 6, 0, 1, 1, 2, 1 },
+ new double[] { 3, 0, 0, 0, 7, 0, 2, 2, 1, 3, 2, 4, 4 }, EuclideanDistanceSimilarity.class);
+
+ assertEquals(0.11268865367232477, similarity, EPSILON);
+ }
+}
[16/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java
new file mode 100644
index 0000000..61a9f56
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/ProjectionSearch.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.BoundType;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import com.google.common.collect.TreeMultiset;
+import org.apache.mahout.math.random.RandomProjector;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.random.WeightedThing;
+
+/**
+ * Does approximate nearest neighbor dudes search by projecting the data.
+ */
+public class ProjectionSearch extends UpdatableSearcher {
+
+ /**
+ * A lists of tree sets containing the scalar projections of each vector.
+ * The elements in a TreeMultiset are WeightedThing<Integer>, where the weight is the scalar
+ * projection of the vector at the index pointed to by the Integer from the referenceVectors list
+ * on the basis vector whose index is the same as the index of the TreeSet in the List.
+ */
+ private List<TreeMultiset<WeightedThing<Vector>>> scalarProjections;
+
+ /**
+ * The list of random normalized projection vectors forming a basis.
+ * The TreeSet of scalar projections at index i in scalarProjections corresponds to the vector
+ * at index i from basisVectors.
+ */
+ private Matrix basisMatrix;
+
+ /**
+ * The number of elements to consider on both sides in the ball around the vector found by the
+ * search in a TreeSet from scalarProjections.
+ */
+ private final int searchSize;
+
+ private final int numProjections;
+ private boolean initialized = false;
+
+ private void initialize(int numDimensions) {
+ if (initialized) {
+ return;
+ }
+ initialized = true;
+ basisMatrix = RandomProjector.generateBasisNormal(numProjections, numDimensions);
+ scalarProjections = Lists.newArrayList();
+ for (int i = 0; i < numProjections; ++i) {
+ scalarProjections.add(TreeMultiset.<WeightedThing<Vector>>create());
+ }
+ }
+
+ public ProjectionSearch(DistanceMeasure distanceMeasure, int numProjections, int searchSize) {
+ super(distanceMeasure);
+ Preconditions.checkArgument(numProjections > 0 && numProjections < 100,
+ "Unreasonable value for number of projections. Must be: 0 < numProjections < 100");
+
+ this.searchSize = searchSize;
+ this.numProjections = numProjections;
+ }
+
+ /**
+ * Adds a WeightedVector into the set of projections for later searching.
+ * @param vector The WeightedVector to add.
+ */
+ @Override
+ public void add(Vector vector) {
+ initialize(vector.size());
+ Vector projection = basisMatrix.times(vector);
+ // Add the the new vector and the projected distance to each set separately.
+ int i = 0;
+ for (TreeMultiset<WeightedThing<Vector>> s : scalarProjections) {
+ s.add(new WeightedThing<>(vector, projection.get(i++)));
+ }
+ int numVectors = scalarProjections.get(0).size();
+ for (TreeMultiset<WeightedThing<Vector>> s : scalarProjections) {
+ Preconditions.checkArgument(s.size() == numVectors, "Number of vectors in projection sets "
+ + "differ");
+ double firstWeight = s.firstEntry().getElement().getWeight();
+ for (WeightedThing<Vector> w : s) {
+ Preconditions.checkArgument(firstWeight <= w.getWeight(), "Weights not in non-decreasing "
+ + "order");
+ firstWeight = w.getWeight();
+ }
+ }
+ }
+
+ /**
+ * Returns the number of scalarProjections that we can search
+ * @return The number of scalarProjections added to the search so far.
+ */
+ @Override
+ public int size() {
+ if (scalarProjections == null) {
+ return 0;
+ }
+ return scalarProjections.get(0).size();
+ }
+
+ /**
+ * Searches for the query vector returning the closest limit referenceVectors.
+ *
+ * @param query the vector to search for.
+ * @param limit the number of results to return.
+ * @return a list of Vectors wrapped in WeightedThings where the "thing"'s weight is the
+ * distance.
+ */
+ @Override
+ public List<WeightedThing<Vector>> search(Vector query, int limit) {
+ Set<Vector> candidates = Sets.newHashSet();
+
+ Iterator<? extends Vector> projections = basisMatrix.iterator();
+ for (TreeMultiset<WeightedThing<Vector>> v : scalarProjections) {
+ Vector basisVector = projections.next();
+ WeightedThing<Vector> projectedQuery = new WeightedThing<>(query,
+ query.dot(basisVector));
+ for (WeightedThing<Vector> candidate : Iterables.concat(
+ Iterables.limit(v.tailMultiset(projectedQuery, BoundType.CLOSED), searchSize),
+ Iterables.limit(v.headMultiset(projectedQuery, BoundType.OPEN).descendingMultiset(), searchSize))) {
+ candidates.add(candidate.getValue());
+ }
+ }
+
+ // If searchSize * scalarProjections.size() is small enough not to cause much memory pressure,
+ // this is probably just as fast as a priority queue here.
+ List<WeightedThing<Vector>> top = Lists.newArrayList();
+ for (Vector candidate : candidates) {
+ top.add(new WeightedThing<>(candidate, distanceMeasure.distance(query, candidate)));
+ }
+ Collections.sort(top);
+ return top.subList(0, Math.min(limit, top.size()));
+ }
+
+ /**
+ * Returns the closest vector to the query.
+ * When only one the nearest vector is needed, use this method, NOT search(query, limit) because
+ * it's faster (less overhead).
+ *
+ * @param query the vector to search for
+ * @param differentThanQuery if true, returns the closest vector different than the query (this
+ * only matters if the query is among the searched vectors), otherwise,
+ * returns the closest vector to the query (even the same vector).
+ * @return the weighted vector closest to the query
+ */
+ @Override
+ public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) {
+ double bestDistance = Double.POSITIVE_INFINITY;
+ Vector bestVector = null;
+
+ Iterator<? extends Vector> projections = basisMatrix.iterator();
+ for (TreeMultiset<WeightedThing<Vector>> v : scalarProjections) {
+ Vector basisVector = projections.next();
+ WeightedThing<Vector> projectedQuery = new WeightedThing<>(query, query.dot(basisVector));
+ for (WeightedThing<Vector> candidate : Iterables.concat(
+ Iterables.limit(v.tailMultiset(projectedQuery, BoundType.CLOSED), searchSize),
+ Iterables.limit(v.headMultiset(projectedQuery, BoundType.OPEN).descendingMultiset(), searchSize))) {
+ double distance = distanceMeasure.distance(query, candidate.getValue());
+ if (distance < bestDistance && (!differentThanQuery || !candidate.getValue().equals(query))) {
+ bestDistance = distance;
+ bestVector = candidate.getValue();
+ }
+ }
+ }
+
+ return new WeightedThing<>(bestVector, bestDistance);
+ }
+
+ @Override
+ public Iterator<Vector> iterator() {
+ return new AbstractIterator<Vector>() {
+ private final Iterator<WeightedThing<Vector>> projected = scalarProjections.get(0).iterator();
+ @Override
+ protected Vector computeNext() {
+ if (!projected.hasNext()) {
+ return endOfData();
+ }
+ return projected.next().getValue();
+ }
+ };
+ }
+
+ @Override
+ public boolean remove(Vector vector, double epsilon) {
+ WeightedThing<Vector> toRemove = searchFirst(vector, false);
+ if (toRemove.getWeight() < epsilon) {
+ Iterator<? extends Vector> basisVectors = basisMatrix.iterator();
+ for (TreeMultiset<WeightedThing<Vector>> projection : scalarProjections) {
+ if (!projection.remove(new WeightedThing<>(vector, vector.dot(basisVectors.next())))) {
+ throw new RuntimeException("Internal inconsistency in ProjectionSearch");
+ }
+ }
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public void clear() {
+ if (scalarProjections == null) {
+ return;
+ }
+ for (TreeMultiset<WeightedThing<Vector>> set : scalarProjections) {
+ set.clear();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java
new file mode 100644
index 0000000..dd387b5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/Searcher.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import java.util.List;
+
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import org.apache.lucene.util.PriorityQueue;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+import org.apache.mahout.math.random.WeightedThing;
+
+/**
+ * Describes how to search a bunch of vectors.
+ * The vectors can be of any type (weighted, sparse, ...) but only the values of the vector matter
+ * when searching (weights, indices, ...) will not.
+ *
+ * When iterating through a Searcher, the Vectors added to it are returned.
+ */
+public abstract class Searcher implements Iterable<Vector> {
+ protected DistanceMeasure distanceMeasure;
+
+ protected Searcher(DistanceMeasure distanceMeasure) {
+ this.distanceMeasure = distanceMeasure;
+ }
+
+ public DistanceMeasure getDistanceMeasure() {
+ return distanceMeasure;
+ }
+
+ /**
+ * Add a new Vector to the Searcher that will be checked when getting
+ * the nearest neighbors.
+ *
+ * The vector IS NOT CLONED. Do not modify the vector externally otherwise the internal
+ * Searcher data structures could be invalidated.
+ */
+ public abstract void add(Vector vector);
+
+ /**
+ * Returns the number of WeightedVectors being searched for nearest neighbors.
+ */
+ public abstract int size();
+
+ /**
+ * When querying the Searcher for the closest vectors, a list of WeightedThing<Vector>s is
+ * returned. The value of the WeightedThing is the neighbor and the weight is the
+ * the distance (calculated by some metric - see a concrete implementation) between the query
+ * and neighbor.
+ * The actual type of vector in the pair is the same as the vector added to the Searcher.
+ * @param query the vector to search for
+ * @param limit the number of results to return
+ * @return the list of weighted vectors closest to the query
+ */
+ public abstract List<WeightedThing<Vector>> search(Vector query, int limit);
+
+ public List<List<WeightedThing<Vector>>> search(Iterable<? extends Vector> queries, int limit) {
+ List<List<WeightedThing<Vector>>> results = Lists.newArrayListWithExpectedSize(Iterables.size(queries));
+ for (Vector query : queries) {
+ results.add(search(query, limit));
+ }
+ return results;
+ }
+
+ /**
+ * Returns the closest vector to the query.
+ * When only one the nearest vector is needed, use this method, NOT search(query, limit) because
+ * it's faster (less overhead).
+ *
+ * @param query the vector to search for
+ * @param differentThanQuery if true, returns the closest vector different than the query (this
+ * only matters if the query is among the searched vectors), otherwise,
+ * returns the closest vector to the query (even the same vector).
+ * @return the weighted vector closest to the query
+ */
+ public abstract WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery);
+
+ public List<WeightedThing<Vector>> searchFirst(Iterable<? extends Vector> queries, boolean differentThanQuery) {
+ List<WeightedThing<Vector>> results = Lists.newArrayListWithExpectedSize(Iterables.size(queries));
+ for (Vector query : queries) {
+ results.add(searchFirst(query, differentThanQuery));
+ }
+ return results;
+ }
+
+ /**
+ * Adds all the data elements in the Searcher.
+ *
+ * @param data an iterable of WeightedVectors to add.
+ */
+ public void addAll(Iterable<? extends Vector> data) {
+ for (Vector vector : data) {
+ add(vector);
+ }
+ }
+
+ /**
+ * Adds all the data elements in the Searcher.
+ *
+ * @param data an iterable of MatrixSlices to add.
+ */
+ public void addAllMatrixSlices(Iterable<MatrixSlice> data) {
+ for (MatrixSlice slice : data) {
+ add(slice.vector());
+ }
+ }
+
+ public void addAllMatrixSlicesAsWeightedVectors(Iterable<MatrixSlice> data) {
+ for (MatrixSlice slice : data) {
+ add(new WeightedVector(slice.vector(), 1, slice.index()));
+ }
+ }
+
+ public boolean remove(Vector v, double epsilon) {
+ throw new UnsupportedOperationException("Can't remove a vector from a "
+ + this.getClass().getName());
+ }
+
+ public void clear() {
+ throw new UnsupportedOperationException("Can't remove vectors from a "
+ + this.getClass().getName());
+ }
+
+ /**
+ * Returns a bounded size priority queue, in reverse order that keeps track of the best nearest neighbor vectors.
+ * @param limit maximum size of the heap.
+ * @return the priority queue.
+ */
+ public static PriorityQueue<WeightedThing<Vector>> getCandidateQueue(int limit) {
+ return new PriorityQueue<WeightedThing<Vector>>(limit) {
+ @Override
+ protected boolean lessThan(WeightedThing<Vector> a, WeightedThing<Vector> b) {
+ return a.getWeight() > b.getWeight();
+ }
+ };
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java
new file mode 100644
index 0000000..68365c7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/UpdatableSearcher.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Describes how we search vectors. A class should extend UpdatableSearch only if it can handle a remove function.
+ */
+public abstract class UpdatableSearcher extends Searcher {
+
+ protected UpdatableSearcher(DistanceMeasure distanceMeasure) {
+ super(distanceMeasure);
+ }
+
+ @Override
+ public abstract boolean remove(Vector v, double epsilon);
+
+ @Override
+ public abstract void clear();
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java b/mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java
new file mode 100644
index 0000000..79fe4b6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/random/RandomProjector.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.lang.math.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+
+public final class RandomProjector {
+ private RandomProjector() {
+ }
+
+ /**
+ * Generates a basis matrix of size projectedVectorSize x vectorSize. Multiplying a a vector by
+ * this matrix results in the projected vector.
+ *
+ * The rows of the matrix are sampled from a multi normal distribution.
+ *
+ * @param projectedVectorSize final projected size of a vector (number of projection vectors)
+ * @param vectorSize initial vector size
+ * @return a projection matrix
+ */
+ public static Matrix generateBasisNormal(int projectedVectorSize, int vectorSize) {
+ Matrix basisMatrix = new DenseMatrix(projectedVectorSize, vectorSize);
+ basisMatrix.assign(new Normal());
+ for (MatrixSlice row : basisMatrix) {
+ row.vector().assign(row.normalize());
+ }
+ return basisMatrix;
+ }
+
+ /**
+ * Generates a basis matrix of size projectedVectorSize x vectorSize. Multiplying a a vector by
+ * this matrix results in the projected vector.
+ *
+ * The rows of a matrix are sample from a distribution where:
+ * - +1 has probability 1/2,
+ * - -1 has probability 1/2
+ *
+ * See Achlioptas, D. (2003). Database-friendly random projections: Johnson-Lindenstrauss with binary coins.
+ * Journal of Computer and System Sciences, 66(4), 671–687. doi:10.1016/S0022-0000(03)00025-4
+ *
+ * @param projectedVectorSize final projected size of a vector (number of projection vectors)
+ * @param vectorSize initial vector size
+ * @return a projection matrix
+ */
+ public static Matrix generateBasisPlusMinusOne(int projectedVectorSize, int vectorSize) {
+ Matrix basisMatrix = new DenseMatrix(projectedVectorSize, vectorSize);
+ for (int i = 0; i < projectedVectorSize; ++i) {
+ for (int j = 0; j < vectorSize; ++j) {
+ basisMatrix.set(i, j, RandomUtils.nextInt(2) == 0 ? 1 : -1);
+ }
+ }
+ for (MatrixSlice row : basisMatrix) {
+ row.vector().assign(row.normalize());
+ }
+ return basisMatrix;
+ }
+
+ /**
+ * Generates a basis matrix of size projectedVectorSize x vectorSize. Multiplying a a vector by
+ * this matrix results in the projected vector.
+ *
+ * The rows of a matrix are sample from a distribution where:
+ * - 0 has probability 2/3,
+ * - +1 has probability 1/6,
+ * - -1 has probability 1/6
+ *
+ * See Achlioptas, D. (2003). Database-friendly random projections: Johnson-Lindenstrauss with binary coins.
+ * Journal of Computer and System Sciences, 66(4), 671–687. doi:10.1016/S0022-0000(03)00025-4
+ *
+ * @param projectedVectorSize final projected size of a vector (number of projection vectors)
+ * @param vectorSize initial vector size
+ * @return a projection matrix
+ */
+ public static Matrix generateBasisZeroPlusMinusOne(int projectedVectorSize, int vectorSize) {
+ Matrix basisMatrix = new DenseMatrix(projectedVectorSize, vectorSize);
+ Multinomial<Double> choice = new Multinomial<>();
+ choice.add(0.0, 2 / 3.0);
+ choice.add(Math.sqrt(3.0), 1 / 6.0);
+ choice.add(-Math.sqrt(3.0), 1 / 6.0);
+ for (int i = 0; i < projectedVectorSize; ++i) {
+ for (int j = 0; j < vectorSize; ++j) {
+ basisMatrix.set(i, j, choice.sample());
+ }
+ }
+ for (MatrixSlice row : basisMatrix) {
+ row.vector().assign(row.normalize());
+ }
+ return basisMatrix;
+ }
+
+ /**
+ * Generates a list of projectedVectorSize vectors, each of size vectorSize. This looks like a
+ * matrix of size (projectedVectorSize, vectorSize).
+ * @param projectedVectorSize final projected size of a vector (number of projection vectors)
+ * @param vectorSize initial vector size
+ * @return a list of projection vectors
+ */
+ public static List<Vector> generateVectorBasis(int projectedVectorSize, int vectorSize) {
+ DoubleFunction random = new Normal();
+ List<Vector> basisVectors = Lists.newArrayList();
+ for (int i = 0; i < projectedVectorSize; ++i) {
+ Vector basisVector = new DenseVector(vectorSize);
+ basisVector.assign(random);
+ basisVector.normalize();
+ basisVectors.add(basisVector);
+ }
+ return basisVectors;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/ssvd/SequentialOutOfCoreSvd.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/ssvd/SequentialOutOfCoreSvd.java b/mr/src/main/java/org/apache/mahout/math/ssvd/SequentialOutOfCoreSvd.java
new file mode 100644
index 0000000..f7724f7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/ssvd/SequentialOutOfCoreSvd.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.ssvd;
+
+import org.apache.mahout.math.CholeskyDecomposition;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.RandomTrinaryMatrix;
+import org.apache.mahout.math.SingularValueDecomposition;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+/**
+ * Sequential block-oriented out of core SVD algorithm.
+ * <p/>
+ * The basic algorithm (in-core version) is that we do a random projects, get a basis of that and
+ * then re-project the original matrix using that basis. This re-projected matrix allows us to get
+ * an approximate SVD of the original matrix.
+ * <p/>
+ * The input to this program is a list of files that contain the sub-matrices A_i. The result is a
+ * vector of singular values and optionally files that contain the left and right singular vectors.
+ * <p/>
+ * Mathematically, to decompose A, we do this:
+ * <p/>
+ * Y = A * \Omega
+ * <p/>
+ * Q R = Y
+ * <p/>
+ * B = Q" A
+ * <p/>
+ * U D V' = B
+ * <p/>
+ * (Q U) D V' \approx A
+ * <p/>
+ * To do this out of core, we break A into blocks each with the same number of rows. This gives a
+ * block-wise version of Y. As we are computing Y, we can also accumulate Y' Y and when done, we
+ * can use a Cholesky decomposition to do the QR decomposition of Y in a latent form. That gives us
+ * B in block-wise form and we can do the same trick to get an LQ of B. The L part can be
+ * decomposed in memory. Then we can recombine to get the final decomposition.
+ * <p/>
+ * The details go like this. Start with a block form of A.
+ * <p/>
+ * Y_i = A_i * \Omega
+ * <p/>
+ * Instead of doing a QR decomposition of Y, we do a Cholesky decomposition of Y' Y. This is a
+ * small in-memory operation. Q is large and dense and won't fit in memory.
+ * <p/>
+ * R' R = \sum_i Y_i' Y_i
+ * <p/>
+ * For reference, R is all we need to compute explicitly. Q will be computed on the fly when
+ * needed.
+ * <p/>
+ * Q = Y R^-1
+ * <p/>
+ * B = Q" A = \sum_i (A \Omega R^-1)' A_i
+ * <p/>
+ * As B is generated, it needs to be segmented in row-wise blocks since it is wide but not tall.
+ * This storage requires something like a map-reduce to accumulate the partial sums. In this code,
+ * we do this by re-reading previously computed chunks and augmenting them.
+ * <p/>
+ * While the pieces of B are being computed, we can accumulate B B' in preparation for a second
+ * Cholesky decomposition
+ * <p/>
+ * L L' = B B' = sum B_j B_j'
+ * <p/>
+ * Again, this is an LQ decomposition of BB', but we don't compute the Q part explicitly. L will be
+ * small and thus tractable.
+ * <p/>
+ * Finally, we do the actual SVD decomposition.
+ * <p/>
+ * U_0 D V_0' = L
+ * <p/>
+ * D contains the singular values of A. The left and right singular values can be reconstructed
+ * using Y and B. Note that both of these reconstructions can be done with single passes through
+ * the blocked forms of Y and B.
+ * <p/>
+ * U = A \Omega R^{-1} U_0
+ * <p/>
+ * V = B' L'^{-1} V_0
+ */
+public class SequentialOutOfCoreSvd {
+
+ private final CholeskyDecomposition l2;
+ private final SingularValueDecomposition svd;
+ private final CholeskyDecomposition r2;
+ private final int columnsPerSlice;
+ private final int seed;
+ private final int dim;
+
+ public SequentialOutOfCoreSvd(Iterable<File> partsOfA, File tmpDir, int internalDimension, int columnsPerSlice)
+ throws IOException {
+ this.columnsPerSlice = columnsPerSlice;
+ this.dim = internalDimension;
+
+ seed = 1;
+ Matrix y2 = null;
+
+ // step 1, compute R as in R'R = Y'Y where Y = A \Omega
+ for (File file : partsOfA) {
+ MatrixWritable m = new MatrixWritable();
+ try (DataInputStream in = new DataInputStream(new FileInputStream(file))) {
+ m.readFields(in);
+ }
+
+ Matrix aI = m.get();
+ Matrix omega = new RandomTrinaryMatrix(seed, aI.columnSize(), internalDimension, false);
+ Matrix y = aI.times(omega);
+
+ if (y2 == null) {
+ y2 = y.transpose().times(y);
+ } else {
+ y2.assign(y.transpose().times(y), Functions.PLUS);
+ }
+ }
+ r2 = new CholeskyDecomposition(y2);
+
+ // step 2, compute B
+ int ncols = 0;
+ for (File file : partsOfA) {
+ MatrixWritable m = new MatrixWritable();
+ try (DataInputStream in = new DataInputStream(new FileInputStream(file))) {
+ m.readFields(in);
+ }
+ Matrix aI = m.get();
+ ncols = Math.max(ncols, aI.columnSize());
+
+ Matrix omega = new RandomTrinaryMatrix(seed, aI.numCols(), internalDimension, false);
+ for (int j = 0; j < aI.numCols(); j += columnsPerSlice) {
+ Matrix yI = aI.times(omega);
+ Matrix aIJ = aI.viewPart(0, aI.rowSize(), j, Math.min(columnsPerSlice, aI.columnSize() - j));
+ Matrix bIJ = r2.solveRight(yI).transpose().times(aIJ);
+ addToSavedCopy(bFile(tmpDir, j), bIJ);
+ }
+ }
+
+ // step 3, compute BB', L and SVD(L)
+ Matrix b2 = new DenseMatrix(internalDimension, internalDimension);
+ MatrixWritable bTmp = new MatrixWritable();
+ for (int j = 0; j < ncols; j += columnsPerSlice) {
+ if (bFile(tmpDir, j).exists()) {
+ try (DataInputStream in = new DataInputStream(new FileInputStream(bFile(tmpDir, j)))) {
+ bTmp.readFields(in);
+ }
+
+ b2.assign(bTmp.get().times(bTmp.get().transpose()), Functions.PLUS);
+ }
+ }
+ l2 = new CholeskyDecomposition(b2);
+ svd = new SingularValueDecomposition(l2.getL());
+ }
+
+ public void computeV(File tmpDir, int ncols) throws IOException {
+ // step 5, compute pieces of V
+ for (int j = 0; j < ncols; j += columnsPerSlice) {
+ File bPath = bFile(tmpDir, j);
+ if (bPath.exists()) {
+ MatrixWritable m = new MatrixWritable();
+ try (DataInputStream in = new DataInputStream(new FileInputStream(bPath))) {
+ m.readFields(in);
+ }
+ m.set(l2.solveRight(m.get().transpose()).times(svd.getV()));
+ try (DataOutputStream out = new DataOutputStream(new FileOutputStream(
+ new File(tmpDir, String.format("V-%s", bPath.getName().replaceAll(".*-", "")))))) {
+ m.write(out);
+ }
+ }
+ }
+ }
+
+ public void computeU(Iterable<File> partsOfA, File tmpDir) throws IOException {
+ // step 4, compute pieces of U
+ for (File file : partsOfA) {
+ MatrixWritable m = new MatrixWritable();
+ m.readFields(new DataInputStream(new FileInputStream(file)));
+ Matrix aI = m.get();
+
+ Matrix y = aI.times(new RandomTrinaryMatrix(seed, aI.numCols(), dim, false));
+ Matrix uI = r2.solveRight(y).times(svd.getU());
+ m.set(uI);
+ try (DataOutputStream out = new DataOutputStream(new FileOutputStream(
+ new File(tmpDir, String.format("U-%s", file.getName().replaceAll(".*-", "")))))) {
+ m.write(out);
+ }
+ }
+ }
+
+ private static void addToSavedCopy(File file, Matrix matrix) throws IOException {
+ MatrixWritable mw = new MatrixWritable();
+ if (file.exists()) {
+ try (DataInputStream in = new DataInputStream(new FileInputStream(file))) {
+ mw.readFields(in);
+ }
+ mw.get().assign(matrix, Functions.PLUS);
+ } else {
+ mw.set(matrix);
+ }
+ try (DataOutputStream out = new DataOutputStream(new FileOutputStream(file))) {
+ mw.write(out);
+ }
+ }
+
+ private static File bFile(File tmpDir, int j) {
+ return new File(tmpDir, String.format("B-%09d", j));
+ }
+
+ public Vector getSingularValues() {
+ return new DenseVector(svd.getSingularValues());
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java b/mr/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java
new file mode 100644
index 0000000..4485bbe
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/stats/GlobalOnlineAuc.java
@@ -0,0 +1,168 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Random;
+
+/**
+ * Computes a running estimate of AUC (see http://en.wikipedia.org/wiki/Receiver_operating_characteristic).
+ * <p/>
+ * Since AUC is normally a global property of labeled scores, it is almost always computed in a
+ * batch fashion. The probabilistic definition (the probability that a random element of one set
+ * has a higher score than a random element of another set) gives us a way to estimate this
+ * on-line.
+ *
+ * @see GroupedOnlineAuc
+ */
+public class GlobalOnlineAuc implements OnlineAuc {
+ enum ReplacementPolicy {
+ FIFO, FAIR, RANDOM
+ }
+
+ // increasing this to 100 causes very small improvements in accuracy. Decreasing it to 2
+ // causes substantial degradation for the FAIR and RANDOM policies, but almost no change
+ // for the FIFO policy
+ public static final int HISTORY = 10;
+
+ // defines the exponential averaging window for results
+ private int windowSize = Integer.MAX_VALUE;
+
+ // FIFO has distinctly the best properties as a policy. See OnlineAucTest for details
+ private ReplacementPolicy policy = ReplacementPolicy.FIFO;
+ private final Random random = RandomUtils.getRandom();
+ private Matrix scores;
+ private Vector averages;
+ private Vector samples;
+
+ public GlobalOnlineAuc() {
+ int numCategories = 2;
+ scores = new DenseMatrix(numCategories, HISTORY);
+ scores.assign(Double.NaN);
+ averages = new DenseVector(numCategories);
+ averages.assign(0.5);
+ samples = new DenseVector(numCategories);
+ }
+
+ @Override
+ public double addSample(int category, String groupKey, double score) {
+ return addSample(category, score);
+ }
+
+ @Override
+ public double addSample(int category, double score) {
+ int n = (int) samples.get(category);
+ if (n < HISTORY) {
+ scores.set(category, n, score);
+ } else {
+ switch (policy) {
+ case FIFO:
+ scores.set(category, n % HISTORY, score);
+ break;
+ case FAIR:
+ int j1 = random.nextInt(n + 1);
+ if (j1 < HISTORY) {
+ scores.set(category, j1, score);
+ }
+ break;
+ case RANDOM:
+ int j2 = random.nextInt(HISTORY);
+ scores.set(category, j2, score);
+ break;
+ default:
+ throw new IllegalStateException("Unknown policy: " + policy);
+ }
+ }
+
+ samples.set(category, n + 1);
+
+ if (samples.minValue() >= 1) {
+ // compare to previous scores for other category
+ Vector row = scores.viewRow(1 - category);
+ double m = 0.0;
+ double count = 0.0;
+ for (Vector.Element element : row.all()) {
+ double v = element.get();
+ if (Double.isNaN(v)) {
+ continue;
+ }
+ count++;
+ if (score > v) {
+ m++;
+ // } else if (score < v) {
+ // m += 0
+ } else if (score == v) {
+ m += 0.5;
+ }
+ }
+ averages.set(category, averages.get(category)
+ + (m / count - averages.get(category)) / Math.min(windowSize, samples.get(category)));
+ }
+ return auc();
+ }
+
+ @Override
+ public double auc() {
+ // return an unweighted average of all averages.
+ return (1 - averages.get(0) + averages.get(1)) / 2;
+ }
+
+ public double value() {
+ return auc();
+ }
+
+ @Override
+ public void setPolicy(ReplacementPolicy policy) {
+ this.policy = policy;
+ }
+
+ @Override
+ public void setWindowSize(int windowSize) {
+ this.windowSize = windowSize;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(windowSize);
+ out.writeInt(policy.ordinal());
+ MatrixWritable.writeMatrix(out, scores);
+ VectorWritable.writeVector(out, averages);
+ VectorWritable.writeVector(out, samples);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ windowSize = in.readInt();
+ policy = ReplacementPolicy.values()[in.readInt()];
+
+ scores = MatrixWritable.readMatrix(in);
+ averages = VectorWritable.readVector(in);
+ samples = VectorWritable.readVector(in);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java b/mr/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java
new file mode 100644
index 0000000..3fa1b79
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/stats/GroupedOnlineAuc.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import com.google.common.collect.Maps;
+import org.apache.mahout.classifier.sgd.PolymorphicWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Map;
+
+/**
+ * Implements a variant on AUC where the result returned is an average of several AUC measurements
+ * made on sub-groups of the overall data. Controlling for the grouping factor allows the effects
+ * of the grouping factor on the model to be ignored. This is useful, for instance, when using a
+ * classifier as a click prediction engine. In that case you want AUC to refer only to the ranking
+ * of items for a particular user, not to the discrimination of users from each other. Grouping by
+ * user (or user cluster) helps avoid optimizing for the wrong quality.
+ */
+public class GroupedOnlineAuc implements OnlineAuc {
+ private final Map<String, OnlineAuc> map = Maps.newHashMap();
+ private GlobalOnlineAuc.ReplacementPolicy policy;
+ private int windowSize;
+
+ @Override
+ public double addSample(int category, String groupKey, double score) {
+ if (groupKey == null) {
+ addSample(category, score);
+ }
+
+ OnlineAuc group = map.get(groupKey);
+ if (group == null) {
+ group = new GlobalOnlineAuc();
+ if (policy != null) {
+ group.setPolicy(policy);
+ }
+ if (windowSize > 0) {
+ group.setWindowSize(windowSize);
+ }
+ map.put(groupKey, group);
+ }
+ return group.addSample(category, score);
+ }
+
+ @Override
+ public double addSample(int category, double score) {
+ throw new UnsupportedOperationException("Can't add to " + this.getClass() + " without group key");
+ }
+
+ @Override
+ public double auc() {
+ double sum = 0;
+ for (OnlineAuc auc : map.values()) {
+ sum += auc.auc();
+ }
+ return sum / map.size();
+ }
+
+ @Override
+ public void setPolicy(GlobalOnlineAuc.ReplacementPolicy policy) {
+ this.policy = policy;
+ for (OnlineAuc auc : map.values()) {
+ auc.setPolicy(policy);
+ }
+ }
+
+ @Override
+ public void setWindowSize(int windowSize) {
+ this.windowSize = windowSize;
+ for (OnlineAuc auc : map.values()) {
+ auc.setWindowSize(windowSize);
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(map.size());
+ for (Map.Entry<String,OnlineAuc> entry : map.entrySet()) {
+ out.writeUTF(entry.getKey());
+ PolymorphicWritable.write(out, entry.getValue());
+ }
+ out.writeInt(policy.ordinal());
+ out.writeInt(windowSize);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int n = in.readInt();
+ map.clear();
+ for (int i = 0; i < n; i++) {
+ String key = in.readUTF();
+ map.put(key, PolymorphicWritable.read(in, OnlineAuc.class));
+ }
+ policy = GlobalOnlineAuc.ReplacementPolicy.values()[in.readInt()];
+ windowSize = in.readInt();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java b/mr/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
new file mode 100644
index 0000000..d21ae6b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import org.apache.hadoop.io.Writable;
+
+/**
+ * Describes the generic outline of how to compute AUC. Currently there are two
+ * implementations of this, one for computing a global estimate of AUC and the other
+ * for computing average grouped AUC. Grouped AUC is useful when misusing a classifier
+ * as a recommendation system.
+ */
+public interface OnlineAuc extends Writable {
+ double addSample(int category, String groupKey, double score);
+
+ double addSample(int category, double score);
+
+ double auc();
+
+ void setPolicy(GlobalOnlineAuc.ReplacementPolicy policy);
+
+ void setWindowSize(int windowSize);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/stats/Sampler.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/stats/Sampler.java b/mr/src/main/java/org/apache/mahout/math/stats/Sampler.java
new file mode 100644
index 0000000..4b9e8a9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/stats/Sampler.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.Vector;
+
+import java.util.Arrays;
+import java.util.Random;
+
+/**
+ * Discrete distribution sampler:
+ *
+ * Samples from a given discrete distribution: you provide a source of randomness and a Vector
+ * (cardinality N) which describes a distribution over [0,N), and calls to sample() sample
+ * from 0 to N using this distribution
+ */
+public class Sampler {
+
+ private final Random random;
+ private final double[] sampler;
+
+ public Sampler(Random random) {
+ this.random = random;
+ sampler = null;
+ }
+
+ public Sampler(Random random, double[] sampler) {
+ this.random = random;
+ this.sampler = sampler;
+ }
+
+ public Sampler(Random random, Vector distribution) {
+ this.random = random;
+ this.sampler = samplerFor(distribution);
+ }
+
+ public int sample(Vector distribution) {
+ return sample(samplerFor(distribution));
+ }
+
+ public int sample() {
+ Preconditions.checkNotNull(sampler,
+ "Sampler must have been constructed with a distribution, or else sample(Vector) should be used to sample");
+ return sample(sampler);
+ }
+
+ private static double[] samplerFor(Vector vectorDistribution) {
+ int size = vectorDistribution.size();
+ double[] partition = new double[size];
+ double norm = vectorDistribution.norm(1);
+ double sum = 0;
+ for (int i = 0; i < size; i++) {
+ sum += vectorDistribution.get(i) / norm;
+ partition[i] = sum;
+ }
+ return partition;
+ }
+
+ private int sample(double[] sampler) {
+ int index = Arrays.binarySearch(sampler, random.nextDouble());
+ return index < 0 ? -(index + 1) : index;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/DictionaryVectorizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/DictionaryVectorizer.java b/mr/src/main/java/org/apache/mahout/vectorizer/DictionaryVectorizer.java
new file mode 100644
index 0000000..8a1f8f8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/DictionaryVectorizer.java
@@ -0,0 +1,416 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.vectorizer.collocations.llr.CollocDriver;
+import org.apache.mahout.vectorizer.collocations.llr.LLRReducer;
+import org.apache.mahout.vectorizer.common.PartialVectorMerger;
+import org.apache.mahout.vectorizer.term.TFPartialVectorReducer;
+import org.apache.mahout.vectorizer.term.TermCountCombiner;
+import org.apache.mahout.vectorizer.term.TermCountMapper;
+import org.apache.mahout.vectorizer.term.TermCountReducer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * This class converts a set of input documents in the sequence file format to vectors. The Sequence file
+ * input should have a {@link Text} key containing the unique document identifier and a {@link StringTuple}
+ * value containing the tokenized document. You may use {@link DocumentProcessor} to tokenize the document.
+ * This is a dictionary based Vectorizer.
+ */
+public final class DictionaryVectorizer extends AbstractJob implements Vectorizer {
+ private static final Logger log = LoggerFactory.getLogger(DictionaryVectorizer.class);
+
+ public static final String DOCUMENT_VECTOR_OUTPUT_FOLDER = "tf-vectors";
+ public static final String MIN_SUPPORT = "min.support";
+ public static final String MAX_NGRAMS = "max.ngrams";
+ public static final int DEFAULT_MIN_SUPPORT = 2;
+ public static final String DICTIONARY_FILE = "dictionary.file-";
+
+ private static final int MAX_CHUNKSIZE = 10000;
+ private static final int MIN_CHUNKSIZE = 100;
+ private static final String OUTPUT_FILES_PATTERN = "part-*";
+ // 4 byte overhead for each entry in the OpenObjectIntHashMap
+ private static final int DICTIONARY_BYTE_OVERHEAD = 4;
+ private static final String VECTOR_OUTPUT_FOLDER = "partial-vectors-";
+ private static final String DICTIONARY_JOB_FOLDER = "wordcount";
+
+ /**
+ * Cannot be initialized. Use the static functions
+ */
+ private DictionaryVectorizer() {
+ }
+
+ //TODO: move more of SparseVectorsFromSequenceFile in here, and then fold SparseVectorsFrom with
+ // EncodedVectorsFrom to have one framework.
+
+ @Override
+ public void createVectors(Path input, Path output, VectorizerConfig config)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ createTermFrequencyVectors(input,
+ output,
+ config.getTfDirName(),
+ config.getConf(),
+ config.getMinSupport(),
+ config.getMaxNGramSize(),
+ config.getMinLLRValue(),
+ config.getNormPower(),
+ config.isLogNormalize(),
+ config.getNumReducers(),
+ config.getChunkSizeInMegabytes(),
+ config.isSequentialAccess(),
+ config.isNamedVectors());
+ }
+
+ /**
+ * Create Term Frequency (Tf) Vectors from the input set of documents in {@link SequenceFile} format. This
+ * tries to fix the maximum memory used by the feature chunk per node thereby splitting the process across
+ * multiple map/reduces.
+ *
+ * @param input
+ * input directory of the documents in {@link SequenceFile} format
+ * @param output
+ * output directory where {@link org.apache.mahout.math.RandomAccessSparseVector}'s of the document
+ * are generated
+ * @param tfVectorsFolderName
+ * The name of the folder in which the final output vectors will be stored
+ * @param baseConf
+ * job configuration
+ * @param normPower
+ * L_p norm to be computed
+ * @param logNormalize
+ * whether to use log normalization
+ * @param minSupport
+ * the minimum frequency of the feature in the entire corpus to be considered for inclusion in the
+ * sparse vector
+ * @param maxNGramSize
+ * 1 = unigram, 2 = unigram and bigram, 3 = unigram, bigram and trigram
+ * @param minLLRValue
+ * minValue of log likelihood ratio to used to prune ngrams
+ * @param chunkSizeInMegabytes
+ * the size in MB of the feature => id chunk to be kept in memory at each node during Map/Reduce
+ * stage. Its recommended you calculated this based on the number of cores and the free memory
+ * available to you per node. Say, you have 2 cores and around 1GB extra memory to spare we
+ * recommend you use a split size of around 400-500MB so that two simultaneous reducers can create
+ * partial vectors without thrashing the system due to increased swapping
+ */
+ public static void createTermFrequencyVectors(Path input,
+ Path output,
+ String tfVectorsFolderName,
+ Configuration baseConf,
+ int minSupport,
+ int maxNGramSize,
+ float minLLRValue,
+ float normPower,
+ boolean logNormalize,
+ int numReducers,
+ int chunkSizeInMegabytes,
+ boolean sequentialAccess,
+ boolean namedVectors)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Preconditions.checkArgument(normPower == PartialVectorMerger.NO_NORMALIZING || normPower >= 0,
+ "If specified normPower must be nonnegative", normPower);
+ Preconditions.checkArgument(normPower == PartialVectorMerger.NO_NORMALIZING
+ || (normPower > 1 && !Double.isInfinite(normPower))
+ || !logNormalize,
+ "normPower must be > 1 and not infinite if log normalization is chosen", normPower);
+ if (chunkSizeInMegabytes < MIN_CHUNKSIZE) {
+ chunkSizeInMegabytes = MIN_CHUNKSIZE;
+ } else if (chunkSizeInMegabytes > MAX_CHUNKSIZE) { // 10GB
+ chunkSizeInMegabytes = MAX_CHUNKSIZE;
+ }
+ if (minSupport < 0) {
+ minSupport = DEFAULT_MIN_SUPPORT;
+ }
+
+ Path dictionaryJobPath = new Path(output, DICTIONARY_JOB_FOLDER);
+
+ log.info("Creating dictionary from {} and saving at {}", input, dictionaryJobPath);
+
+ int[] maxTermDimension = new int[1];
+ List<Path> dictionaryChunks;
+ if (maxNGramSize == 1) {
+ startWordCounting(input, dictionaryJobPath, baseConf, minSupport);
+ dictionaryChunks =
+ createDictionaryChunks(dictionaryJobPath, output, baseConf, chunkSizeInMegabytes, maxTermDimension);
+ } else {
+ CollocDriver.generateAllGrams(input, dictionaryJobPath, baseConf, maxNGramSize,
+ minSupport, minLLRValue, numReducers);
+ dictionaryChunks =
+ createDictionaryChunks(new Path(new Path(output, DICTIONARY_JOB_FOLDER),
+ CollocDriver.NGRAM_OUTPUT_DIRECTORY),
+ output,
+ baseConf,
+ chunkSizeInMegabytes,
+ maxTermDimension);
+ }
+
+ int partialVectorIndex = 0;
+ Collection<Path> partialVectorPaths = Lists.newArrayList();
+ for (Path dictionaryChunk : dictionaryChunks) {
+ Path partialVectorOutputPath = new Path(output, VECTOR_OUTPUT_FOLDER + partialVectorIndex++);
+ partialVectorPaths.add(partialVectorOutputPath);
+ makePartialVectors(input, baseConf, maxNGramSize, dictionaryChunk, partialVectorOutputPath,
+ maxTermDimension[0], sequentialAccess, namedVectors, numReducers);
+ }
+
+ Configuration conf = new Configuration(baseConf);
+
+ Path outputDir = new Path(output, tfVectorsFolderName);
+ PartialVectorMerger.mergePartialVectors(partialVectorPaths, outputDir, conf, normPower, logNormalize,
+ maxTermDimension[0], sequentialAccess, namedVectors, numReducers);
+ HadoopUtil.delete(conf, partialVectorPaths);
+ }
+
+ /**
+ * Read the feature frequency List which is built at the end of the Word Count Job and assign ids to them.
+ * This will use constant memory and will run at the speed of your disk read
+ */
+ private static List<Path> createDictionaryChunks(Path wordCountPath,
+ Path dictionaryPathBase,
+ Configuration baseConf,
+ int chunkSizeInMegabytes,
+ int[] maxTermDimension) throws IOException {
+ List<Path> chunkPaths = Lists.newArrayList();
+
+ Configuration conf = new Configuration(baseConf);
+
+ FileSystem fs = FileSystem.get(wordCountPath.toUri(), conf);
+
+ long chunkSizeLimit = chunkSizeInMegabytes * 1024L * 1024L;
+ int chunkIndex = 0;
+ Path chunkPath = new Path(dictionaryPathBase, DICTIONARY_FILE + chunkIndex);
+ chunkPaths.add(chunkPath);
+
+ SequenceFile.Writer dictWriter = new SequenceFile.Writer(fs, conf, chunkPath, Text.class, IntWritable.class);
+
+ try {
+ long currentChunkSize = 0;
+ Path filesPattern = new Path(wordCountPath, OUTPUT_FILES_PATTERN);
+ int i = 0;
+ for (Pair<Writable,Writable> record
+ : new SequenceFileDirIterable<>(filesPattern, PathType.GLOB, null, null, true, conf)) {
+ if (currentChunkSize > chunkSizeLimit) {
+ Closeables.close(dictWriter, false);
+ chunkIndex++;
+
+ chunkPath = new Path(dictionaryPathBase, DICTIONARY_FILE + chunkIndex);
+ chunkPaths.add(chunkPath);
+
+ dictWriter = new SequenceFile.Writer(fs, conf, chunkPath, Text.class, IntWritable.class);
+ currentChunkSize = 0;
+ }
+
+ Writable key = record.getFirst();
+ int fieldSize = DICTIONARY_BYTE_OVERHEAD + key.toString().length() * 2 + Integer.SIZE / 8;
+ currentChunkSize += fieldSize;
+ dictWriter.append(key, new IntWritable(i++));
+ }
+ maxTermDimension[0] = i;
+ } finally {
+ Closeables.close(dictWriter, false);
+ }
+
+ return chunkPaths;
+ }
+
+ /**
+ * Create a partial vector using a chunk of features from the input documents. The input documents has to be
+ * in the {@link SequenceFile} format
+ *
+ * @param input
+ * input directory of the documents in {@link SequenceFile} format
+ * @param baseConf
+ * job configuration
+ * @param maxNGramSize
+ * maximum size of ngrams to generate
+ * @param dictionaryFilePath
+ * location of the chunk of features and the id's
+ * @param output
+ * output directory were the partial vectors have to be created
+ * @param dimension
+ * @param sequentialAccess
+ * output vectors should be optimized for sequential access
+ * @param namedVectors
+ * output vectors should be named, retaining key (doc id) as a label
+ * @param numReducers
+ * the desired number of reducer tasks
+ */
+ private static void makePartialVectors(Path input,
+ Configuration baseConf,
+ int maxNGramSize,
+ Path dictionaryFilePath,
+ Path output,
+ int dimension,
+ boolean sequentialAccess,
+ boolean namedVectors,
+ int numReducers)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ Configuration conf = new Configuration(baseConf);
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ conf.setInt(PartialVectorMerger.DIMENSION, dimension);
+ conf.setBoolean(PartialVectorMerger.SEQUENTIAL_ACCESS, sequentialAccess);
+ conf.setBoolean(PartialVectorMerger.NAMED_VECTOR, namedVectors);
+ conf.setInt(MAX_NGRAMS, maxNGramSize);
+ DistributedCache.addCacheFile(dictionaryFilePath.toUri(), conf);
+
+ Job job = new Job(conf);
+ job.setJobName("DictionaryVectorizer::MakePartialVectors: input-folder: " + input
+ + ", dictionary-file: " + dictionaryFilePath);
+ job.setJarByClass(DictionaryVectorizer.class);
+
+ job.setMapOutputKeyClass(Text.class);
+ job.setMapOutputValueClass(StringTuple.class);
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(VectorWritable.class);
+ FileInputFormat.setInputPaths(job, input);
+
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setMapperClass(Mapper.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setReducerClass(TFPartialVectorReducer.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setNumReduceTasks(numReducers);
+
+ HadoopUtil.delete(conf, output);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+ /**
+ * Count the frequencies of words in parallel using Map/Reduce. The input documents have to be in
+ * {@link SequenceFile} format
+ */
+ private static void startWordCounting(Path input, Path output, Configuration baseConf, int minSupport)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ Configuration conf = new Configuration(baseConf);
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ conf.setInt(MIN_SUPPORT, minSupport);
+
+ Job job = new Job(conf);
+
+ job.setJobName("DictionaryVectorizer::WordCount: input-folder: " + input);
+ job.setJarByClass(DictionaryVectorizer.class);
+
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(LongWritable.class);
+
+ FileInputFormat.setInputPaths(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setMapperClass(TermCountMapper.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setCombinerClass(TermCountCombiner.class);
+ job.setReducerClass(TermCountReducer.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ HadoopUtil.delete(conf, output);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption("tfDirName", "tf", "The folder to store the TF calculations", "tfDirName");
+ addOption("minSupport", "s", "(Optional) Minimum Support. Default Value: 2", "2");
+ addOption("maxNGramSize", "ng", "(Optional) The maximum size of ngrams to create"
+ + " (2 = bigrams, 3 = trigrams, etc) Default Value:1");
+ addOption("minLLR", "ml", "(Optional)The minimum Log Likelihood Ratio(Float) Default is "
+ + LLRReducer.DEFAULT_MIN_LLR);
+ addOption("norm", "n", "The norm to use, expressed as either a float or \"INF\" "
+ + "if you want to use the Infinite norm. "
+ + "Must be greater or equal to 0. The default is not to normalize");
+ addOption("logNormalize", "lnorm", "(Optional) Whether output vectors should be logNormalize. "
+ + "If set true else false", "false");
+ addOption(DefaultOptionCreator.numReducersOption().create());
+ addOption("chunkSize", "chunk", "The chunkSize in MegaBytes. 100-10000 MB", "100");
+ addOption(DefaultOptionCreator.methodOption().create());
+ addOption("namedVector", "nv", "(Optional) Whether output vectors should be NamedVectors. "
+ + "If set true else false", "false");
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+ String tfDirName = getOption("tfDirName", "tfDir");
+ int minSupport = getInt("minSupport", 2);
+ int maxNGramSize = getInt("maxNGramSize", 1);
+ float minLLRValue = getFloat("minLLR", LLRReducer.DEFAULT_MIN_LLR);
+ float normPower = getFloat("norm", PartialVectorMerger.NO_NORMALIZING);
+ boolean logNormalize = hasOption("logNormalize");
+ int numReducers = getInt(DefaultOptionCreator.MAX_REDUCERS_OPTION);
+ int chunkSizeInMegs = getInt("chunkSize", 100);
+ boolean sequential = hasOption("sequential");
+ boolean namedVecs = hasOption("namedVectors");
+ //TODO: add support for other paths
+ createTermFrequencyVectors(getInputPath(), getOutputPath(),
+ tfDirName,
+ getConf(), minSupport, maxNGramSize, minLLRValue,
+ normPower, logNormalize, numReducers, chunkSizeInMegs, sequential, namedVecs);
+ return 0;
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new DictionaryVectorizer(), args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/DocumentProcessor.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/DocumentProcessor.java b/mr/src/main/java/org/apache/mahout/vectorizer/DocumentProcessor.java
new file mode 100644
index 0000000..2c3c236
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/DocumentProcessor.java
@@ -0,0 +1,99 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.vectorizer.document.SequenceFileTokenizerMapper;
+
+/**
+ * This class converts a set of input documents in the sequence file format of {@link StringTuple}s.The
+ * {@link org.apache.hadoop.io.SequenceFile} input should have a {@link Text} key
+ * containing the unique document identifier and a
+ * {@link Text} value containing the whole document. The document should be stored in UTF-8 encoding which is
+ * recognizable by hadoop. It uses the given {@link Analyzer} to process the document into
+ * {@link org.apache.lucene.analysis.Token}s.
+ *
+ */
+public final class DocumentProcessor {
+
+ public static final String TOKENIZED_DOCUMENT_OUTPUT_FOLDER = "tokenized-documents";
+ public static final String ANALYZER_CLASS = "analyzer.class";
+
+ /**
+ * Cannot be initialized. Use the static functions
+ */
+ private DocumentProcessor() {
+
+ }
+
+ /**
+ * Convert the input documents into token array using the {@link StringTuple} The input documents has to be
+ * in the {@link org.apache.hadoop.io.SequenceFile} format
+ *
+ * @param input
+ * input directory of the documents in {@link org.apache.hadoop.io.SequenceFile} format
+ * @param output
+ * output directory were the {@link StringTuple} token array of each document has to be created
+ * @param analyzerClass
+ * The Lucene {@link Analyzer} for tokenizing the UTF-8 text
+ */
+ public static void tokenizeDocuments(Path input,
+ Class<? extends Analyzer> analyzerClass,
+ Path output,
+ Configuration baseConf)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Configuration conf = new Configuration(baseConf);
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ conf.set(ANALYZER_CLASS, analyzerClass.getName());
+
+ Job job = new Job(conf);
+ job.setJobName("DocumentProcessor::DocumentTokenizer: input-folder: " + input);
+ job.setJarByClass(DocumentProcessor.class);
+
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(StringTuple.class);
+ FileInputFormat.setInputPaths(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setMapperClass(SequenceFileTokenizerMapper.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setNumReduceTasks(0);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ HadoopUtil.delete(conf, output);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/EncodedVectorsFromSequenceFiles.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/EncodedVectorsFromSequenceFiles.java b/mr/src/main/java/org/apache/mahout/vectorizer/EncodedVectorsFromSequenceFiles.java
new file mode 100644
index 0000000..1cf7ad7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/EncodedVectorsFromSequenceFiles.java
@@ -0,0 +1,104 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.LuceneTextValueEncoder;
+
+/**
+ * Converts a given set of sequence files into SparseVectors
+ */
+public final class EncodedVectorsFromSequenceFiles extends AbstractJob {
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new EncodedVectorsFromSequenceFiles(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.analyzerOption().create());
+ addOption(buildOption("sequentialAccessVector", "seq",
+ "(Optional) Whether output vectors should be SequentialAccessVectors. "
+ + "If set true else false",
+ false, false, null));
+ addOption(buildOption("namedVector", "nv",
+ "Create named vectors using the key. False by default", false, false, null));
+ addOption("cardinality", "c",
+ "The cardinality to use for creating the vectors. Default is 5000", "5000");
+ addOption("encoderFieldName", "en",
+ "The name of the encoder to be passed to the FeatureVectorEncoder constructor. Default is text. "
+ + "Note this is not the class name of a FeatureValueEncoder, but is instead the construction "
+ + "argument.",
+ "text");
+ addOption("encoderClass", "ec",
+ "The class name of the encoder to be used. Default is " + LuceneTextValueEncoder.class.getName(),
+ LuceneTextValueEncoder.class.getName());
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path output = getOutputPath();
+
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+
+ Class<? extends Analyzer> analyzerClass = getAnalyzerClassFromOption();
+
+ Configuration conf = getConf();
+
+ boolean sequentialAccessOutput = hasOption("sequentialAccessVector");
+
+ boolean namedVectors = hasOption("namedVector");
+ int cardinality = 5000;
+ if (hasOption("cardinality")) {
+ cardinality = Integer.parseInt(getOption("cardinality"));
+ }
+ String encoderName = "text";
+ if (hasOption("encoderFieldName")) {
+ encoderName = getOption("encoderFieldName");
+ }
+ String encoderClass = LuceneTextValueEncoder.class.getName();
+ if (hasOption("encoderClass")) {
+ encoderClass = getOption("encoderClass");
+ ClassUtils.instantiateAs(encoderClass, FeatureVectorEncoder.class, new Class[] { String.class },
+ new Object[] { encoderName }); //try instantiating it
+ }
+
+ SimpleTextEncodingVectorizer vectorizer = new SimpleTextEncodingVectorizer();
+ VectorizerConfig config = new VectorizerConfig(conf, analyzerClass.getName(), encoderClass, encoderName,
+ sequentialAccessOutput, namedVectors, cardinality);
+
+ vectorizer.createVectors(input, output, config);
+
+ return 0;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/EncodingMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/EncodingMapper.java b/mr/src/main/java/org/apache/mahout/vectorizer/EncodingMapper.java
new file mode 100644
index 0000000..63ccea4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/EncodingMapper.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.vectorizer;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.lucene.AnalyzerUtils;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.LuceneTextValueEncoder;
+
+import java.io.IOException;
+
+/**
+ * The Mapper that does the work of encoding text
+ */
+public class EncodingMapper extends Mapper<Text, Text, Text, VectorWritable> {
+
+ public static final String USE_NAMED_VECTORS = "namedVectors";
+ public static final String USE_SEQUENTIAL = "sequential";
+ public static final String ANALYZER_NAME = "analyzer";
+ public static final String ENCODER_FIELD_NAME = "encoderFieldName";
+ public static final String ENCODER_CLASS = "encoderClass";
+ public static final String CARDINALITY = "cardinality";
+ private boolean sequentialVectors;
+ private boolean namedVectors;
+ private FeatureVectorEncoder encoder;
+ private int cardinality;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+ sequentialVectors = conf.getBoolean(USE_SEQUENTIAL, false);
+ namedVectors = conf.getBoolean(USE_NAMED_VECTORS, false);
+ String analyzerName = conf.get(ANALYZER_NAME, StandardAnalyzer.class.getName());
+ Analyzer analyzer;
+ try {
+ analyzer = AnalyzerUtils.createAnalyzer(analyzerName);
+ } catch (ClassNotFoundException e) {
+ //TODO: hmmm, don't like this approach
+ throw new IOException("Unable to create Analyzer for name: " + analyzerName, e);
+ }
+
+ String encoderName = conf.get(ENCODER_FIELD_NAME, "text");
+ cardinality = conf.getInt(CARDINALITY, 5000);
+ String encClass = conf.get(ENCODER_CLASS);
+ encoder = ClassUtils.instantiateAs(encClass,
+ FeatureVectorEncoder.class,
+ new Class[]{String.class},
+ new Object[]{encoderName});
+ if (encoder instanceof LuceneTextValueEncoder) {
+ ((LuceneTextValueEncoder) encoder).setAnalyzer(analyzer);
+ }
+ }
+
+ @Override
+ protected void map(Text key, Text value, Context context) throws IOException, InterruptedException {
+ Vector vector;
+ if (sequentialVectors) {
+ vector = new SequentialAccessSparseVector(cardinality);
+ } else {
+ vector = new RandomAccessSparseVector(cardinality);
+ }
+ if (namedVectors) {
+ vector = new NamedVector(vector, key.toString());
+ }
+ encoder.addToVector(value.toString(), vector);
+ context.write(new Text(key.toString()), new VectorWritable(vector));
+ }
+}
[25/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
new file mode 100644
index 0000000..4bffb2b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
+import org.apache.mahout.math.neighborhood.ProjectionSearch;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+
+public final class StreamingKMeansUtilsMR {
+
+ private StreamingKMeansUtilsMR() {
+ }
+
+ /**
+ * Instantiates a searcher from a given configuration.
+ * @param conf the configuration
+ * @return the instantiated searcher
+ * @throws RuntimeException if the distance measure class cannot be instantiated
+ * @throws IllegalStateException if an unknown searcher class was requested
+ */
+ public static UpdatableSearcher searcherFromConfiguration(Configuration conf) {
+ DistanceMeasure distanceMeasure;
+ String distanceMeasureClass = conf.get(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ try {
+ distanceMeasure = (DistanceMeasure) Class.forName(distanceMeasureClass).getConstructor().newInstance();
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to instantiate distanceMeasure", e);
+ }
+
+ int numProjections = conf.getInt(StreamingKMeansDriver.NUM_PROJECTIONS_OPTION, 20);
+ int searchSize = conf.getInt(StreamingKMeansDriver.SEARCH_SIZE_OPTION, 10);
+
+ String searcherClass = conf.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION);
+
+ if (searcherClass.equals(BruteSearch.class.getName())) {
+ return ClassUtils.instantiateAs(searcherClass, UpdatableSearcher.class,
+ new Class[]{DistanceMeasure.class}, new Object[]{distanceMeasure});
+ } else if (searcherClass.equals(FastProjectionSearch.class.getName())
+ || searcherClass.equals(ProjectionSearch.class.getName())) {
+ return ClassUtils.instantiateAs(searcherClass, UpdatableSearcher.class,
+ new Class[]{DistanceMeasure.class, int.class, int.class},
+ new Object[]{distanceMeasure, numProjections, searchSize});
+ } else if (searcherClass.equals(LocalitySensitiveHashSearch.class.getName())) {
+ return ClassUtils.instantiateAs(searcherClass, LocalitySensitiveHashSearch.class,
+ new Class[]{DistanceMeasure.class, int.class},
+ new Object[]{distanceMeasure, searchSize});
+ } else {
+ throw new IllegalStateException("Unknown class instantiation requested");
+ }
+ }
+
+ /**
+ * Returns an Iterable of centroids from an Iterable of VectorWritables by creating a new Centroid containing
+ * a RandomAccessSparseVector as a delegate for each VectorWritable.
+ * @param inputIterable VectorWritable Iterable to get Centroids from
+ * @return the new Centroids
+ */
+ public static Iterable<Centroid> getCentroidsFromVectorWritable(Iterable<VectorWritable> inputIterable) {
+ return Iterables.transform(inputIterable, new Function<VectorWritable, Centroid>() {
+ private int numVectors = 0;
+ @Override
+ public Centroid apply(VectorWritable input) {
+ Preconditions.checkNotNull(input);
+ return new Centroid(numVectors++, new RandomAccessSparseVector(input.get()), 1);
+ }
+ });
+ }
+
+ /**
+ * Returns an Iterable of Centroid from an Iterable of Vector by either casting each Vector to Centroid (if the
+ * instance extends Centroid) or create a new Centroid based on that Vector.
+ * The implicit expectation is that the input will not have interleaving types of vectors. Otherwise, the numbering
+ * of new Centroids will become invalid.
+ * @param input Iterable of Vectors to cast
+ * @return the new Centroids
+ */
+ public static Iterable<Centroid> castVectorsToCentroids(Iterable<Vector> input) {
+ return Iterables.transform(input, new Function<Vector, Centroid>() {
+ private int numVectors = 0;
+ @Override
+ public Centroid apply(Vector input) {
+ Preconditions.checkNotNull(input);
+ if (input instanceof Centroid) {
+ return (Centroid) input;
+ } else {
+ return new Centroid(numVectors++, input, 1);
+ }
+ }
+ });
+ }
+
+ /**
+ * Writes centroids to a sequence file.
+ * @param centroids the centroids to write.
+ * @param path the path of the output file.
+ * @param conf the configuration for the HDFS to write the file to.
+ * @throws java.io.IOException
+ */
+ public static void writeCentroidsToSequenceFile(Iterable<Centroid> centroids, Path path, Configuration conf)
+ throws IOException {
+ SequenceFile.Writer writer = null;
+ try {
+ writer = SequenceFile.createWriter(FileSystem.get(conf), conf,
+ path, IntWritable.class, CentroidWritable.class);
+ int i = 0;
+ for (Centroid centroid : centroids) {
+ writer.append(new IntWritable(i++), new CentroidWritable(centroid));
+ }
+ } finally {
+ Closeables.close(writer, true);
+ }
+ }
+
+ public static void writeVectorsToSequenceFile(Iterable<? extends Vector> datapoints, Path path, Configuration conf)
+ throws IOException {
+ SequenceFile.Writer writer = null;
+ try {
+ writer = SequenceFile.createWriter(FileSystem.get(conf), conf,
+ path, IntWritable.class, VectorWritable.class);
+ int i = 0;
+ for (Vector vector : datapoints) {
+ writer.append(new IntWritable(i++), new VectorWritable(vector));
+ }
+ } finally {
+ Closeables.close(writer, true);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFiles.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFiles.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFiles.java
new file mode 100644
index 0000000..55b7848
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFiles.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.tools;
+
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.Iterator;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Iterables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+
+public class ResplitSequenceFiles {
+
+ private String inputFile;
+ private String outputFileBase;
+ private int numSplits;
+
+ private Configuration conf;
+ private FileSystem fs;
+
+ private ResplitSequenceFiles() {}
+
+ private void writeSplit(Iterator<Pair<Writable, Writable>> inputIterator,
+ int numSplit, int numEntriesPerSplit) throws IOException {
+ SequenceFile.Writer splitWriter = null;
+ for (int j = 0; j < numEntriesPerSplit; ++j) {
+ Pair<Writable, Writable> item = inputIterator.next();
+ if (splitWriter == null) {
+ splitWriter = SequenceFile.createWriter(fs, conf,
+ new Path(outputFileBase + "-" + numSplit), item.getFirst().getClass(), item.getSecond().getClass());
+ }
+ splitWriter.append(item.getFirst(), item.getSecond());
+ }
+ if (splitWriter != null) {
+ splitWriter.close();
+ }
+ }
+
+ private void run(PrintWriter printWriter) throws IOException {
+ conf = new Configuration();
+ SequenceFileDirIterable<Writable, Writable> inputIterable = new
+ SequenceFileDirIterable<Writable, Writable>(new Path(inputFile), PathType.LIST, conf);
+ fs = FileSystem.get(conf);
+
+ int numEntries = Iterables.size(inputIterable);
+ int numEntriesPerSplit = numEntries / numSplits;
+ int numEntriesLastSplit = numEntriesPerSplit + numEntries - numEntriesPerSplit * numSplits;
+ Iterator<Pair<Writable, Writable>> inputIterator = inputIterable.iterator();
+
+ printWriter.printf("Writing %d splits\n", numSplits);
+ for (int i = 0; i < numSplits - 1; ++i) {
+ printWriter.printf("Writing split %d\n", i);
+ writeSplit(inputIterator, i, numEntriesPerSplit);
+ }
+ printWriter.printf("Writing split %d\n", numSplits - 1);
+ writeSplit(inputIterator, numSplits - 1, numEntriesLastSplit);
+ }
+
+ private boolean parseArgs(String[] args) {
+ DefaultOptionBuilder builder = new DefaultOptionBuilder();
+
+ Option help = builder.withLongName("help").withDescription("print this list").create();
+
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+ Option inputFileOption = builder.withLongName("input")
+ .withShortName("i")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("input").withMaximum(1).create())
+ .withDescription("what the base folder for sequence files is (they all must have the same key/value type")
+ .create();
+
+ Option outputFileOption = builder.withLongName("output")
+ .withShortName("o")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("output").withMaximum(1).create())
+ .withDescription("the base name of the file split that the files will be split it; the i'th split has the "
+ + "suffix -i")
+ .create();
+
+ Option numSplitsOption = builder.withLongName("numSplits")
+ .withShortName("ns")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("numSplits").withMaximum(1).create())
+ .withDescription("how many splits to use for the given files")
+ .create();
+
+ Group normalArgs = new GroupBuilder()
+ .withOption(help)
+ .withOption(inputFileOption)
+ .withOption(outputFileOption)
+ .withOption(numSplitsOption)
+ .create();
+
+ Parser parser = new Parser();
+ parser.setHelpOption(help);
+ parser.setHelpTrigger("--help");
+ parser.setGroup(normalArgs);
+ parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130));
+ CommandLine cmdLine = parser.parseAndHelp(args);
+
+ if (cmdLine == null) {
+ return false;
+ }
+
+ inputFile = (String) cmdLine.getValue(inputFileOption);
+ outputFileBase = (String) cmdLine.getValue(outputFileOption);
+ numSplits = Integer.parseInt((String) cmdLine.getValue(numSplitsOption));
+ return true;
+ }
+
+ public static void main(String[] args) throws IOException {
+ ResplitSequenceFiles runner = new ResplitSequenceFiles();
+ if (runner.parseArgs(args)) {
+ runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/topdown/PathDirectory.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/topdown/PathDirectory.java b/mr/src/main/java/org/apache/mahout/clustering/topdown/PathDirectory.java
new file mode 100644
index 0000000..11bc34a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/topdown/PathDirectory.java
@@ -0,0 +1,94 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.topdown;
+
+import java.io.File;
+
+import org.apache.hadoop.fs.Path;
+
+/**
+ * Contains list of all internal paths used in top down clustering.
+ */
+public final class PathDirectory {
+
+ public static final String TOP_LEVEL_CLUSTER_DIRECTORY = "topLevelCluster";
+ public static final String POST_PROCESS_DIRECTORY = "clusterPostProcessed";
+ public static final String CLUSTERED_POINTS_DIRECTORY = "clusteredPoints";
+ public static final String BOTTOM_LEVEL_CLUSTER_DIRECTORY = "bottomLevelCluster";
+
+ private PathDirectory() {
+ }
+
+ /**
+ * All output of top level clustering is stored in output directory/topLevelCluster.
+ *
+ * @param output
+ * the output path of clustering.
+ * @return The top level Cluster Directory.
+ */
+ public static Path getTopLevelClusterPath(Path output) {
+ return new Path(output + File.separator + TOP_LEVEL_CLUSTER_DIRECTORY);
+ }
+
+ /**
+ * The output of top level clusters is post processed and kept in this path.
+ *
+ * @param outputPathProvidedByUser
+ * the output path of clustering.
+ * @return the path where the output of top level cluster post processor is kept.
+ */
+ public static Path getClusterPostProcessorOutputDirectory(Path outputPathProvidedByUser) {
+ return new Path(outputPathProvidedByUser + File.separator + POST_PROCESS_DIRECTORY);
+ }
+
+ /**
+ * The top level clustered points before post processing is generated here.
+ *
+ * @param output
+ * the output path of clustering.
+ * @return the clustered points directory
+ */
+ public static Path getClusterOutputClusteredPoints(Path output) {
+ return new Path(output + File.separator + CLUSTERED_POINTS_DIRECTORY + File.separator, "*");
+ }
+
+ /**
+ * Each cluster produced by top level clustering is processed in output/"bottomLevelCluster"/clusterId.
+ *
+ * @param output
+ * @param clusterId
+ * @return the bottom level clustering path.
+ */
+ public static Path getBottomLevelClusterPath(Path output, String clusterId) {
+ return new Path(output + File.separator + BOTTOM_LEVEL_CLUSTER_DIRECTORY + File.separator + clusterId);
+ }
+
+ /**
+ * Each clusters path name is its clusterId. The vectors reside in separate files inside it.
+ *
+ * @param clusterPostProcessorOutput
+ * the path of cluster post processor output.
+ * @param clusterId
+ * the id of the cluster.
+ * @return the cluster path for cluster id.
+ */
+ public static Path getClusterPathForClusterId(Path clusterPostProcessorOutput, String clusterId) {
+ return new Path(clusterPostProcessorOutput + File.separator + clusterId);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java b/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java
new file mode 100644
index 0000000..083b543
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java
@@ -0,0 +1,103 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.topdown.postprocessor;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * Reads the number of clusters produced by the clustering algorithm.
+ */
+public final class ClusterCountReader {
+
+ private ClusterCountReader() {
+ }
+
+ /**
+ * Reads the number of clusters present by reading the clusters-*-final file.
+ *
+ * @param clusterOutputPath The output path provided to the clustering algorithm.
+ * @param conf The hadoop configuration.
+ * @return the number of final clusters.
+ */
+ public static int getNumberOfClusters(Path clusterOutputPath, Configuration conf) throws IOException {
+ FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
+ FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
+ int numberOfClusters = 0;
+ Iterator<?> it = new SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(),
+ PathType.LIST,
+ PathFilters.partFilter(),
+ null,
+ true,
+ conf);
+ while (it.hasNext()) {
+ it.next();
+ numberOfClusters++;
+ }
+ return numberOfClusters;
+ }
+
+ /**
+ * Generates a list of all cluster ids by reading the clusters-*-final file.
+ *
+ * @param clusterOutputPath The output path provided to the clustering algorithm.
+ * @param conf The hadoop configuration.
+ * @return An ArrayList containing the final cluster ids.
+ */
+ public static Map<Integer, Integer> getClusterIDs(Path clusterOutputPath, Configuration conf, boolean keyIsClusterId)
+ throws IOException {
+ Map<Integer, Integer> clusterIds = new HashMap<>();
+ FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
+ FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
+ //System.out.println("LOOK HERE: " + clusterOutputPath);
+ Iterator<ClusterWritable> it = new SequenceFileDirValueIterator<>(clusterFiles[0].getPath(),
+ PathType.LIST,
+ PathFilters.partFilter(),
+ null,
+ true,
+ conf);
+ int i = 0;
+ while (it.hasNext()) {
+ Integer key;
+ Integer value;
+ if (keyIsClusterId) { // key is the cluster id, value is i, the index we will use
+ key = it.next().getValue().getId();
+ value = i;
+ } else {
+ key = i;
+ value = it.next().getValue().getId();
+ }
+ clusterIds.put(key, value);
+ i++;
+ }
+ return clusterIds;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessor.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessor.java b/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessor.java
new file mode 100644
index 0000000..44a944d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessor.java
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.topdown.postprocessor;
+
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.Writer;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.classify.WeightedVectorWritable;
+import org.apache.mahout.clustering.topdown.PathDirectory;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.Map;
+
+/**
+ * This class reads the output of any clustering algorithm, and, creates separate directories for different
+ * clusters. Each cluster directory's name is its clusterId. Each and every point is written in the cluster
+ * directory associated with that point.
+ * <p/>
+ * This class incorporates a sequential algorithm and is appropriate for use for data which has been clustered
+ * sequentially.
+ * <p/>
+ * The sequential and non sequential version, both are being used from {@link ClusterOutputPostProcessorDriver}.
+ */
+public final class ClusterOutputPostProcessor {
+
+ private Path clusteredPoints;
+ private final FileSystem fileSystem;
+ private final Configuration conf;
+ private final Path clusterPostProcessorOutput;
+ private final Map<String, Path> postProcessedClusterDirectories = Maps.newHashMap();
+ private long uniqueVectorId = 0L;
+ private final Map<String, SequenceFile.Writer> writersForClusters;
+
+ public ClusterOutputPostProcessor(Path clusterOutputToBeProcessed,
+ Path output,
+ Configuration hadoopConfiguration) throws IOException {
+ this.clusterPostProcessorOutput = output;
+ this.clusteredPoints = PathDirectory.getClusterOutputClusteredPoints(clusterOutputToBeProcessed);
+ this.conf = hadoopConfiguration;
+ this.writersForClusters = Maps.newHashMap();
+ fileSystem = clusteredPoints.getFileSystem(conf);
+ }
+
+ /**
+ * This method takes the clustered points output by the clustering algorithms as input and writes them into
+ * their respective clusters.
+ */
+ public void process() throws IOException {
+ createPostProcessDirectory();
+ for (Pair<?, WeightedVectorWritable> record
+ : new SequenceFileDirIterable<Writable, WeightedVectorWritable>(clusteredPoints, PathType.GLOB, PathFilters.partFilter(),
+ null, false, conf)) {
+ String clusterId = record.getFirst().toString().trim();
+ putVectorInRespectiveCluster(clusterId, record.getSecond());
+ }
+ IOUtils.close(writersForClusters.values());
+ writersForClusters.clear();
+ }
+
+ /**
+ * Creates the directory to put post processed clusters.
+ */
+ private void createPostProcessDirectory() throws IOException {
+ if (!fileSystem.exists(clusterPostProcessorOutput)
+ && !fileSystem.mkdirs(clusterPostProcessorOutput)) {
+ throw new IOException("Error creating cluster post processor directory");
+ }
+ }
+
+ /**
+ * Finds out the cluster directory of the vector and writes it into the specified cluster.
+ */
+ private void putVectorInRespectiveCluster(String clusterId, WeightedVectorWritable point) throws IOException {
+ Writer writer = findWriterForVector(clusterId);
+ postProcessedClusterDirectories.put(clusterId,
+ PathDirectory.getClusterPathForClusterId(clusterPostProcessorOutput, clusterId));
+ writeVectorToCluster(writer, point);
+ }
+
+ /**
+ * Finds out the path in cluster where the point is supposed to be written.
+ */
+ private Writer findWriterForVector(String clusterId) throws IOException {
+ Path clusterDirectory = PathDirectory.getClusterPathForClusterId(clusterPostProcessorOutput, clusterId);
+ Writer writer = writersForClusters.get(clusterId);
+ if (writer == null) {
+ Path pathToWrite = new Path(clusterDirectory, new Path("part-m-0"));
+ writer = new Writer(fileSystem, conf, pathToWrite, LongWritable.class, VectorWritable.class);
+ writersForClusters.put(clusterId, writer);
+ }
+ return writer;
+ }
+
+ /**
+ * Writes vector to the cluster directory.
+ */
+ private void writeVectorToCluster(Writer writer, WeightedVectorWritable point) throws IOException {
+ writer.append(new LongWritable(uniqueVectorId++), new VectorWritable(point.getVector()));
+ writer.sync();
+ }
+
+ /**
+ * @return the set of all post processed cluster paths.
+ */
+ public Map<String, Path> getPostProcessedClusterDirectories() {
+ return postProcessedClusterDirectories;
+ }
+
+ public void setClusteredPoints(Path clusteredPoints) {
+ this.clusteredPoints = clusteredPoints;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorDriver.java b/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorDriver.java
new file mode 100644
index 0000000..82a3071
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorDriver.java
@@ -0,0 +1,182 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.topdown.postprocessor;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterator;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+/**
+ * Post processes the output of clustering algorithms and groups them into respective clusters. Ideal to be
+ * used for top down clustering. It can also be used if the clustering output needs to be grouped into their
+ * respective clusters.
+ */
+public final class ClusterOutputPostProcessorDriver extends AbstractJob {
+
+ /**
+ * CLI to run clustering post processor. The input to post processor is the ouput path specified to the
+ * clustering.
+ */
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.methodOption().create());
+ addOption(DefaultOptionCreator.overwriteOption().create());
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+ Path input = getInputPath();
+ Path output = getOutputPath();
+
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
+ DefaultOptionCreator.SEQUENTIAL_METHOD);
+ run(input, output, runSequential);
+ return 0;
+
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new ClusterOutputPostProcessorDriver(), args);
+ }
+
+ /**
+ * Post processes the output of clustering algorithms and groups them into respective clusters. Each
+ * cluster's vectors are written into a directory named after its clusterId.
+ *
+ * @param input The output path provided to the clustering algorithm, whose would be post processed. Hint: The
+ * path of the directory containing clusters-*-final and clusteredPoints.
+ * @param output The post processed data would be stored at this path.
+ * @param runSequential If set to true, post processes it sequentially, else, uses. MapReduce. Hint: If the clustering
+ * was done sequentially, make it sequential, else vice versa.
+ */
+ public static void run(Path input, Path output, boolean runSequential) throws IOException,
+ InterruptedException,
+ ClassNotFoundException {
+ if (runSequential) {
+ postProcessSeq(input, output);
+ } else {
+ Configuration conf = new Configuration();
+ postProcessMR(conf, input, output);
+ movePartFilesToRespectiveDirectories(conf, output);
+ }
+
+ }
+
+ /**
+ * Process Sequentially. Reads the vectors one by one, and puts them into respective directory, named after
+ * their clusterId.
+ *
+ * @param input The output path provided to the clustering algorithm, whose would be post processed. Hint : The
+ * path of the directory containing clusters-*-final and clusteredPoints.
+ * @param output The post processed data would be stored at this path.
+ */
+ private static void postProcessSeq(Path input, Path output) throws IOException {
+ ClusterOutputPostProcessor clusterOutputPostProcessor = new ClusterOutputPostProcessor(input, output,
+ new Configuration());
+ clusterOutputPostProcessor.process();
+ }
+
+ /**
+ * Process as a map reduce job. The numberOfReduceTasks is set to the number of clusters present in the
+ * output. So that each cluster's vector is written in its own part file.
+ *
+ * @param conf The hadoop configuration.
+ * @param input The output path provided to the clustering algorithm, whose would be post processed. Hint : The
+ * path of the directory containing clusters-*-final and clusteredPoints.
+ * @param output The post processed data would be stored at this path.
+ */
+ private static void postProcessMR(Configuration conf, Path input, Path output) throws IOException,
+ InterruptedException,
+ ClassNotFoundException {
+ System.out.println("WARNING: If you are running in Hadoop local mode, please use the --sequential option, "
+ + "as the MapReduce option will not work properly");
+ int numberOfClusters = ClusterCountReader.getNumberOfClusters(input, conf);
+ conf.set("clusterOutputPath", input.toString());
+ Job job = new Job(conf, "ClusterOutputPostProcessor Driver running over input: " + input);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setMapperClass(ClusterOutputPostProcessorMapper.class);
+ job.setMapOutputKeyClass(IntWritable.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+ job.setReducerClass(ClusterOutputPostProcessorReducer.class);
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setNumReduceTasks(numberOfClusters);
+ job.setJarByClass(ClusterOutputPostProcessorDriver.class);
+
+ FileInputFormat.addInputPath(job, new Path(input, new Path("clusteredPoints")));
+ FileOutputFormat.setOutputPath(job, output);
+ if (!job.waitForCompletion(true)) {
+ throw new InterruptedException("ClusterOutputPostProcessor Job failed processing " + input);
+ }
+ }
+
+ /**
+ * The mapreduce version of the post processor writes different clusters into different part files. This
+ * method reads the part files and moves them into directories named after their clusterIds.
+ *
+ * @param conf The hadoop configuration.
+ * @param output The post processed data would be stored at this path.
+ */
+ private static void movePartFilesToRespectiveDirectories(Configuration conf, Path output) throws IOException {
+ FileSystem fileSystem = output.getFileSystem(conf);
+ for (FileStatus fileStatus : fileSystem.listStatus(output, PathFilters.partFilter())) {
+ SequenceFileIterator<Writable, Writable> it =
+ new SequenceFileIterator<>(fileStatus.getPath(), true, conf);
+ if (it.hasNext()) {
+ renameFile(it.next().getFirst(), fileStatus, conf);
+ }
+ it.close();
+ }
+ }
+
+ /**
+ * Using @FileSystem rename method to move the file.
+ */
+ private static void renameFile(Writable key, FileStatus fileStatus, Configuration conf) throws IOException {
+ Path path = fileStatus.getPath();
+ FileSystem fileSystem = path.getFileSystem(conf);
+ Path subDir = new Path(key.toString());
+ Path renameTo = new Path(path.getParent(), subDir);
+ fileSystem.mkdirs(renameTo);
+ fileSystem.rename(path, renameTo);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorMapper.java b/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorMapper.java
new file mode 100644
index 0000000..6834362
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorMapper.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.topdown.postprocessor;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.classify.WeightedVectorWritable;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.Map;
+
+/**
+ * Mapper for post processing cluster output.
+ */
+public class ClusterOutputPostProcessorMapper extends
+ Mapper<IntWritable, WeightedVectorWritable, IntWritable, VectorWritable> {
+
+ private Map<Integer, Integer> newClusterMappings;
+ private VectorWritable outputVector;
+
+ //read the current cluster ids, and populate the cluster mapping hash table
+ @Override
+ public void setup(Context context) throws IOException {
+ Configuration conf = context.getConfiguration();
+ //this give the clusters-x-final directory where the cluster ids can be read
+ Path clusterOutputPath = new Path(conf.get("clusterOutputPath"));
+ //we want the key to be the cluster id, the value to be the index
+ newClusterMappings = ClusterCountReader.getClusterIDs(clusterOutputPath, conf, true);
+ outputVector = new VectorWritable();
+ }
+
+ @Override
+ public void map(IntWritable key, WeightedVectorWritable val, Context context)
+ throws IOException, InterruptedException {
+ // by pivoting on the cluster mapping value, we can make sure that each unique cluster goes to it's own reducer,
+ // since they are numbered from 0 to k-1, where k is the number of clusters
+ outputVector.set(val.getVector());
+ context.write(new IntWritable(newClusterMappings.get(key.get())), outputVector);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorReducer.java b/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorReducer.java
new file mode 100644
index 0000000..58dada4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorReducer.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.topdown.postprocessor;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.Map;
+
+/**
+ * Reducer for post processing cluster output.
+ */
+public class ClusterOutputPostProcessorReducer
+ extends Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {
+
+ private Map<Integer, Integer> reverseClusterMappings;
+
+ //read the current cluster ids, and populate the hash cluster mapping hash table
+ @Override
+ public void setup(Context context) throws IOException {
+ Configuration conf = context.getConfiguration();
+ Path clusterOutputPath = new Path(conf.get("clusterOutputPath"));
+ //we want to the key to be the index, the value to be the cluster id
+ reverseClusterMappings = ClusterCountReader.getClusterIDs(clusterOutputPath, conf, false);
+ }
+
+ /**
+ * The key is the remapped cluster id and the values contains the vectors in that cluster.
+ */
+ @Override
+ protected void reduce(IntWritable key, Iterable<VectorWritable> values, Context context) throws IOException,
+ InterruptedException {
+ //remap the cluster back to its original id
+ //and then output the vectors with their correct
+ //cluster id.
+ IntWritable outKey = new IntWritable(reverseClusterMappings.get(key.get()));
+ System.out.println(outKey + " this: " + this);
+ for (VectorWritable value : values) {
+ context.write(outKey, value);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/AbstractJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/AbstractJob.java b/mr/src/main/java/org/apache/mahout/common/AbstractJob.java
new file mode 100644
index 0000000..ec77749
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/AbstractJob.java
@@ -0,0 +1,658 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.conf.Configured;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.InputFormat;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.OutputFormat;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.Tool;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.lucene.AnalyzerUtils;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+
+/**
+ * <p>Superclass of many Mahout Hadoop "jobs". A job drives configuration and launch of one or
+ * more maps and reduces in order to accomplish some task.</p>
+ *
+ * <p>Command line arguments available to all subclasses are:</p>
+ *
+ * <ul>
+ * <li>--tempDir (path): Specifies a directory where the job may place temp files
+ * (default "temp")</li>
+ * <li>--help: Show help message</li>
+ * </ul>
+ *
+ * <p>In addition, note some key command line parameters that are parsed by Hadoop, which jobs
+ * may need to set:</p>
+ *
+ * <ul>
+ * <li>-Dmapred.job.name=(name): Sets the Hadoop task names. It will be suffixed by
+ * the mapper and reducer class names</li>
+ * <li>-Dmapred.output.compress={true,false}: Compress final output (default true)</li>
+ * <li>-Dmapred.input.dir=(path): input file, or directory containing input files (required)</li>
+ * <li>-Dmapred.output.dir=(path): path to write output files (required)</li>
+ * </ul>
+ *
+ * <p>Note that because of how Hadoop parses arguments, all "-D" arguments must appear before all other
+ * arguments.</p>
+ */
+public abstract class AbstractJob extends Configured implements Tool {
+
+ private static final Logger log = LoggerFactory.getLogger(AbstractJob.class);
+
+ /** option used to specify the input path */
+ private Option inputOption;
+
+ /** option used to specify the output path */
+ private Option outputOption;
+
+ /** input path, populated by {@link #parseArguments(String[])} */
+ protected Path inputPath;
+ protected File inputFile; //the input represented as a file
+
+ /** output path, populated by {@link #parseArguments(String[])} */
+ protected Path outputPath;
+ protected File outputFile; //the output represented as a file
+
+ /** temp path, populated by {@link #parseArguments(String[])} */
+ protected Path tempPath;
+
+ protected Map<String, List<String>> argMap;
+
+ /** internal list of options that have been added */
+ private final List<Option> options;
+ private Group group;
+
+ protected AbstractJob() {
+ options = Lists.newLinkedList();
+ }
+
+ /** Returns the input path established by a call to {@link #parseArguments(String[])}.
+ * The source of the path may be an input option added using {@link #addInputOption()}
+ * or it may be the value of the {@code mapred.input.dir} configuration
+ * property.
+ */
+ protected Path getInputPath() {
+ return inputPath;
+ }
+
+ /** Returns the output path established by a call to {@link #parseArguments(String[])}.
+ * The source of the path may be an output option added using {@link #addOutputOption()}
+ * or it may be the value of the {@code mapred.input.dir} configuration
+ * property.
+ */
+ protected Path getOutputPath() {
+ return outputPath;
+ }
+
+ protected Path getOutputPath(String path) {
+ return new Path(outputPath, path);
+ }
+
+ protected File getInputFile() {
+ return inputFile;
+ }
+
+ protected File getOutputFile() {
+ return outputFile;
+ }
+
+
+ protected Path getTempPath() {
+ return tempPath;
+ }
+
+ protected Path getTempPath(String directory) {
+ return new Path(tempPath, directory);
+ }
+
+ @Override
+ public Configuration getConf() {
+ Configuration result = super.getConf();
+ if (result == null) {
+ return new Configuration();
+ }
+ return result;
+ }
+
+ /** Add an option with no argument whose presence can be checked for using
+ * {@code containsKey} method on the map returned by {@link #parseArguments(String[])};
+ */
+ protected void addFlag(String name, String shortName, String description) {
+ options.add(buildOption(name, shortName, description, false, false, null));
+ }
+
+ /** Add an option to the the set of options this job will parse when
+ * {@link #parseArguments(String[])} is called. This options has an argument
+ * with null as its default value.
+ */
+ protected void addOption(String name, String shortName, String description) {
+ options.add(buildOption(name, shortName, description, true, false, null));
+ }
+
+ /** Add an option to the the set of options this job will parse when
+ * {@link #parseArguments(String[])} is called.
+ *
+ * @param required if true the {@link #parseArguments(String[])} will throw
+ * fail with an error and usage message if this option is not specified
+ * on the command line.
+ */
+ protected void addOption(String name, String shortName, String description, boolean required) {
+ options.add(buildOption(name, shortName, description, true, required, null));
+ }
+
+ /** Add an option to the the set of options this job will parse when
+ * {@link #parseArguments(String[])} is called. If this option is not
+ * specified on the command line the default value will be
+ * used.
+ *
+ * @param defaultValue the default argument value if this argument is not
+ * found on the command-line. null is allowed.
+ */
+ protected void addOption(String name, String shortName, String description, String defaultValue) {
+ options.add(buildOption(name, shortName, description, true, false, defaultValue));
+ }
+
+ /** Add an arbitrary option to the set of options this job will parse when
+ * {@link #parseArguments(String[])} is called. If this option has no
+ * argument, use {@code containsKey} on the map returned by
+ * {@code parseArguments} to check for its presence. Otherwise, the
+ * string value of the option will be placed in the map using a key
+ * equal to this options long name preceded by '--'.
+ * @return the option added.
+ */
+ protected Option addOption(Option option) {
+ options.add(option);
+ return option;
+ }
+
+ protected Group getGroup() {
+ return group;
+ }
+
+ /** Add the default input directory option, '-i' which takes a directory
+ * name as an argument. When {@link #parseArguments(String[])} is
+ * called, the inputPath will be set based upon the value for this option.
+ * If this method is called, the input is required.
+ */
+ protected void addInputOption() {
+ this.inputOption = addOption(DefaultOptionCreator.inputOption().create());
+ }
+
+ /** Add the default output directory option, '-o' which takes a directory
+ * name as an argument. When {@link #parseArguments(String[])} is
+ * called, the outputPath will be set based upon the value for this option.
+ * If this method is called, the output is required.
+ */
+ protected void addOutputOption() {
+ this.outputOption = addOption(DefaultOptionCreator.outputOption().create());
+ }
+
+ /** Build an option with the given parameters. Name and description are
+ * required.
+ *
+ * @param name the long name of the option prefixed with '--' on the command-line
+ * @param shortName the short name of the option, prefixed with '-' on the command-line
+ * @param description description of the option displayed in help method
+ * @param hasArg true if the option has an argument.
+ * @param required true if the option is required.
+ * @param defaultValue default argument value, can be null.
+ * @return the option.
+ */
+ protected static Option buildOption(String name,
+ String shortName,
+ String description,
+ boolean hasArg,
+ boolean required,
+ String defaultValue) {
+
+ return buildOption(name, shortName, description, hasArg, 1, 1, required, defaultValue);
+ }
+
+ protected static Option buildOption(String name,
+ String shortName,
+ String description,
+ boolean hasArg, int min, int max,
+ boolean required,
+ String defaultValue) {
+
+ DefaultOptionBuilder optBuilder = new DefaultOptionBuilder().withLongName(name).withDescription(description)
+ .withRequired(required);
+
+ if (shortName != null) {
+ optBuilder.withShortName(shortName);
+ }
+
+ if (hasArg) {
+ ArgumentBuilder argBuilder = new ArgumentBuilder().withName(name).withMinimum(min).withMaximum(max);
+
+ if (defaultValue != null) {
+ argBuilder = argBuilder.withDefault(defaultValue);
+ }
+
+ optBuilder.withArgument(argBuilder.create());
+ }
+
+ return optBuilder.create();
+ }
+
+ /**
+ * @param name The name of the option
+ * @return the {@link org.apache.commons.cli2.Option} with the name, else null
+ */
+ protected Option getCLIOption(String name) {
+ for (Option option : options) {
+ if (option.getPreferredName().equals(name)) {
+ return option;
+ }
+ }
+ return null;
+ }
+
+ /** Parse the arguments specified based on the options defined using the
+ * various {@code addOption} methods. If -h is specified or an
+ * exception is encountered print help and return null. Has the
+ * side effect of setting inputPath and outputPath
+ * if {@code addInputOption} or {@code addOutputOption}
+ * or {@code mapred.input.dir} or {@code mapred.output.dir}
+ * are present in the Configuration.
+ *
+ * @return a {@code Map<String,String>} containing options and their argument values.
+ * The presence of a flag can be tested using {@code containsKey}, while
+ * argument values can be retrieved using {@code get(optionName)}. The
+ * names used for keys are the option name parameter prefixed by '--'.
+ *
+ * @see #parseArguments(String[], boolean, boolean) -- passes in false, false for the optional args.
+ */
+ public Map<String, List<String>> parseArguments(String[] args) throws IOException {
+ return parseArguments(args, false, false);
+ }
+
+ /**
+ *
+ * @param args The args to parse
+ * @param inputOptional if false, then the input option, if set, need not be present. If true and input is an option
+ * and there is no input, then throw an error
+ * @param outputOptional if false, then the output option, if set, need not be present. If true and output is an
+ * option and there is no output, then throw an error
+ * @return the args parsed into a map.
+ */
+ public Map<String, List<String>> parseArguments(String[] args, boolean inputOptional, boolean outputOptional)
+ throws IOException {
+ Option helpOpt = addOption(DefaultOptionCreator.helpOption());
+ addOption("tempDir", null, "Intermediate output directory", "temp");
+ addOption("startPhase", null, "First phase to run", "0");
+ addOption("endPhase", null, "Last phase to run", String.valueOf(Integer.MAX_VALUE));
+
+ GroupBuilder gBuilder = new GroupBuilder().withName("Job-Specific Options:");
+
+ for (Option opt : options) {
+ gBuilder = gBuilder.withOption(opt);
+ }
+
+ group = gBuilder.create();
+
+ CommandLine cmdLine;
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ parser.setHelpOption(helpOpt);
+ cmdLine = parser.parse(args);
+
+ } catch (OptionException e) {
+ log.error(e.getMessage());
+ CommandLineUtil.printHelpWithGenericOptions(group, e);
+ return null;
+ }
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelpWithGenericOptions(group);
+ return null;
+ }
+
+ try {
+ parseDirectories(cmdLine, inputOptional, outputOptional);
+ } catch (IllegalArgumentException e) {
+ log.error(e.getMessage());
+ CommandLineUtil.printHelpWithGenericOptions(group);
+ return null;
+ }
+
+ argMap = new TreeMap<String, List<String>>();
+ maybePut(argMap, cmdLine, this.options.toArray(new Option[this.options.size()]));
+
+ this.tempPath = new Path(getOption("tempDir"));
+
+ if (!hasOption("quiet")) {
+ log.info("Command line arguments: {}", argMap);
+ }
+ return argMap;
+ }
+
+ /**
+ * Build the option key (--name) from the option name
+ */
+ public static String keyFor(String optionName) {
+ return "--" + optionName;
+ }
+
+ /**
+ * @return the requested option, or null if it has not been specified
+ */
+ public String getOption(String optionName) {
+ List<String> list = argMap.get(keyFor(optionName));
+ if (list != null && !list.isEmpty()) {
+ return list.get(0);
+ }
+ return null;
+ }
+
+ /**
+ * Get the option, else the default
+ * @param optionName The name of the option to look up, without the --
+ * @param defaultVal The default value.
+ * @return The requested option, else the default value if it doesn't exist
+ */
+ public String getOption(String optionName, String defaultVal) {
+ String res = getOption(optionName);
+ if (res == null) {
+ res = defaultVal;
+ }
+ return res;
+ }
+
+ public int getInt(String optionName) {
+ return Integer.parseInt(getOption(optionName));
+ }
+
+ public int getInt(String optionName, int defaultVal) {
+ return Integer.parseInt(getOption(optionName, String.valueOf(defaultVal)));
+ }
+
+ public float getFloat(String optionName) {
+ return Float.parseFloat(getOption(optionName));
+ }
+
+ public float getFloat(String optionName, float defaultVal) {
+ return Float.parseFloat(getOption(optionName, String.valueOf(defaultVal)));
+ }
+
+ /**
+ * Options can occur multiple times, so return the list
+ * @param optionName The unadorned (no "--" prefixing it) option name
+ * @return The values, else null. If the option is present, but has no values, then the result will be an
+ * empty list (Collections.emptyList())
+ */
+ public List<String> getOptions(String optionName) {
+ return argMap.get(keyFor(optionName));
+ }
+
+ /**
+ * @return if the requested option has been specified
+ */
+ public boolean hasOption(String optionName) {
+ return argMap.containsKey(keyFor(optionName));
+ }
+
+
+ /**
+ * Get the cardinality of the input vectors
+ *
+ * @param matrix
+ * @return the cardinality of the vector
+ */
+ public int getDimensions(Path matrix) throws IOException {
+
+ SequenceFile.Reader reader = null;
+ try {
+ reader = new SequenceFile.Reader(FileSystem.get(getConf()), matrix, getConf());
+
+ Writable row = ClassUtils.instantiateAs(reader.getKeyClass().asSubclass(Writable.class), Writable.class);
+
+ Preconditions.checkArgument(reader.getValueClass().equals(VectorWritable.class),
+ "value type of sequencefile must be a VectorWritable");
+
+ VectorWritable vectorWritable = new VectorWritable();
+ boolean hasAtLeastOneRow = reader.next(row, vectorWritable);
+ Preconditions.checkState(hasAtLeastOneRow, "matrix must have at least one row");
+
+ return vectorWritable.get().size();
+
+ } finally {
+ Closeables.close(reader, true);
+ }
+ }
+
+ /**
+ * Obtain input and output directories from command-line options or hadoop
+ * properties. If {@code addInputOption} or {@code addOutputOption}
+ * has been called, this method will throw an {@code OptionException} if
+ * no source (command-line or property) for that value is present.
+ * Otherwise, {@code inputPath} or {@code outputPath} will be
+ * non-null only if specified as a hadoop property. Command-line options
+ * take precedence over hadoop properties.
+ *
+ * @throws IllegalArgumentException if either inputOption is present,
+ * and neither {@code --input} nor {@code -Dmapred.input dir} are
+ * specified or outputOption is present and neither {@code --output}
+ * nor {@code -Dmapred.output.dir} are specified.
+ */
+ protected void parseDirectories(CommandLine cmdLine, boolean inputOptional, boolean outputOptional) {
+
+ Configuration conf = getConf();
+
+ if (inputOption != null && cmdLine.hasOption(inputOption)) {
+ this.inputPath = new Path(cmdLine.getValue(inputOption).toString());
+ this.inputFile = new File(cmdLine.getValue(inputOption).toString());
+ }
+ if (inputPath == null && conf.get("mapred.input.dir") != null) {
+ this.inputPath = new Path(conf.get("mapred.input.dir"));
+ }
+
+ if (outputOption != null && cmdLine.hasOption(outputOption)) {
+ this.outputPath = new Path(cmdLine.getValue(outputOption).toString());
+ this.outputFile = new File(cmdLine.getValue(outputOption).toString());
+ }
+ if (outputPath == null && conf.get("mapred.output.dir") != null) {
+ this.outputPath = new Path(conf.get("mapred.output.dir"));
+ }
+
+ Preconditions.checkArgument(inputOptional || inputOption == null || inputPath != null,
+ "No input specified or -Dmapred.input.dir must be provided to specify input directory");
+ Preconditions.checkArgument(outputOptional || outputOption == null || outputPath != null,
+ "No output specified: or -Dmapred.output.dir must be provided to specify output directory");
+ }
+
+ protected static void maybePut(Map<String, List<String>> args, CommandLine cmdLine, Option... opt) {
+ for (Option o : opt) {
+
+ // the option appeared on the command-line, or it has a value
+ // (which is likely a default value).
+ if (cmdLine.hasOption(o) || cmdLine.getValue(o) != null
+ || (cmdLine.getValues(o) != null && !cmdLine.getValues(o).isEmpty())) {
+
+ // nulls are ok, for cases where options are simple flags.
+ List<?> vo = cmdLine.getValues(o);
+ if (vo != null && !vo.isEmpty()) {
+ List<String> vals = Lists.newArrayList();
+ for (Object o1 : vo) {
+ vals.add(o1.toString());
+ }
+ args.put(o.getPreferredName(), vals);
+ } else {
+ args.put(o.getPreferredName(), null);
+ }
+ }
+ }
+ }
+
+ /**
+ *
+ * @param args The input argument map
+ * @param optName The adorned (including "--") option name
+ * @return The first value in the match, else null
+ */
+ public static String getOption(Map<String, List<String>> args, String optName) {
+ List<String> res = args.get(optName);
+ if (res != null && !res.isEmpty()) {
+ return res.get(0);
+ }
+ return null;
+ }
+
+
+ protected static boolean shouldRunNextPhase(Map<String, List<String>> args, AtomicInteger currentPhase) {
+ int phase = currentPhase.getAndIncrement();
+ String startPhase = getOption(args, "--startPhase");
+ String endPhase = getOption(args, "--endPhase");
+ boolean phaseSkipped = (startPhase != null && phase < Integer.parseInt(startPhase))
+ || (endPhase != null && phase > Integer.parseInt(endPhase));
+ if (phaseSkipped) {
+ log.info("Skipping phase {}", phase);
+ }
+ return !phaseSkipped;
+ }
+
+ protected Job prepareJob(Path inputPath,
+ Path outputPath,
+ Class<? extends InputFormat> inputFormat,
+ Class<? extends Mapper> mapper,
+ Class<? extends Writable> mapperKey,
+ Class<? extends Writable> mapperValue,
+ Class<? extends OutputFormat> outputFormat) throws IOException {
+ return prepareJob(inputPath, outputPath, inputFormat, mapper, mapperKey, mapperValue, outputFormat, null);
+
+ }
+ protected Job prepareJob(Path inputPath,
+ Path outputPath,
+ Class<? extends InputFormat> inputFormat,
+ Class<? extends Mapper> mapper,
+ Class<? extends Writable> mapperKey,
+ Class<? extends Writable> mapperValue,
+ Class<? extends OutputFormat> outputFormat,
+ String jobname) throws IOException {
+
+ Job job = HadoopUtil.prepareJob(inputPath, outputPath,
+ inputFormat, mapper, mapperKey, mapperValue, outputFormat, getConf());
+
+ String name =
+ jobname != null ? jobname : HadoopUtil.getCustomJobName(getClass().getSimpleName(), job, mapper, Reducer.class);
+
+ job.setJobName(name);
+ return job;
+
+ }
+
+ protected Job prepareJob(Path inputPath, Path outputPath, Class<? extends Mapper> mapper,
+ Class<? extends Writable> mapperKey, Class<? extends Writable> mapperValue, Class<? extends Reducer> reducer,
+ Class<? extends Writable> reducerKey, Class<? extends Writable> reducerValue) throws IOException {
+ return prepareJob(inputPath, outputPath, SequenceFileInputFormat.class, mapper, mapperKey, mapperValue, reducer,
+ reducerKey, reducerValue, SequenceFileOutputFormat.class);
+ }
+
+ protected Job prepareJob(Path inputPath,
+ Path outputPath,
+ Class<? extends InputFormat> inputFormat,
+ Class<? extends Mapper> mapper,
+ Class<? extends Writable> mapperKey,
+ Class<? extends Writable> mapperValue,
+ Class<? extends Reducer> reducer,
+ Class<? extends Writable> reducerKey,
+ Class<? extends Writable> reducerValue,
+ Class<? extends OutputFormat> outputFormat) throws IOException {
+ Job job = HadoopUtil.prepareJob(inputPath, outputPath,
+ inputFormat, mapper, mapperKey, mapperValue, reducer, reducerKey, reducerValue, outputFormat, getConf());
+ job.setJobName(HadoopUtil.getCustomJobName(getClass().getSimpleName(), job, mapper, Reducer.class));
+ return job;
+ }
+
+ /**
+ * necessary to make this job (having a combined input path) work on Amazon S3, hopefully this is
+ * obsolete when MultipleInputs is available again
+ */
+ public static void setS3SafeCombinedInputPath(Job job, Path referencePath, Path inputPathOne, Path inputPathTwo)
+ throws IOException {
+ FileSystem fs = FileSystem.get(referencePath.toUri(), job.getConfiguration());
+ FileInputFormat.setInputPaths(job, inputPathOne.makeQualified(fs), inputPathTwo.makeQualified(fs));
+ }
+
+ protected Class<? extends Analyzer> getAnalyzerClassFromOption() throws ClassNotFoundException {
+ Class<? extends Analyzer> analyzerClass = StandardAnalyzer.class;
+ if (hasOption(DefaultOptionCreator.ANALYZER_NAME_OPTION)) {
+ String className = getOption(DefaultOptionCreator.ANALYZER_NAME_OPTION);
+ analyzerClass = Class.forName(className).asSubclass(Analyzer.class);
+ // try instantiating it, b/c there isn't any point in setting it if
+ // you can't instantiate it
+ //ClassUtils.instantiateAs(analyzerClass, Analyzer.class);
+ AnalyzerUtils.createAnalyzer(analyzerClass);
+ }
+ return analyzerClass;
+ }
+
+ /**
+ * Overrides the base implementation to install the Oozie action configuration resource
+ * into the provided Configuration object; note that ToolRunner calls setConf on the Tool
+ * before it invokes run.
+ */
+ @Override
+ public void setConf(Configuration conf) {
+ super.setConf(conf);
+
+ // If running in an Oozie workflow as a Java action, need to add the
+ // Configuration resource provided by Oozie to this job's config.
+ String oozieActionConfXml = System.getProperty("oozie.action.conf.xml");
+ if (oozieActionConfXml != null && conf != null) {
+ conf.addResource(new Path("file:///", oozieActionConfXml));
+ log.info("Added Oozie action Configuration resource {} to the Hadoop Configuration", oozieActionConfXml);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/ClassUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/ClassUtils.java b/mr/src/main/java/org/apache/mahout/common/ClassUtils.java
new file mode 100644
index 0000000..8052ef1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/ClassUtils.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.lang.reflect.InvocationTargetException;
+
+public final class ClassUtils {
+
+ private ClassUtils() {}
+
+ public static <T> T instantiateAs(String classname, Class<T> asSubclassOfClass) {
+ try {
+ return instantiateAs(Class.forName(classname).asSubclass(asSubclassOfClass), asSubclassOfClass);
+ } catch (ClassNotFoundException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ public static <T> T instantiateAs(String classname, Class<T> asSubclassOfClass, Class<?>[] params, Object[] args) {
+ try {
+ return instantiateAs(Class.forName(classname).asSubclass(asSubclassOfClass), asSubclassOfClass, params, args);
+ } catch (ClassNotFoundException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ public static <T> T instantiateAs(Class<? extends T> clazz,
+ Class<T> asSubclassOfClass,
+ Class<?>[] params,
+ Object[] args) {
+ try {
+ return clazz.asSubclass(asSubclassOfClass).getConstructor(params).newInstance(args);
+ } catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException ie) {
+ throw new IllegalStateException(ie);
+ }
+ }
+
+
+ public static <T> T instantiateAs(Class<? extends T> clazz, Class<T> asSubclassOfClass) {
+ try {
+ return clazz.asSubclass(asSubclassOfClass).getConstructor().newInstance();
+ } catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException ie) {
+ throw new IllegalStateException(ie);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/CommandLineUtil.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/CommandLineUtil.java b/mr/src/main/java/org/apache/mahout/common/CommandLineUtil.java
new file mode 100644
index 0000000..0cc93ba
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/CommandLineUtil.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+
+import com.google.common.base.Charsets;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.util.GenericOptionsParser;
+
+public final class CommandLineUtil {
+
+ private CommandLineUtil() { }
+
+ public static void printHelp(Group group) {
+ HelpFormatter formatter = new HelpFormatter();
+ formatter.setGroup(group);
+ formatter.print();
+ }
+
+ /**
+ * Print the options supported by {@code GenericOptionsParser}.
+ * In addition to the options supported by the job, passed in as the
+ * group parameter.
+ *
+ * @param group job-specific command-line options.
+ */
+ public static void printHelpWithGenericOptions(Group group) throws IOException {
+ new GenericOptionsParser(new Configuration(), new org.apache.commons.cli.Options(), new String[0]);
+ PrintWriter pw = new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true);
+ HelpFormatter formatter = new HelpFormatter();
+ formatter.setGroup(group);
+ formatter.setPrintWriter(pw);
+ formatter.setFooter("Specify HDFS directories while running on hadoop; else specify local file system directories");
+ formatter.print();
+ }
+
+ public static void printHelpWithGenericOptions(Group group, OptionException oe) throws IOException {
+ new GenericOptionsParser(new Configuration(), new org.apache.commons.cli.Options(), new String[0]);
+ PrintWriter pw = new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true);
+ HelpFormatter formatter = new HelpFormatter();
+ formatter.setGroup(group);
+ formatter.setPrintWriter(pw);
+ formatter.setException(oe);
+ formatter.print();
+ }
+
+}
[09/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java
new file mode 100644
index 0000000..ca4d2b2
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarityTest.java
@@ -0,0 +1,236 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.common.Weighting;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.junit.Test;
+
+/** <p>Tests {@link EuclideanDistanceSimilarity}.</p> */
+public final class EuclideanDistanceSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testFullCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {3.0, -2.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation1Weighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {3.0, -2.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {3.0, 3.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertEquals(1.0, correlation, EPSILON);
+ }
+
+ @Test
+ public void testNoCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {-3.0, 2.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(0.1639607805437114, correlation);
+ }
+
+ @Test
+ public void testNoCorrelation1Weighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {-3.0, 2.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(0.7213202601812372, correlation);
+ }
+
+ @Test
+ public void testNoCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 1.0, null},
+ {null, null, 1.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoCorrelation3() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {90.0, 80.0, 70.0},
+ {70.0, 80.0, 90.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(0.05770363219029305, correlation);
+ }
+
+ @Test
+ public void testSimple() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {2.0, 5.0, 6.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(0.2843646522044218, correlation);
+ }
+
+ @Test
+ public void testSimpleWeighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {2.0, 5.0, 6.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(0.8210911630511055, correlation);
+ }
+
+ @Test
+ public void testFullItemCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {-2.0, -2.0},
+ });
+ double correlation =
+ new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullItemCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {3.0, 3.0},
+ });
+ double correlation =
+ new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1);
+ assertEquals(1.0, correlation, EPSILON);
+ }
+
+ @Test
+ public void testNoItemCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -3.0},
+ {-2.0, 2.0},
+ });
+ double correlation =
+ new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(0.1639607805437114, correlation);
+ }
+
+ @Test
+ public void testNoItemCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 1.0, null},
+ {null, null, 1.0},
+ });
+ double correlation = new EuclideanDistanceSimilarity(dataModel).itemSimilarity(1, 2);
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoItemCorrelation3() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {90.0, 70.0},
+ {80.0, 80.0},
+ {70.0, 90.0},
+ });
+ double correlation =
+ new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(0.05770363219029305, correlation);
+ }
+
+ @Test
+ public void testSimpleItem() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {1.0, 2.0},
+ {2.0, 5.0},
+ {3.0, 6.0},
+ });
+ double correlation =
+ new EuclideanDistanceSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(0.2843646522044218, correlation);
+ }
+
+ @Test
+ public void testSimpleItemWeighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {1.0, 2.0},
+ {2.0, 5.0},
+ {3.0, 6.0},
+ });
+ ItemSimilarity itemSimilarity = new EuclideanDistanceSimilarity(dataModel, Weighting.WEIGHTED);
+ double correlation = itemSimilarity.itemSimilarity(0, 1);
+ assertCorrelationEquals(0.8210911630511055, correlation);
+ }
+
+ @Test
+ public void testRefresh() throws TasteException {
+ // Make sure this doesn't throw an exception
+ new EuclideanDistanceSimilarity(getDataModel()).refresh(null);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java
new file mode 100644
index 0000000..5ce255c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarityTest.java
@@ -0,0 +1,104 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+/** <p>Tests {@link GenericItemSimilarity}.</p> */
+public final class GenericItemSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testSimple() {
+ List<GenericItemSimilarity.ItemItemSimilarity> similarities = Lists.newArrayList();
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 2, 0.5));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(2, 1, 0.6));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 1, 0.5));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 3, 0.3));
+ GenericItemSimilarity itemCorrelation = new GenericItemSimilarity(similarities);
+ assertEquals(1.0, itemCorrelation.itemSimilarity(1, 1), EPSILON);
+ assertEquals(0.6, itemCorrelation.itemSimilarity(1, 2), EPSILON);
+ assertEquals(0.6, itemCorrelation.itemSimilarity(2, 1), EPSILON);
+ assertEquals(0.3, itemCorrelation.itemSimilarity(1, 3), EPSILON);
+ assertTrue(Double.isNaN(itemCorrelation.itemSimilarity(3, 4)));
+ }
+
+ @Test
+ public void testFromCorrelation() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {1.0, 2.0},
+ {2.0, 5.0},
+ {3.0, 6.0},
+ });
+ ItemSimilarity otherSimilarity = new PearsonCorrelationSimilarity(dataModel);
+ ItemSimilarity itemSimilarity = new GenericItemSimilarity(otherSimilarity, dataModel);
+ assertCorrelationEquals(1.0, itemSimilarity.itemSimilarity(0, 0));
+ assertCorrelationEquals(0.960768922830523, itemSimilarity.itemSimilarity(0, 1));
+ }
+
+ @Test
+ public void testAllSimilaritiesWithoutIndex() throws TasteException {
+
+ List<GenericItemSimilarity.ItemItemSimilarity> itemItemSimilarities =
+ Arrays.asList(new GenericItemSimilarity.ItemItemSimilarity(1L, 2L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(1L, 3L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(2L, 1L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(3L, 5L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(3L, 4L, 0.2));
+
+ ItemSimilarity similarity = new GenericItemSimilarity(itemItemSimilarities);
+
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(1L), 2L, 3L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(2L), 1L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(3L), 1L, 5L, 4L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(4L), 3L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(5L), 3L));
+ }
+
+ @Test
+ public void testAllSimilaritiesWithIndex() throws TasteException {
+
+ List<GenericItemSimilarity.ItemItemSimilarity> itemItemSimilarities =
+ Arrays.asList(new GenericItemSimilarity.ItemItemSimilarity(1L, 2L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(1L, 3L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(2L, 1L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(3L, 5L, 0.2),
+ new GenericItemSimilarity.ItemItemSimilarity(3L, 4L, 0.2));
+
+ ItemSimilarity similarity = new GenericItemSimilarity(itemItemSimilarities);
+
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(1L), 2L, 3L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(2L), 1L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(3L), 1L, 5L, 4L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(4L), 3L));
+ assertTrue(containsExactly(similarity.allSimilarItemIDs(5L), 3L));
+ }
+
+ private static boolean containsExactly(long[] allIDs, long... shouldContainID) {
+ return new FastIDSet(allIDs).intersectionSize(new FastIDSet(shouldContainID)) == shouldContainID.length;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java
new file mode 100644
index 0000000..ae9df5c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarityTest.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.junit.Test;
+
+/** <p>Tests {@link LogLikelihoodSimilarity}.</p> */
+public final class LogLikelihoodSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testCorrelation() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3, 4, 5},
+ new Double[][] {
+ {1.0, 1.0},
+ {1.0, null, 1.0},
+ {null, null, 1.0, 1.0, 1.0},
+ {1.0, 1.0, 1.0, 1.0, 1.0},
+ {null, 1.0, 1.0, 1.0, 1.0},
+ });
+
+ LogLikelihoodSimilarity similarity = new LogLikelihoodSimilarity(dataModel);
+
+ assertCorrelationEquals(0.12160727029227925, similarity.itemSimilarity(1, 0));
+ assertCorrelationEquals(0.12160727029227925, similarity.itemSimilarity(0, 1));
+
+ assertCorrelationEquals(0.5423213660693732, similarity.itemSimilarity(1, 2));
+ assertCorrelationEquals(0.5423213660693732, similarity.itemSimilarity(2, 1));
+
+ assertCorrelationEquals(0.6905400104897509, similarity.itemSimilarity(2, 3));
+ assertCorrelationEquals(0.6905400104897509, similarity.itemSimilarity(3, 2));
+
+ assertCorrelationEquals(0.8706358464330881, similarity.itemSimilarity(3, 4));
+ assertCorrelationEquals(0.8706358464330881, similarity.itemSimilarity(4, 3));
+ }
+
+ @Test
+ public void testNoSimilarity() throws Exception {
+
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3, 4},
+ new Double[][] {
+ {1.0, null, 1.0, 1.0},
+ {1.0, null, 1.0, 1.0},
+ {null, 1.0, 1.0, 1.0},
+ {null, 1.0, 1.0, 1.0},
+ });
+
+ LogLikelihoodSimilarity similarity = new LogLikelihoodSimilarity(dataModel);
+
+ assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(1, 0));
+ assertCorrelationEquals(Double.NaN, similarity.itemSimilarity(0, 1));
+
+ assertCorrelationEquals(0.0, similarity.itemSimilarity(2, 3));
+ assertCorrelationEquals(0.0, similarity.itemSimilarity(3, 2));
+ }
+
+ @Test
+ public void testRefresh() {
+ // Make sure this doesn't throw an exception
+ new LogLikelihoodSimilarity(getDataModel()).refresh(null);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java
new file mode 100644
index 0000000..bb3ad3e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarityTest.java
@@ -0,0 +1,265 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.Weighting;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import org.junit.Test;
+
+/** <p>Tests {@link PearsonCorrelationSimilarity}.</p> */
+public final class PearsonCorrelationSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testFullCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {3.0, -2.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation1Weighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {3.0, -2.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {3.0, 3.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ // Yeah, undefined in this case
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {-3.0, 2.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testNoCorrelation1Weighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -2.0},
+ {-3.0, 2.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testNoCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 1.0, null},
+ {null, null, 1.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoCorrelation3() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {90.0, 80.0, 70.0},
+ {70.0, 80.0, 90.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testSimple() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {2.0, 5.0, 6.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(0.9607689228305227, correlation);
+ }
+
+ @Test
+ public void testSimpleWeighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {2.0, 5.0, 6.0},
+ });
+ double correlation = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED).userSimilarity(1, 2);
+ assertCorrelationEquals(0.9901922307076306, correlation);
+ }
+
+ @Test
+ public void testFullItemCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {-2.0, -2.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullItemCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, 3.0},
+ {3.0, 3.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1);
+ // Yeah, undefined in this case
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoItemCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {3.0, -3.0},
+ {2.0, -2.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testNoItemCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 1.0, null},
+ {null, null, 1.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(1, 2);
+ assertTrue(Double.isNaN(correlation));
+ }
+
+ @Test
+ public void testNoItemCorrelation3() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {90.0, 70.0},
+ {80.0, 80.0},
+ {70.0, 90.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testSimpleItem() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {1.0, 2.0},
+ {2.0, 5.0},
+ {3.0, 6.0},
+ });
+ double correlation =
+ new PearsonCorrelationSimilarity(dataModel).itemSimilarity(0, 1);
+ assertCorrelationEquals(0.9607689228305227, correlation);
+ }
+
+ @Test
+ public void testSimpleItemWeighted() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {1.0, 2.0},
+ {2.0, 5.0},
+ {3.0, 6.0},
+ });
+ ItemSimilarity itemSimilarity = new PearsonCorrelationSimilarity(dataModel, Weighting.WEIGHTED);
+ double correlation = itemSimilarity.itemSimilarity(0, 1);
+ assertCorrelationEquals(0.9901922307076306, correlation);
+ }
+
+ @Test
+ public void testRefresh() throws Exception {
+ // Make sure this doesn't throw an exception
+ new PearsonCorrelationSimilarity(getDataModel()).refresh(null);
+ }
+
+ @Test
+ public void testInferrer() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 1.0, 2.0, null, null, 6.0},
+ {1.0, 8.0, null, 3.0, 4.0, null},
+ });
+ UserSimilarity similarity = new PearsonCorrelationSimilarity(dataModel);
+ similarity.setPreferenceInferrer(new PreferenceInferrer() {
+ @Override
+ public float inferPreference(long userID, long itemID) {
+ return 1.0f;
+ }
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ }
+ });
+
+ assertEquals(-0.435285750066007, similarity.userSimilarity(1L, 2L), EPSILON);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java
new file mode 100644
index 0000000..ad1e4b7
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SimilarityTestCase.java
@@ -0,0 +1,35 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+
+abstract class SimilarityTestCase extends TasteTestCase {
+
+ static void assertCorrelationEquals(double expected, double actual) {
+ if (Double.isNaN(expected)) {
+ assertTrue("Correlation is not NaN", Double.isNaN(actual));
+ } else {
+ assertTrue("Correlation is NaN", !Double.isNaN(actual));
+ assertTrue("Correlation > 1.0", actual <= 1.0);
+ assertTrue("Correlation < -1.0", actual >= -1.0);
+ assertEquals(expected, actual, EPSILON);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java
new file mode 100644
index 0000000..6034f4b
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarityTest.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.junit.Test;
+
+/** <p>Tests {@link SpearmanCorrelationSimilarity}.</p> */
+public final class SpearmanCorrelationSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testFullCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {1.0, 2.0, 3.0},
+ });
+ double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0},
+ });
+ double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testAnticorrelation() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {3.0, 2.0, 1.0},
+ });
+ double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(-1.0, correlation);
+ }
+
+ @Test
+ public void testSimple() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {2.0, 3.0, 1.0},
+ });
+ double correlation = new SpearmanCorrelationSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(-0.5, correlation);
+ }
+
+ @Test
+ public void testRefresh() {
+ // Make sure this doesn't throw an exception
+ new SpearmanCorrelationSimilarity(getDataModel()).refresh(null);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java
new file mode 100644
index 0000000..87f82b9
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarityTest.java
@@ -0,0 +1,121 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.junit.Test;
+
+/** <p>Tests {@link TanimotoCoefficientSimilarity}.</p> */
+public final class TanimotoCoefficientSimilarityTest extends SimilarityTestCase {
+
+ @Test
+ public void testNoCorrelation() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 2.0, 3.0},
+ {1.0},
+ });
+ double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(Double.NaN, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0},
+ {1.0},
+ });
+ double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(1.0, correlation);
+ }
+
+ @Test
+ public void testFullCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {1.0, 2.0, 3.0},
+ {1.0},
+ });
+ double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2);
+ assertCorrelationEquals(0.3333333333333333, correlation);
+ }
+
+ @Test
+ public void testCorrelation1() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 2.0, 3.0},
+ {1.0, 1.0},
+ });
+ double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2);
+ assertEquals(0.3333333333333333, correlation, EPSILON);
+ }
+
+ @Test
+ public void testCorrelation2() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, 2.0, 3.0, 1.0},
+ {1.0, 1.0, null, 0.0},
+ });
+ double correlation = new TanimotoCoefficientSimilarity(dataModel).userSimilarity(1, 2);
+ assertEquals(0.5, correlation, EPSILON);
+ }
+
+ @Test
+ public void testRefresh() {
+ // Make sure this doesn't throw an exception
+ new TanimotoCoefficientSimilarity(getDataModel()).refresh(null);
+ }
+
+ @Test
+ public void testReturnNaNDoubleWhenNoSimilaritiesForTwoItems() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {null, null, 3.0},
+ {1.0, 1.0, null},
+ });
+ Double similarity = new TanimotoCoefficientSimilarity(dataModel).itemSimilarity(1, 2);
+ assertEquals(Double.NaN, similarity, EPSILON);
+ }
+
+ @Test
+ public void testItemsSimilarities() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2},
+ new Double[][] {
+ {2.0, null, 2.0},
+ {1.0, 1.0, 1.0},
+ });
+ TanimotoCoefficientSimilarity tCS = new TanimotoCoefficientSimilarity(dataModel);
+ assertEquals(0.5, tCS.itemSimilarity(0, 1), EPSILON);
+ assertEquals(1, tCS.itemSimilarity(0, 2), EPSILON);
+
+ double[] similarities = tCS.itemSimilarities(0, new long [] {1, 2});
+ assertEquals(0.5, similarities[0], EPSILON);
+ assertEquals(1, similarities[1], EPSILON);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java
new file mode 100644
index 0000000..d9d28ab
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarityTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.file;
+
+import java.io.File;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity.ItemItemSimilarity;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.junit.Before;
+import org.junit.Test;
+
+/** <p>Tests {@link FileItemSimilarity}.</p> */
+public final class FileItemSimilarityTest extends TasteTestCase {
+
+ private static final String[] data = {
+ "1,5,0.125",
+ "1,7,0.5" };
+
+ private static final String[] changedData = {
+ "1,5,0.125",
+ "1,7,0.9",
+ "7,8,0.112" };
+
+ private File testFile;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ testFile = getTestTempFile("test.txt");
+ writeLines(testFile, data);
+ }
+
+ @Test
+ public void testLoadFromFile() throws Exception {
+ ItemSimilarity similarity = new FileItemSimilarity(testFile);
+
+ assertEquals(0.125, similarity.itemSimilarity(1L, 5L), EPSILON);
+ assertEquals(0.125, similarity.itemSimilarity(5L, 1L), EPSILON);
+ assertEquals(0.5, similarity.itemSimilarity(1L, 7L), EPSILON);
+ assertEquals(0.5, similarity.itemSimilarity(7L, 1L), EPSILON);
+
+ assertTrue(Double.isNaN(similarity.itemSimilarity(7L, 8L)));
+
+ double[] valuesForOne = similarity.itemSimilarities(1L, new long[] { 5L, 7L });
+ assertNotNull(valuesForOne);
+ assertEquals(2, valuesForOne.length);
+ assertEquals(0.125, valuesForOne[0], EPSILON);
+ assertEquals(0.5, valuesForOne[1], EPSILON);
+ }
+
+ @Test
+ public void testNoRefreshAfterFileUpdate() throws Exception {
+ ItemSimilarity similarity = new FileItemSimilarity(testFile, 0L);
+
+ /* call a method to make sure the original file is loaded*/
+ similarity.itemSimilarity(1L, 5L);
+
+ /* change the underlying file,
+ * we have to wait at least a second to see the change in the file's lastModified timestamp */
+ Thread.sleep(2000L);
+ writeLines(testFile, changedData);
+
+ /* we shouldn't see any changes in the data as we have not yet refreshed */
+ assertEquals(0.5, similarity.itemSimilarity(1L, 7L), EPSILON);
+ assertEquals(0.5, similarity.itemSimilarity(7L, 1L), EPSILON);
+ assertTrue(Double.isNaN(similarity.itemSimilarity(7L, 8L)));
+ }
+
+ @Test
+ public void testRefreshAfterFileUpdate() throws Exception {
+ ItemSimilarity similarity = new FileItemSimilarity(testFile, 0L);
+
+ /* call a method to make sure the original file is loaded */
+ similarity.itemSimilarity(1L, 5L);
+
+ /* change the underlying file,
+ * we have to wait at least a second to see the change in the file's lastModified timestamp */
+ Thread.sleep(2000L);
+ writeLines(testFile, changedData);
+
+ similarity.refresh(null);
+
+ /* we should now see the changes in the data */
+ assertEquals(0.9, similarity.itemSimilarity(1L, 7L), EPSILON);
+ assertEquals(0.9, similarity.itemSimilarity(7L, 1L), EPSILON);
+ assertEquals(0.125, similarity.itemSimilarity(1L, 5L), EPSILON);
+ assertEquals(0.125, similarity.itemSimilarity(5L, 1L), EPSILON);
+
+ assertFalse(Double.isNaN(similarity.itemSimilarity(7L, 8L)));
+ assertEquals(0.112, similarity.itemSimilarity(7L, 8L), EPSILON);
+ assertEquals(0.112, similarity.itemSimilarity(8L, 7L), EPSILON);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testFileNotFoundExceptionForNonExistingFile() throws Exception {
+ new FileItemSimilarity(new File("xKsdfksdfsdf"));
+ }
+
+ @Test
+ public void testFileItemItemSimilarityIterable() throws Exception {
+ Iterable<ItemItemSimilarity> similarityIterable = new FileItemItemSimilarityIterable(testFile);
+ GenericItemSimilarity similarity = new GenericItemSimilarity(similarityIterable);
+
+ assertEquals(0.125, similarity.itemSimilarity(1L, 5L), EPSILON);
+ assertEquals(0.125, similarity.itemSimilarity(5L, 1L), EPSILON);
+ assertEquals(0.5, similarity.itemSimilarity(1L, 7L), EPSILON);
+ assertEquals(0.5, similarity.itemSimilarity(7L, 1L), EPSILON);
+
+ assertTrue(Double.isNaN(similarity.itemSimilarity(7L, 8L)));
+
+ double[] valuesForOne = similarity.itemSimilarities(1L, new long[] { 5L, 7L });
+ assertNotNull(valuesForOne);
+ assertEquals(2, valuesForOne.length);
+ assertEquals(0.125, valuesForOne[0], EPSILON);
+ assertEquals(0.5, valuesForOne[1], EPSILON);
+ }
+
+ @Test
+ public void testToString() throws Exception {
+ ItemSimilarity similarity = new FileItemSimilarity(testFile);
+ assertTrue(!similarity.toString().isEmpty());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java
new file mode 100644
index 0000000..67cc2f1
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilaritiesTest.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.precompute;
+
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
+import org.apache.mahout.cf.taste.impl.similarity.TanimotoCoefficientSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
+import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities;
+import org.apache.mahout.cf.taste.similarity.precompute.SimilarItemsWriter;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+
+public class MultithreadedBatchItemSimilaritiesTest {
+
+ @Test
+ public void lessItemsThanBatchSize() throws Exception {
+
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<PreferenceArray>();
+ userData.put(1, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1, 1, 1),
+ new GenericPreference(1, 2, 1), new GenericPreference(1, 3, 1))));
+ userData.put(2, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2, 1, 1),
+ new GenericPreference(2, 2, 1), new GenericPreference(2, 4, 1))));
+
+ DataModel dataModel = new GenericDataModel(userData);
+ ItemBasedRecommender recommender =
+ new GenericItemBasedRecommender(dataModel, new TanimotoCoefficientSimilarity(dataModel));
+
+ BatchItemSimilarities batchSimilarities = new MultithreadedBatchItemSimilarities(recommender, 10);
+
+ batchSimilarities.computeItemSimilarities(1, 1, mock(SimilarItemsWriter.class));
+ }
+
+ @Test
+ public void higherDegreeOfParallelismThanBatches() throws Exception {
+
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<PreferenceArray>();
+ userData.put(1, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1, 1, 1),
+ new GenericPreference(1, 2, 1), new GenericPreference(1, 3, 1))));
+ userData.put(2, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2, 1, 1),
+ new GenericPreference(2, 2, 1), new GenericPreference(2, 4, 1))));
+
+ DataModel dataModel = new GenericDataModel(userData);
+ ItemBasedRecommender recommender =
+ new GenericItemBasedRecommender(dataModel, new TanimotoCoefficientSimilarity(dataModel));
+
+ BatchItemSimilarities batchSimilarities = new MultithreadedBatchItemSimilarities(recommender, 10);
+
+ try {
+ // Batch size is 100, so we only get 1 batch from 3 items, but we use a degreeOfParallelism of 2
+ batchSimilarities.computeItemSimilarities(2, 1, mock(SimilarItemsWriter.class));
+ fail();
+ } catch (IOException e) {}
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java b/mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java
new file mode 100644
index 0000000..f037209
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/ClassifierData.java
@@ -0,0 +1,102 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+/**
+ * Class containing sample docs from ASF websites under mahout, lucene and spamassasin projects
+ *
+ */
+public final class ClassifierData {
+
+ public static final String[][] DATA = {
+ {
+ "mahout",
+ "Mahout's goal is to build scalable machine learning libraries. With scalable we mean: "
+ + "Scalable to reasonably large data sets. Our core algorithms for clustering,"
+ + " classfication and batch based collaborative filtering are implemented on top "
+ + "of Apache Hadoop using the map/reduce paradigm. However we do not restrict "
+ + "contributions to Hadoop based implementations: Contributions that run on"},
+ {
+ "mahout",
+ " a single node or on a non-Hadoop cluster are welcome as well. The core"
+ + " libraries are highly optimized to allow for good performance also for"
+ + " non-distributed algorithms. Scalable to support your business case. "
+ + "Mahout is distributed under a commercially friendly Apache Software license. "
+ + "Scalable community. The goal of Mahout is to build a vibrant, responsive, "},
+ {
+ "mahout",
+ "diverse community to facilitate discussions not only on the project itself"
+ + " but also on potential use cases. Come to the mailing lists to find out more."
+ + " Currently Mahout supports mainly four use cases: Recommendation mining takes "
+ + "users' behavior and from that tries to find items users might like. Clustering "},
+ {
+ "mahout",
+ "takes e.g. text documents and groups them into groups of topically related documents."
+ + " Classification learns from exisiting categorized documents what documents of"
+ + " a specific category look like and is able to assign unlabelled documents to "
+ + "the (hopefully) correct category. Frequent itemset mining takes a set of item"
+ + " groups (terms in a query session, shopping cart content) and identifies, which"
+ + " individual items usually appear together."},
+ {
+ "lucene",
+ "Apache Lucene is a high-performance, full-featured text search engine library"
+ + " written entirely in Java. It is a technology suitable for nearly any application "
+ + "that requires full-text search, especially cross-platform. Apache Lucene is an open source"
+ + " project available for free download. Please use the links on the left to access Lucene. "
+ + "The new version is mostly a cleanup release without any new features. "},
+ {
+ "lucene",
+ "All deprecations targeted to be removed in version 3.0 were removed. If you "
+ + "are upgrading from version 2.9.1 of Lucene, you have to fix all deprecation warnings"
+ + " in your code base to be able to recompile against this version. This is the first Lucene"},
+ {
+ "lucene",
+ " release with Java 5 as a minimum requirement. The API was cleaned up to make use of Java 5's "
+ + "generics, varargs, enums, and autoboxing. New users of Lucene are advised to use this version "
+ + "for new developments, because it has a clean, type safe new API. Upgrading users can now remove"},
+ {
+ "lucene",
+ " unnecessary casts and add generics to their code, too. If you have not upgraded your installation "
+ + "to Java 5, please read the file JRE_VERSION_MIGRATION.txt (please note that this is not related to"
+ + " Lucene 3.0, it will also happen with any previous release when you upgrade your Java environment)."},
+ {
+ "spamassasin",
+ "SpamAssassin is a mail filter to identify spam. It is an intelligent email filter which uses a diverse "
+ + "range of tests to identify unsolicited bulk email, more commonly known as Spam. These tests are applied "
+ + "to email headers and content to classify email using advanced statistical methods. In addition, "},
+ {
+ "spamassasin",
+ "SpamAssassin has a modular architecture that allows other technologies to be quickly wielded against spam"
+ + " and is designed for easy integration into virtually any email system."
+ + "SpamAssassin's practical multi-technique approach, modularity, and extensibility continue to give it an "},
+ {
+ "spamassasin",
+ "advantage over other anti-spam systems. Due to these advantages, SpamAssassin is widely used in all aspects "
+ + "of email management. You can readily find SpamAssassin in use in both email clients and servers, on many "
+ + "different operating systems, filtering incoming as well as outgoing email, and implementing a "
+ + "very broad range "},
+ {
+ "spamassasin",
+ "of policy actions. These installations include service providers, businesses, not-for-profit and "
+ + "educational organizations, and end-user systems. SpamAssassin also forms the basis for numerous "
+ + "commercial anti-spam products available on the market today."}};
+
+
+ private ClassifierData() { }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java b/mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
new file mode 100644
index 0000000..3ffff85
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
@@ -0,0 +1,119 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.junit.Test;
+
+public final class ConfusionMatrixTest extends MahoutTestCase {
+
+ private static final int[][] VALUES = {{2, 3}, {10, 20}};
+ private static final String[] LABELS = {"Label1", "Label2"};
+ private static final int[] OTHER = {3, 6};
+ private static final String DEFAULT_LABEL = "other";
+
+ @Test
+ public void testBuild() {
+ ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
+ checkValues(confusionMatrix);
+ checkAccuracy(confusionMatrix);
+ }
+
+ @Test
+ public void testGetMatrix() {
+ ConfusionMatrix confusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL);
+ Matrix m = confusionMatrix.getMatrix();
+ Map<String, Integer> rowLabels = m.getRowLabelBindings();
+
+ assertEquals(confusionMatrix.getLabels().size(), m.numCols());
+ assertTrue(rowLabels.keySet().contains(LABELS[0]));
+ assertTrue(rowLabels.keySet().contains(LABELS[1]));
+ assertTrue(rowLabels.keySet().contains(DEFAULT_LABEL));
+ assertEquals(2, confusionMatrix.getCorrect(LABELS[0]));
+ assertEquals(20, confusionMatrix.getCorrect(LABELS[1]));
+ assertEquals(0, confusionMatrix.getCorrect(DEFAULT_LABEL));
+ }
+
+ /**
+ * Example taken from
+ * http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
+ */
+ @Test
+ public void testPrecisionRecallAndF1ScoreAsScikitLearn() {
+ Collection<String> labelList = Arrays.asList("0", "1", "2");
+
+ ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, "DEFAULT");
+ confusionMatrix.putCount("0", "0", 2);
+ confusionMatrix.putCount("1", "0", 1);
+ confusionMatrix.putCount("1", "2", 1);
+ confusionMatrix.putCount("2", "1", 2);
+
+ double delta = 0.001;
+ assertEquals(0.222, confusionMatrix.getWeightedPrecision(), delta);
+ assertEquals(0.333, confusionMatrix.getWeightedRecall(), delta);
+ assertEquals(0.266, confusionMatrix.getWeightedF1score(), delta);
+ }
+
+ private static void checkValues(ConfusionMatrix cm) {
+ int[][] counts = cm.getConfusionMatrix();
+ cm.toString();
+ assertEquals(counts.length, counts[0].length);
+ assertEquals(3, counts.length);
+ assertEquals(VALUES[0][0], counts[0][0]);
+ assertEquals(VALUES[0][1], counts[0][1]);
+ assertEquals(VALUES[1][0], counts[1][0]);
+ assertEquals(VALUES[1][1], counts[1][1]);
+ assertTrue(Arrays.equals(new int[3], counts[2])); // zeros
+ assertEquals(OTHER[0], counts[0][2]);
+ assertEquals(OTHER[1], counts[1][2]);
+ assertEquals(3, cm.getLabels().size());
+ assertTrue(cm.getLabels().contains(LABELS[0]));
+ assertTrue(cm.getLabels().contains(LABELS[1]));
+ assertTrue(cm.getLabels().contains(DEFAULT_LABEL));
+ }
+
+ private static void checkAccuracy(ConfusionMatrix cm) {
+ Collection<String> labelstrs = cm.getLabels();
+ assertEquals(3, labelstrs.size());
+ assertEquals(25.0, cm.getAccuracy("Label1"), EPSILON);
+ assertEquals(55.5555555, cm.getAccuracy("Label2"), EPSILON);
+ assertTrue(Double.isNaN(cm.getAccuracy("other")));
+ }
+
+ private static ConfusionMatrix fillConfusionMatrix(int[][] values, String[] labels, String defaultLabel) {
+ Collection<String> labelList = Lists.newArrayList();
+ labelList.add(labels[0]);
+ labelList.add(labels[1]);
+ ConfusionMatrix confusionMatrix = new ConfusionMatrix(labelList, defaultLabel);
+
+ confusionMatrix.putCount("Label1", "Label1", values[0][0]);
+ confusionMatrix.putCount("Label1", "Label2", values[0][1]);
+ confusionMatrix.putCount("Label2", "Label1", values[1][0]);
+ confusionMatrix.putCount("Label2", "Label2", values[1][1]);
+ confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER[0]);
+ confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
+ return confusionMatrix;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java b/mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java
new file mode 100644
index 0000000..86234f8
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/RegressionResultAnalyzerTest.java
@@ -0,0 +1,128 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public class RegressionResultAnalyzerTest extends MahoutTestCase {
+
+ private static final Pattern p1 = Pattern.compile("Correlation coefficient *: *(.*)\n");
+ private static final Pattern p2 = Pattern.compile("Mean absolute error *: *(.*)\n");
+ private static final Pattern p3 = Pattern.compile("Root mean squared error *: *(.*)\n");
+ private static final Pattern p4 = Pattern.compile("Predictable Instances *: *(.*)\n");
+ private static final Pattern p5 = Pattern.compile("Unpredictable Instances *: *(.*)\n");
+ private static final Pattern p6 = Pattern.compile("Total Regressed Instances *: *(.*)\n");
+
+ private static double[] parseAnalysis(CharSequence analysis) {
+ double[] results = new double[3];
+ Matcher m = p1.matcher(analysis);
+ if (m.find()) {
+ results[0] = Double.parseDouble(m.group(1));
+ } else {
+ return null;
+ }
+ m = p2.matcher(analysis);
+ if (m.find()) {
+ results[1] = Double.parseDouble(m.group(1));
+ } else {
+ return null;
+ }
+ m = p3.matcher(analysis);
+ if (m.find()) {
+ results[2] = Double.parseDouble(m.group(1));
+ } else {
+ return null;
+ }
+ return results;
+ }
+
+ private static int[] parseAnalysisCount(CharSequence analysis) {
+ int[] results = new int[3];
+ Matcher m = p4.matcher(analysis);
+ if (m.find()) {
+ results[0] = Integer.parseInt(m.group(1));
+ }
+ m = p5.matcher(analysis);
+ if (m.find()) {
+ results[1] = Integer.parseInt(m.group(1));
+ }
+ m = p6.matcher(analysis);
+ if (m.find()) {
+ results[2] = Integer.parseInt(m.group(1));
+ }
+ return results;
+ }
+
+ @Test
+ public void testAnalyze() {
+ double[][] results = new double[10][2];
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][0] = i;
+ results[i][1] = i + 1;
+ }
+ RegressionResultAnalyzer analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ String analysis = analyzer.toString();
+ assertArrayEquals(new double[]{1.0, 1.0, 1.0}, parseAnalysis(analysis), 0);
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][1] = Math.sqrt(i);
+ }
+ analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ analysis = analyzer.toString();
+ assertArrayEquals(new double[]{0.9573, 2.5694, 3.2848}, parseAnalysis(analysis), 0);
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][0] = results.length - i;
+ }
+ analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ analysis = analyzer.toString();
+ assertArrayEquals(new double[]{-0.9573, 4.1351, 5.1573}, parseAnalysis(analysis), 0);
+ }
+
+ @Test
+ public void testUnpredictable() {
+ double[][] results = new double[10][2];
+
+ for (int i = 0; i < results.length; i++) {
+ results[i][0] = i;
+ results[i][1] = Double.NaN;
+ }
+ RegressionResultAnalyzer analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ String analysis = analyzer.toString();
+ assertNull(parseAnalysis(analysis));
+ assertArrayEquals(new int[]{0, 10, 10}, parseAnalysisCount(analysis));
+
+ for (int i = 0; i < results.length - 3; i++) {
+ results[i][1] = Math.sqrt(i);
+ }
+ analyzer = new RegressionResultAnalyzer();
+ analyzer.setInstances(results);
+ analysis = analyzer.toString();
+ assertArrayEquals(new double[]{0.9552, 1.4526, 1.9345}, parseAnalysis(analysis), 0);
+ assertArrayEquals(new int[]{7, 3, 10}, parseAnalysisCount(analysis));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java
new file mode 100644
index 0000000..f1ec07f
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/DecisionForestTest.java
@@ -0,0 +1,206 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import java.util.List;
+import java.util.Random;
+
+import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public final class DecisionForestTest extends MahoutTestCase {
+
+ private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no",
+ "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes",
+ "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no",
+ "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no",
+ "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes",
+ "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes",
+ "rainy,71,91,TRUE,no"};
+
+ private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-",
+ "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",};
+
+ private Random rng;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+ }
+
+ private static Data[] generateTrainingDataA() throws DescriptorException {
+ // Dataset
+ Dataset dataset = DataLoader.generateDataset("C N N C L", false, TRAIN_DATA);
+
+ // Training data
+ Data data = DataLoader.loadData(dataset, TRAIN_DATA);
+ @SuppressWarnings("unchecked")
+ List<Instance>[] instances = new List[3];
+ for (int i = 0; i < instances.length; i++) {
+ instances[i] = Lists.newArrayList();
+ }
+ for (int i = 0; i < data.size(); i++) {
+ if (data.get(i).get(0) == 0.0d) {
+ instances[0].add(data.get(i));
+ } else {
+ instances[1].add(data.get(i));
+ }
+ }
+ Data[] datas = new Data[instances.length];
+ for (int i = 0; i < datas.length; i++) {
+ datas[i] = new Data(dataset, instances[i]);
+ }
+
+ return datas;
+ }
+
+ private static Data[] generateTrainingDataB() throws DescriptorException {
+
+ // Training data
+ String[] trainData = new String[20];
+ for (int i = 0; i < trainData.length; i++) {
+ if (i % 3 == 0) {
+ trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+ } else if (i % 3 == 1) {
+ trainData[i] = "B," + (i + 20) + ',' + (40 - i);
+ } else {
+ trainData[i] = "C," + (i + 20) + ',' + (i + 20);
+ }
+ }
+ // Dataset
+ Dataset dataset = DataLoader.generateDataset("C N L", true, trainData);
+ Data[] datas = new Data[3];
+ datas[0] = DataLoader.loadData(dataset, trainData);
+
+ // Training data
+ trainData = new String[20];
+ for (int i = 0; i < trainData.length; i++) {
+ if (i % 2 == 0) {
+ trainData[i] = "A," + (50 - i) + ',' + (i + 10);
+ } else {
+ trainData[i] = "B," + (i + 10) + ',' + (50 - i);
+ }
+ }
+ datas[1] = DataLoader.loadData(dataset, trainData);
+
+ // Training data
+ trainData = new String[10];
+ for (int i = 0; i < trainData.length; i++) {
+ trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+ }
+ datas[2] = DataLoader.loadData(dataset, trainData);
+
+ return datas;
+ }
+
+ private DecisionForest buildForest(Data[] datas) {
+ List<Node> trees = Lists.newArrayList();
+ for (Data data : datas) {
+ // build tree
+ DecisionTreeBuilder builder = new DecisionTreeBuilder();
+ builder.setM(data.getDataset().nbAttributes() - 1);
+ builder.setMinSplitNum(0);
+ builder.setComplemented(false);
+ trees.add(builder.build(rng, data));
+ }
+ return new DecisionForest(trees);
+ }
+
+ @Test
+ public void testClassify() throws DescriptorException {
+ // Training data
+ Data[] datas = generateTrainingDataA();
+ // Build Forest
+ DecisionForest forest = buildForest(datas);
+ // Test data
+ Dataset dataset = datas[0].getDataset();
+ Data testData = DataLoader.loadData(dataset, TEST_DATA);
+
+ double noValue = dataset.valueOf(4, "no");
+ double yesValue = dataset.valueOf(4, "yes");
+ assertEquals(noValue, forest.classify(testData.getDataset(), rng, testData.get(0)), EPSILON);
+ // This one is tie-broken -- 1 is OK too
+ //assertEquals(yesValue, forest.classify(testData.getDataset(), rng, testData.get(1)), EPSILON);
+ assertEquals(noValue, forest.classify(testData.getDataset(), rng, testData.get(2)), EPSILON);
+ }
+
+ @Test
+ public void testClassifyData() throws DescriptorException {
+ // Training data
+ Data[] datas = generateTrainingDataA();
+ // Build Forest
+ DecisionForest forest = buildForest(datas);
+ // Test data
+ Dataset dataset = datas[0].getDataset();
+ Data testData = DataLoader.loadData(dataset, TEST_DATA);
+
+ double[][] predictions = new double[testData.size()][];
+ forest.classify(testData, predictions);
+ double noValue = dataset.valueOf(4, "no");
+ double yesValue = dataset.valueOf(4, "yes");
+ assertArrayEquals(new double[][]{{noValue, Double.NaN, Double.NaN},
+ {noValue, yesValue, Double.NaN}, {noValue, noValue, Double.NaN}}, predictions);
+ }
+
+ @Test
+ public void testRegression() throws DescriptorException {
+ Data[] datas = generateTrainingDataB();
+ DecisionForest[] forests = new DecisionForest[datas.length];
+ for (int i = 0; i < datas.length; i++) {
+ Data[] subDatas = new Data[datas.length - 1];
+ int k = 0;
+ for (int j = 0; j < datas.length; j++) {
+ if (j != i) {
+ subDatas[k] = datas[j];
+ k++;
+ }
+ }
+ forests[i] = buildForest(subDatas);
+ }
+
+ double[][] predictions = new double[datas[0].size()][];
+ forests[0].classify(datas[0], predictions);
+ assertArrayEquals(new double[]{20.0, 20.0}, predictions[0], EPSILON);
+ assertArrayEquals(new double[]{39.0, 29.0}, predictions[1], EPSILON);
+ assertArrayEquals(new double[]{Double.NaN, 29.0}, predictions[2], EPSILON);
+ assertArrayEquals(new double[]{Double.NaN, 23.0}, predictions[17], EPSILON);
+
+ predictions = new double[datas[1].size()][];
+ forests[1].classify(datas[1], predictions);
+ assertArrayEquals(new double[]{30.0, 29.0}, predictions[19], EPSILON);
+
+ predictions = new double[datas[2].size()][];
+ forests[2].classify(datas[2], predictions);
+ assertArrayEquals(new double[]{29.0, 28.0}, predictions[9], EPSILON);
+
+ assertEquals(20.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(0)), EPSILON);
+ assertEquals(34.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(1)), EPSILON);
+ assertEquals(29.0, forests[0].classify(datas[0].getDataset(), rng, datas[0].get(2)), EPSILON);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java
new file mode 100644
index 0000000..85244c8
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilderTest.java
@@ -0,0 +1,78 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import java.lang.reflect.Method;
+import java.util.Random;
+import java.util.Arrays;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+public final class DecisionTreeBuilderTest extends MahoutTestCase {
+
+ /**
+ * make sure that DecisionTreeBuilder.randomAttributes() returns the correct number of attributes, that have not been
+ * selected yet
+ */
+ @Test
+ public void testRandomAttributes() throws Exception {
+ Random rng = RandomUtils.getRandom();
+ int nbAttributes = rng.nextInt(100) + 1;
+ boolean[] selected = new boolean[nbAttributes];
+
+ for (int nloop = 0; nloop < 100; nloop++) {
+ Arrays.fill(selected, false);
+
+ // randomly select some attributes
+ int nbSelected = rng.nextInt(nbAttributes - 1);
+ for (int index = 0; index < nbSelected; index++) {
+ int attr;
+ do {
+ attr = rng.nextInt(nbAttributes);
+ } while (selected[attr]);
+
+ selected[attr] = true;
+ }
+
+ int m = rng.nextInt(nbAttributes);
+
+ Method randomAttributes = DecisionTreeBuilder.class.getDeclaredMethod("randomAttributes",
+ Random.class, boolean[].class, int.class);
+ randomAttributes.setAccessible(true);
+ int[] attrs = (int[]) randomAttributes.invoke(null, rng, selected, m);
+
+ assertNotNull(attrs);
+ assertEquals(Math.min(m, nbAttributes - nbSelected), attrs.length);
+
+ for (int attr : attrs) {
+ // the attribute should not be already selected
+ assertFalse("an attribute has already been selected", selected[attr]);
+
+ // each attribute should be in the range [0, nbAttributes[
+ assertTrue(attr >= 0);
+ assertTrue(attr < nbAttributes);
+
+ // each attribute should appear only once
+ assertEquals(ArrayUtils.indexOf(attrs, attr), ArrayUtils.lastIndexOf(attrs, attr));
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java
new file mode 100644
index 0000000..78fe65f
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilderTest.java
@@ -0,0 +1,74 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import java.util.Random;
+import java.util.Arrays;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+public final class DefaultTreeBuilderTest extends MahoutTestCase {
+
+ /**
+ * make sure that DefaultTreeBuilder.randomAttributes() returns the correct number of attributes, that have not been
+ * selected yet
+ */
+ @Test
+ public void testRandomAttributes() throws Exception {
+ Random rng = RandomUtils.getRandom();
+ int nbAttributes = rng.nextInt(100) + 1;
+ boolean[] selected = new boolean[nbAttributes];
+
+ for (int nloop = 0; nloop < 100; nloop++) {
+ Arrays.fill(selected, false);
+
+ // randomly select some attributes
+ int nbSelected = rng.nextInt(nbAttributes - 1);
+ for (int index = 0; index < nbSelected; index++) {
+ int attr;
+ do {
+ attr = rng.nextInt(nbAttributes);
+ } while (selected[attr]);
+
+ selected[attr] = true;
+ }
+
+ int m = rng.nextInt(nbAttributes);
+
+ int[] attrs = DefaultTreeBuilder.randomAttributes(rng, selected, m);
+
+ assertNotNull(attrs);
+ assertEquals(Math.min(m, nbAttributes - nbSelected), attrs.length);
+
+ for (int attr : attrs) {
+ // the attribute should not be already selected
+ assertFalse("an attribute has already been selected", selected[attr]);
+
+ // each attribute should be in the range [0, nbAttributes[
+ assertTrue(attr >= 0);
+ assertTrue(attr < nbAttributes);
+
+ // each attribute should appear only once
+ assertEquals(ArrayUtils.indexOf(attrs, attr), ArrayUtils.lastIndexOf(attrs, attr));
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
new file mode 100644
index 0000000..16e7499
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Utils;
+import org.junit.Test;
+
+import java.util.Random;
+
+public final class InfiniteRecursionTest extends MahoutTestCase {
+
+ private static final double[][] dData = {
+ { 0.25, 0.0, 0.0, 5.143998668220409E-4, 0.019847102289905324, 3.5216524641879855E-4, 0.0, 0.6225857142857143, 4 },
+ { 0.25, 0.0, 0.0, 0.0010504411519893459, 0.005462138323171171, 0.0026130744829756746, 0.0, 0.4964857142857143, 3 },
+ { 0.25, 0.0, 0.0, 0.0010504411519893459, 0.005462138323171171, 0.0026130744829756746, 0.0, 0.4964857142857143, 4 },
+ { 0.25, 0.0, 0.0, 5.143998668220409E-4, 0.019847102289905324, 3.5216524641879855E-4, 0.0, 0.6225857142857143, 3 }
+ };
+
+ /**
+ * make sure DecisionTreeBuilder.build() does not throw a StackOverflowException
+ */
+ @Test
+ public void testBuild() throws Exception {
+ Random rng = RandomUtils.getRandom();
+
+ String[] source = Utils.double2String(dData);
+ String descriptor = "N N N N N N N N L";
+
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, source);
+ Data data = DataLoader.loadData(dataset, source);
+ TreeBuilder builder = new DecisionTreeBuilder();
+ builder.build(rng, data);
+
+ // regression
+ dataset = DataLoader.generateDataset(descriptor, true, source);
+ data = DataLoader.loadData(dataset, source);
+ builder = new DecisionTreeBuilder();
+ builder.build(rng, data);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java
new file mode 100644
index 0000000..39858cf
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+public final class DataConverterTest extends MahoutTestCase {
+
+ private static final int ATTRIBUTE_COUNT = 10;
+
+ private static final int INSTANCE_COUNT = 100;
+
+ @Test
+ public void testConvert() throws Exception {
+ Random rng = RandomUtils.getRandom();
+
+ String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, INSTANCE_COUNT);
+ String[] sData = Utils.double2String(source);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data data = DataLoader.loadData(dataset, sData);
+
+ DataConverter converter = new DataConverter(dataset);
+
+ for (int index = 0; index < data.size(); index++) {
+ assertEquals(data.get(index), converter.convert(sData[index]));
+ }
+
+ // regression
+ source = Utils.randomDoubles(rng, descriptor, true, INSTANCE_COUNT);
+ sData = Utils.double2String(source);
+ dataset = DataLoader.generateDataset(descriptor, true, sData);
+ data = DataLoader.loadData(dataset, sData);
+
+ converter = new DataConverter(dataset);
+
+ for (int index = 0; index < data.size(); index++) {
+ assertEquals(data.get(index), converter.convert(sData[index]));
+ }
+ }
+}
[24/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/HadoopUtil.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/HadoopUtil.java b/mr/src/main/java/org/apache/mahout/common/HadoopUtil.java
new file mode 100644
index 0000000..f693821
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/HadoopUtil.java
@@ -0,0 +1,442 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URI;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Joiner;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalFileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.InputFormat;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.OutputFormat;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class HadoopUtil {
+
+ private static final Logger log = LoggerFactory.getLogger(HadoopUtil.class);
+
+ private HadoopUtil() { }
+
+ /**
+ * Create a map-only Hadoop Job out of the passed in parameters. Does not set the
+ * Job name.
+ *
+ * @see #getCustomJobName(String, org.apache.hadoop.mapreduce.JobContext, Class, Class)
+ */
+ public static Job prepareJob(Path inputPath,
+ Path outputPath,
+ Class<? extends InputFormat> inputFormat,
+ Class<? extends Mapper> mapper,
+ Class<? extends Writable> mapperKey,
+ Class<? extends Writable> mapperValue,
+ Class<? extends OutputFormat> outputFormat, Configuration conf) throws IOException {
+
+ Job job = new Job(new Configuration(conf));
+ Configuration jobConf = job.getConfiguration();
+
+ if (mapper.equals(Mapper.class)) {
+ throw new IllegalStateException("Can't figure out the user class jar file from mapper/reducer");
+ }
+ job.setJarByClass(mapper);
+
+ job.setInputFormatClass(inputFormat);
+ jobConf.set("mapred.input.dir", inputPath.toString());
+
+ job.setMapperClass(mapper);
+ job.setMapOutputKeyClass(mapperKey);
+ job.setMapOutputValueClass(mapperValue);
+ job.setOutputKeyClass(mapperKey);
+ job.setOutputValueClass(mapperValue);
+ jobConf.setBoolean("mapred.compress.map.output", true);
+ job.setNumReduceTasks(0);
+
+ job.setOutputFormatClass(outputFormat);
+ jobConf.set("mapred.output.dir", outputPath.toString());
+
+ return job;
+ }
+
+ /**
+ * Create a map and reduce Hadoop job. Does not set the name on the job.
+ * @param inputPath The input {@link org.apache.hadoop.fs.Path}
+ * @param outputPath The output {@link org.apache.hadoop.fs.Path}
+ * @param inputFormat The {@link org.apache.hadoop.mapreduce.InputFormat}
+ * @param mapper The {@link org.apache.hadoop.mapreduce.Mapper} class to use
+ * @param mapperKey The {@link org.apache.hadoop.io.Writable} key class. If the Mapper is a no-op,
+ * this value may be null
+ * @param mapperValue The {@link org.apache.hadoop.io.Writable} value class. If the Mapper is a no-op,
+ * this value may be null
+ * @param reducer The {@link org.apache.hadoop.mapreduce.Reducer} to use
+ * @param reducerKey The reducer key class.
+ * @param reducerValue The reducer value class.
+ * @param outputFormat The {@link org.apache.hadoop.mapreduce.OutputFormat}.
+ * @param conf The {@link org.apache.hadoop.conf.Configuration} to use.
+ * @return The {@link org.apache.hadoop.mapreduce.Job}.
+ * @throws IOException if there is a problem with the IO.
+ *
+ * @see #getCustomJobName(String, org.apache.hadoop.mapreduce.JobContext, Class, Class)
+ * @see #prepareJob(org.apache.hadoop.fs.Path, org.apache.hadoop.fs.Path, Class, Class, Class, Class, Class,
+ * org.apache.hadoop.conf.Configuration)
+ */
+ public static Job prepareJob(Path inputPath,
+ Path outputPath,
+ Class<? extends InputFormat> inputFormat,
+ Class<? extends Mapper> mapper,
+ Class<? extends Writable> mapperKey,
+ Class<? extends Writable> mapperValue,
+ Class<? extends Reducer> reducer,
+ Class<? extends Writable> reducerKey,
+ Class<? extends Writable> reducerValue,
+ Class<? extends OutputFormat> outputFormat,
+ Configuration conf) throws IOException {
+
+ Job job = new Job(new Configuration(conf));
+ Configuration jobConf = job.getConfiguration();
+
+ if (reducer.equals(Reducer.class)) {
+ if (mapper.equals(Mapper.class)) {
+ throw new IllegalStateException("Can't figure out the user class jar file from mapper/reducer");
+ }
+ job.setJarByClass(mapper);
+ } else {
+ job.setJarByClass(reducer);
+ }
+
+ job.setInputFormatClass(inputFormat);
+ jobConf.set("mapred.input.dir", inputPath.toString());
+
+ job.setMapperClass(mapper);
+ if (mapperKey != null) {
+ job.setMapOutputKeyClass(mapperKey);
+ }
+ if (mapperValue != null) {
+ job.setMapOutputValueClass(mapperValue);
+ }
+
+ jobConf.setBoolean("mapred.compress.map.output", true);
+
+ job.setReducerClass(reducer);
+ job.setOutputKeyClass(reducerKey);
+ job.setOutputValueClass(reducerValue);
+
+ job.setOutputFormatClass(outputFormat);
+ jobConf.set("mapred.output.dir", outputPath.toString());
+
+ return job;
+ }
+
+
+ public static String getCustomJobName(String className, JobContext job,
+ Class<? extends Mapper> mapper,
+ Class<? extends Reducer> reducer) {
+ StringBuilder name = new StringBuilder(100);
+ String customJobName = job.getJobName();
+ if (customJobName == null || customJobName.trim().isEmpty()) {
+ name.append(className);
+ } else {
+ name.append(customJobName);
+ }
+ name.append('-').append(mapper.getSimpleName());
+ name.append('-').append(reducer.getSimpleName());
+ return name.toString();
+ }
+
+
+ public static void delete(Configuration conf, Iterable<Path> paths) throws IOException {
+ if (conf == null) {
+ conf = new Configuration();
+ }
+ for (Path path : paths) {
+ FileSystem fs = path.getFileSystem(conf);
+ if (fs.exists(path)) {
+ log.info("Deleting {}", path);
+ fs.delete(path, true);
+ }
+ }
+ }
+
+ public static void delete(Configuration conf, Path... paths) throws IOException {
+ delete(conf, Arrays.asList(paths));
+ }
+
+ public static long countRecords(Path path, Configuration conf) throws IOException {
+ long count = 0;
+ Iterator<?> iterator = new SequenceFileValueIterator<Writable>(path, true, conf);
+ while (iterator.hasNext()) {
+ iterator.next();
+ count++;
+ }
+ return count;
+ }
+
+ /**
+ * Count all the records in a directory using a
+ * {@link org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator}
+ *
+ * @param path The {@link org.apache.hadoop.fs.Path} to count
+ * @param pt The {@link org.apache.mahout.common.iterator.sequencefile.PathType}
+ * @param filter Apply the {@link org.apache.hadoop.fs.PathFilter}. May be null
+ * @param conf The Hadoop {@link org.apache.hadoop.conf.Configuration}
+ * @return The number of records
+ * @throws IOException if there was an IO error
+ */
+ public static long countRecords(Path path, PathType pt, PathFilter filter, Configuration conf) throws IOException {
+ long count = 0;
+ Iterator<?> iterator = new SequenceFileDirValueIterator<Writable>(path, pt, filter, null, true, conf);
+ while (iterator.hasNext()) {
+ iterator.next();
+ count++;
+ }
+ return count;
+ }
+
+ public static InputStream openStream(Path path, Configuration conf) throws IOException {
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ return fs.open(path.makeQualified(path.toUri(), path));
+ }
+
+ public static FileStatus[] getFileStatus(Path path, PathType pathType, PathFilter filter,
+ Comparator<FileStatus> ordering, Configuration conf) throws IOException {
+ FileStatus[] statuses;
+ FileSystem fs = path.getFileSystem(conf);
+ if (filter == null) {
+ statuses = pathType == PathType.GLOB ? fs.globStatus(path) : listStatus(fs, path);
+ } else {
+ statuses = pathType == PathType.GLOB ? fs.globStatus(path, filter) : listStatus(fs, path, filter);
+ }
+ if (ordering != null) {
+ Arrays.sort(statuses, ordering);
+ }
+ return statuses;
+ }
+
+ public static FileStatus[] listStatus(FileSystem fs, Path path) throws IOException {
+ try {
+ return fs.listStatus(path);
+ } catch (FileNotFoundException e) {
+ return new FileStatus[0];
+ }
+ }
+
+ public static FileStatus[] listStatus(FileSystem fs, Path path, PathFilter filter) throws IOException {
+ try {
+ return fs.listStatus(path, filter);
+ } catch (FileNotFoundException e) {
+ return new FileStatus[0];
+ }
+ }
+
+ public static void cacheFiles(Path fileToCache, Configuration conf) {
+ DistributedCache.setCacheFiles(new URI[]{fileToCache.toUri()}, conf);
+ }
+
+ /**
+ * Return the first cached file in the list, else null if thre are no cached files.
+ * @param conf - MapReduce Configuration
+ * @return Path of Cached file
+ * @throws IOException - IO Exception
+ */
+ public static Path getSingleCachedFile(Configuration conf) throws IOException {
+ return getCachedFiles(conf)[0];
+ }
+
+ /**
+ * Retrieves paths to cached files.
+ * @param conf - MapReduce Configuration
+ * @return Path[] of Cached Files
+ * @throws IOException - IO Exception
+ * @throws IllegalStateException if no cache files are found
+ */
+ public static Path[] getCachedFiles(Configuration conf) throws IOException {
+ LocalFileSystem localFs = FileSystem.getLocal(conf);
+ Path[] cacheFiles = DistributedCache.getLocalCacheFiles(conf);
+
+ URI[] fallbackFiles = DistributedCache.getCacheFiles(conf);
+
+ // fallback for local execution
+ if (cacheFiles == null) {
+
+ Preconditions.checkState(fallbackFiles != null, "Unable to find cached files!");
+
+ cacheFiles = new Path[fallbackFiles.length];
+ for (int n = 0; n < fallbackFiles.length; n++) {
+ cacheFiles[n] = new Path(fallbackFiles[n].getPath());
+ }
+ } else {
+
+ for (int n = 0; n < cacheFiles.length; n++) {
+ cacheFiles[n] = localFs.makeQualified(cacheFiles[n]);
+ // fallback for local execution
+ if (!localFs.exists(cacheFiles[n])) {
+ cacheFiles[n] = new Path(fallbackFiles[n].getPath());
+ }
+ }
+ }
+
+ Preconditions.checkState(cacheFiles.length > 0, "Unable to find cached files!");
+
+ return cacheFiles;
+ }
+
+ public static void setSerializations(Configuration configuration) {
+ configuration.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ }
+
+ public static void writeInt(int value, Path path, Configuration configuration) throws IOException {
+ FileSystem fs = FileSystem.get(path.toUri(), configuration);
+ FSDataOutputStream out = fs.create(path);
+ try {
+ out.writeInt(value);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public static int readInt(Path path, Configuration configuration) throws IOException {
+ FileSystem fs = FileSystem.get(path.toUri(), configuration);
+ FSDataInputStream in = fs.open(path);
+ try {
+ return in.readInt();
+ } finally {
+ Closeables.close(in, true);
+ }
+ }
+
+ /**
+ * Builds a comma-separated list of input splits
+ * @param fs - File System
+ * @param fileStatus - File Status
+ * @return list of directories as a comma-separated String
+ * @throws IOException - IO Exception
+ */
+ public static String buildDirList(FileSystem fs, FileStatus fileStatus) throws IOException {
+ boolean containsFiles = false;
+ List<String> directoriesList = Lists.newArrayList();
+ for (FileStatus childFileStatus : fs.listStatus(fileStatus.getPath())) {
+ if (childFileStatus.isDir()) {
+ String subDirectoryList = buildDirList(fs, childFileStatus);
+ directoriesList.add(subDirectoryList);
+ } else {
+ containsFiles = true;
+ }
+ }
+
+ if (containsFiles) {
+ directoriesList.add(fileStatus.getPath().toUri().getPath());
+ }
+ return Joiner.on(',').skipNulls().join(directoriesList.iterator());
+ }
+
+ /**
+ * Builds a comma-separated list of input splits
+ * @param fs - File System
+ * @param fileStatus - File Status
+ * @param pathFilter - path filter
+ * @return list of directories as a comma-separated String
+ * @throws IOException - IO Exception
+ */
+ public static String buildDirList(FileSystem fs, FileStatus fileStatus, PathFilter pathFilter) throws IOException {
+ boolean containsFiles = false;
+ List<String> directoriesList = Lists.newArrayList();
+ for (FileStatus childFileStatus : fs.listStatus(fileStatus.getPath(), pathFilter)) {
+ if (childFileStatus.isDir()) {
+ String subDirectoryList = buildDirList(fs, childFileStatus);
+ directoriesList.add(subDirectoryList);
+ } else {
+ containsFiles = true;
+ }
+ }
+
+ if (containsFiles) {
+ directoriesList.add(fileStatus.getPath().toUri().getPath());
+ }
+ return Joiner.on(',').skipNulls().join(directoriesList.iterator());
+ }
+
+ /**
+ *
+ * @param configuration - configuration
+ * @param filePath - Input File Path
+ * @return relative file Path
+ * @throws IOException - IO Exception
+ */
+ public static String calcRelativeFilePath(Configuration configuration, Path filePath) throws IOException {
+ FileSystem fs = filePath.getFileSystem(configuration);
+ FileStatus fst = fs.getFileStatus(filePath);
+ String currentPath = fst.getPath().toString().replaceFirst("file:", "");
+
+ String basePath = configuration.get("baseinputpath");
+ if (!basePath.endsWith("/")) {
+ basePath += "/";
+ }
+ basePath = basePath.replaceFirst("file:", "");
+ String[] parts = currentPath.split(basePath);
+
+ if (parts.length == 2) {
+ return parts[1];
+ } else if (parts.length == 1) {
+ return parts[0];
+ }
+ return currentPath;
+ }
+
+ /**
+ * Finds a file in the DistributedCache
+ *
+ * @param partOfFilename a substring of the file name
+ * @param localFiles holds references to files stored in distributed cache
+ * @return Path to first matched file or null if nothing was found
+ **/
+ public static Path findInCacheByPartOfFilename(String partOfFilename, URI[] localFiles) {
+ for (URI distCacheFile : localFiles) {
+ log.info("trying find a file in distributed cache containing [{}] in its name", partOfFilename);
+ if (distCacheFile != null && distCacheFile.toString().contains(partOfFilename)) {
+ log.info("found file [{}] containing [{}]", distCacheFile.toString(), partOfFilename);
+ return new Path(distCacheFile.getPath());
+ }
+ }
+ return null;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/IntPairWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/IntPairWritable.java b/mr/src/main/java/org/apache/mahout/common/IntPairWritable.java
new file mode 100644
index 0000000..dacd66f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/IntPairWritable.java
@@ -0,0 +1,270 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import org.apache.hadoop.io.BinaryComparable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.io.WritableComparator;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.Arrays;
+
+/**
+ * A {@link WritableComparable} which encapsulates an ordered pair of signed integers.
+ */
+public final class IntPairWritable extends BinaryComparable
+ implements WritableComparable<BinaryComparable>, Cloneable {
+
+ static final int INT_BYTE_LENGTH = 4;
+ static final int INT_PAIR_BYTE_LENGTH = 2 * INT_BYTE_LENGTH;
+ private byte[] b = new byte[INT_PAIR_BYTE_LENGTH];
+
+ public IntPairWritable() {
+ setFirst(0);
+ setSecond(0);
+ }
+
+ public IntPairWritable(IntPairWritable pair) {
+ b = Arrays.copyOf(pair.getBytes(), INT_PAIR_BYTE_LENGTH);
+ }
+
+ public IntPairWritable(int x, int y) {
+ putInt(x, b, 0);
+ putInt(y, b, INT_BYTE_LENGTH);
+ }
+
+ public void set(int x, int y) {
+ putInt(x, b, 0);
+ putInt(y, b, INT_BYTE_LENGTH);
+ }
+
+ public void setFirst(int x) {
+ putInt(x, b, 0);
+ }
+
+ public int getFirst() {
+ return getInt(b, 0);
+ }
+
+ public void setSecond(int y) {
+ putInt(y, b, INT_BYTE_LENGTH);
+ }
+
+ public int getSecond() {
+ return getInt(b, INT_BYTE_LENGTH);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ in.readFully(b);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.write(b);
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(b);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!super.equals(obj)) {
+ return false;
+ }
+ if (!(obj instanceof IntPairWritable)) {
+ return false;
+ }
+ IntPairWritable other = (IntPairWritable) obj;
+ return Arrays.equals(b, other.b);
+ }
+
+ @Override
+ public int compareTo(BinaryComparable other) {
+ return Comparator.doCompare(b, 0, ((IntPairWritable) other).b, 0);
+ }
+
+ @Override
+ public Object clone() {
+ return new IntPairWritable(this);
+ }
+
+ @Override
+ public String toString() {
+ return "(" + getFirst() + ", " + getSecond() + ')';
+ }
+
+ @Override
+ public byte[] getBytes() {
+ return b;
+ }
+
+ @Override
+ public int getLength() {
+ return INT_PAIR_BYTE_LENGTH;
+ }
+
+ private static void putInt(int value, byte[] b, int offset) {
+ for (int i = offset, j = 24; j >= 0; i++, j -= 8) {
+ b[i] = (byte) (value >> j);
+ }
+ }
+
+ private static int getInt(byte[] b, int offset) {
+ int value = 0;
+ for (int i = offset, j = 24; j >= 0; i++, j -= 8) {
+ value |= (b[i] & 0xFF) << j;
+ }
+ return value;
+ }
+
+ static {
+ WritableComparator.define(IntPairWritable.class, new Comparator());
+ }
+
+ public static final class Comparator extends WritableComparator implements Serializable {
+ public Comparator() {
+ super(IntPairWritable.class);
+ }
+
+ @Override
+ public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
+ return doCompare(b1, s1, b2, s2);
+ }
+
+ static int doCompare(byte[] b1, int s1, byte[] b2, int s2) {
+ int compare1 = compareInts(b1, s1, b2, s2);
+ if (compare1 != 0) {
+ return compare1;
+ }
+ return compareInts(b1, s1 + INT_BYTE_LENGTH, b2, s2 + INT_BYTE_LENGTH);
+ }
+
+ private static int compareInts(byte[] b1, int s1, byte[] b2, int s2) {
+ // Like WritableComparator.compareBytes(), but treats first byte as signed value
+ int end1 = s1 + INT_BYTE_LENGTH;
+ for (int i = s1, j = s2; i < end1; i++, j++) {
+ int a = b1[i];
+ int b = b2[j];
+ if (i > s1) {
+ a &= 0xff;
+ b &= 0xff;
+ }
+ if (a != b) {
+ return a - b;
+ }
+ }
+ return 0;
+ }
+ }
+
+ /**
+ * Compare only the first part of the pair, so that reduce is called once for each value of the first part.
+ */
+ public static class FirstGroupingComparator extends WritableComparator implements Serializable {
+
+ public FirstGroupingComparator() {
+ super(IntPairWritable.class);
+ }
+
+ @Override
+ public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
+ int firstb1 = WritableComparator.readInt(b1, s1);
+ int firstb2 = WritableComparator.readInt(b2, s2);
+ if (firstb1 < firstb2) {
+ return -1;
+ } else if (firstb1 > firstb2) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+ @Override
+ public int compare(Object o1, Object o2) {
+ int firstb1 = ((IntPairWritable) o1).getFirst();
+ int firstb2 = ((IntPairWritable) o2).getFirst();
+ if (firstb1 < firstb2) {
+ return -1;
+ }
+ if (firstb1 > firstb2) {
+ return 1;
+ }
+ return 0;
+ }
+
+ }
+
+ /** A wrapper class that associates pairs with frequency (Occurrences) */
+ public static class Frequency implements Comparable<Frequency>, Serializable {
+
+ private final IntPairWritable pair;
+ private final double frequency;
+
+ public Frequency(IntPairWritable bigram, double frequency) {
+ this.pair = new IntPairWritable(bigram);
+ this.frequency = frequency;
+ }
+
+ public double getFrequency() {
+ return frequency;
+ }
+
+ public IntPairWritable getPair() {
+ return pair;
+ }
+
+ @Override
+ public int hashCode() {
+ return pair.hashCode() + RandomUtils.hashDouble(frequency);
+ }
+
+ @Override
+ public boolean equals(Object right) {
+ if (!(right instanceof Frequency)) {
+ return false;
+ }
+ Frequency that = (Frequency) right;
+ return pair.equals(that.pair) && frequency == that.frequency;
+ }
+
+ @Override
+ public int compareTo(Frequency that) {
+ if (frequency < that.frequency) {
+ return -1;
+ }
+ if (frequency > that.frequency) {
+ return 1;
+ }
+ return 0;
+ }
+
+ @Override
+ public String toString() {
+ return pair + "\t" + frequency;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/IntegerTuple.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/IntegerTuple.java b/mr/src/main/java/org/apache/mahout/common/IntegerTuple.java
new file mode 100644
index 0000000..f456d4d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/IntegerTuple.java
@@ -0,0 +1,176 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.WritableComparable;
+
+/**
+ * An Ordered List of Integers which can be used in a Hadoop Map/Reduce Job
+ *
+ *
+ */
+public final class IntegerTuple implements WritableComparable<IntegerTuple> {
+
+ private List<Integer> tuple = Lists.newArrayList();
+
+ public IntegerTuple() { }
+
+ public IntegerTuple(Integer firstEntry) {
+ add(firstEntry);
+ }
+
+ public IntegerTuple(Iterable<Integer> entries) {
+ for (Integer entry : entries) {
+ add(entry);
+ }
+ }
+
+ public IntegerTuple(Integer[] entries) {
+ for (Integer entry : entries) {
+ add(entry);
+ }
+ }
+
+ /**
+ * add an entry to the end of the list
+ *
+ * @param entry
+ * @return true if the items get added
+ */
+ public boolean add(Integer entry) {
+ return tuple.add(entry);
+ }
+
+ /**
+ * Fetches the string at the given location
+ *
+ * @param index
+ * @return String value at the given location in the tuple list
+ */
+ public Integer integerAt(int index) {
+ return tuple.get(index);
+ }
+
+ /**
+ * Replaces the string at the given index with the given newString
+ *
+ * @param index
+ * @param newInteger
+ * @return The previous value at that location
+ */
+ public Integer replaceAt(int index, Integer newInteger) {
+ return tuple.set(index, newInteger);
+ }
+
+ /**
+ * Fetch the list of entries from the tuple
+ *
+ * @return a List containing the strings in the order of insertion
+ */
+ public List<Integer> getEntries() {
+ return Collections.unmodifiableList(this.tuple);
+ }
+
+ /**
+ * Returns the length of the tuple
+ *
+ * @return length
+ */
+ public int length() {
+ return this.tuple.size();
+ }
+
+ @Override
+ public String toString() {
+ return tuple.toString();
+ }
+
+ @Override
+ public int hashCode() {
+ return tuple.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ IntegerTuple other = (IntegerTuple) obj;
+ if (tuple == null) {
+ if (other.tuple != null) {
+ return false;
+ }
+ } else if (!tuple.equals(other.tuple)) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int len = in.readInt();
+ tuple = Lists.newArrayListWithCapacity(len);
+ for (int i = 0; i < len; i++) {
+ int data = in.readInt();
+ tuple.add(data);
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(tuple.size());
+ for (Integer entry : tuple) {
+ out.writeInt(entry);
+ }
+ }
+
+ @Override
+ public int compareTo(IntegerTuple otherTuple) {
+ int thisLength = length();
+ int otherLength = otherTuple.length();
+ int min = Math.min(thisLength, otherLength);
+ for (int i = 0; i < min; i++) {
+ int ret = this.tuple.get(i).compareTo(otherTuple.integerAt(i));
+ if (ret == 0) {
+ continue;
+ }
+ return ret;
+ }
+ if (thisLength < otherLength) {
+ return -1;
+ } else if (thisLength > otherLength) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/LongPair.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/LongPair.java b/mr/src/main/java/org/apache/mahout/common/LongPair.java
new file mode 100644
index 0000000..5215e3a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/LongPair.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.Serializable;
+
+import com.google.common.primitives.Longs;
+
+/** A simple (ordered) pair of longs. */
+public final class LongPair implements Comparable<LongPair>, Serializable {
+
+ private final long first;
+ private final long second;
+
+ public LongPair(long first, long second) {
+ this.first = first;
+ this.second = second;
+ }
+
+ public long getFirst() {
+ return first;
+ }
+
+ public long getSecond() {
+ return second;
+ }
+
+ public LongPair swap() {
+ return new LongPair(second, first);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof LongPair)) {
+ return false;
+ }
+ LongPair otherPair = (LongPair) obj;
+ return first == otherPair.getFirst() && second == otherPair.getSecond();
+ }
+
+ @Override
+ public int hashCode() {
+ int firstHash = Longs.hashCode(first);
+ // Flip top and bottom 16 bits; this makes the hash function probably different
+ // for (a,b) versus (b,a)
+ return (firstHash >>> 16 | firstHash << 16) ^ Longs.hashCode(second);
+ }
+
+ @Override
+ public String toString() {
+ return '(' + String.valueOf(first) + ',' + second + ')';
+ }
+
+ @Override
+ public int compareTo(LongPair o) {
+ if (first < o.getFirst()) {
+ return -1;
+ } else if (first > o.getFirst()) {
+ return 1;
+ } else {
+ return second < o.getSecond() ? -1 : second > o.getSecond() ? 1 : 0;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/MemoryUtil.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/MemoryUtil.java b/mr/src/main/java/org/apache/mahout/common/MemoryUtil.java
new file mode 100644
index 0000000..f241b53
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/MemoryUtil.java
@@ -0,0 +1,99 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.common;
+
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Memory utilities.
+ */
+public final class MemoryUtil {
+
+ private static final Logger log = LoggerFactory.getLogger(MemoryUtil.class);
+
+ private MemoryUtil() {
+ }
+
+ /**
+ * Logs current heap memory statistics.
+ *
+ * @see Runtime
+ */
+ public static void logMemoryStatistics() {
+ Runtime runtime = Runtime.getRuntime();
+ long freeBytes = runtime.freeMemory();
+ long maxBytes = runtime.maxMemory();
+ long totalBytes = runtime.totalMemory();
+ long usedBytes = totalBytes - freeBytes;
+ log.info("Memory (bytes): {} used, {} heap, {} max", usedBytes, totalBytes,
+ maxBytes);
+ }
+
+ private static volatile ScheduledExecutorService scheduler;
+
+ /**
+ * Constructs and starts a memory logger thread.
+ *
+ * @param rateInMillis how often memory info should be logged.
+ */
+ public static void startMemoryLogger(long rateInMillis) {
+ stopMemoryLogger();
+ scheduler = Executors.newScheduledThreadPool(1, new ThreadFactory() {
+ private final ThreadFactory delegate = Executors.defaultThreadFactory();
+
+ @Override
+ public Thread newThread(Runnable r) {
+ Thread t = delegate.newThread(r);
+ t.setDaemon(true);
+ return t;
+ }
+ });
+ Runnable memoryLoogerRunnable = new Runnable() {
+ @Override
+ public void run() {
+ logMemoryStatistics();
+ }
+ };
+ scheduler.scheduleAtFixedRate(memoryLoogerRunnable, rateInMillis, rateInMillis,
+ TimeUnit.MILLISECONDS);
+ }
+
+ /**
+ * Constructs and starts a memory logger thread with a logging rate of 1000 milliseconds.
+ */
+ public static void startMemoryLogger() {
+ startMemoryLogger(1000);
+ }
+
+ /**
+ * Stops the memory logger, if any, started via {@link #startMemoryLogger(long)} or
+ * {@link #startMemoryLogger()}.
+ */
+ public static void stopMemoryLogger() {
+ if (scheduler != null) {
+ scheduler.shutdownNow();
+ scheduler = null;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/Pair.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/Pair.java b/mr/src/main/java/org/apache/mahout/common/Pair.java
new file mode 100644
index 0000000..d2ad6a1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/Pair.java
@@ -0,0 +1,99 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.Serializable;
+
+/** A simple (ordered) pair of two objects. Elements may be null. */
+public final class Pair<A,B> implements Comparable<Pair<A,B>>, Serializable {
+
+ private final A first;
+ private final B second;
+
+ public Pair(A first, B second) {
+ this.first = first;
+ this.second = second;
+ }
+
+ public A getFirst() {
+ return first;
+ }
+
+ public B getSecond() {
+ return second;
+ }
+
+ public Pair<B, A> swap() {
+ return new Pair<>(second, first);
+ }
+
+ public static <A,B> Pair<A,B> of(A a, B b) {
+ return new Pair<>(a, b);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (!(obj instanceof Pair<?, ?>)) {
+ return false;
+ }
+ Pair<?, ?> otherPair = (Pair<?, ?>) obj;
+ return isEqualOrNulls(first, otherPair.getFirst())
+ && isEqualOrNulls(second, otherPair.getSecond());
+ }
+
+ private static boolean isEqualOrNulls(Object obj1, Object obj2) {
+ return obj1 == null ? obj2 == null : obj1.equals(obj2);
+ }
+
+ @Override
+ public int hashCode() {
+ int firstHash = hashCodeNull(first);
+ // Flip top and bottom 16 bits; this makes the hash function probably different
+ // for (a,b) versus (b,a)
+ return (firstHash >>> 16 | firstHash << 16) ^ hashCodeNull(second);
+ }
+
+ private static int hashCodeNull(Object obj) {
+ return obj == null ? 0 : obj.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return '(' + String.valueOf(first) + ',' + second + ')';
+ }
+
+ /**
+ * Defines an ordering on pairs that sorts by first value's natural ordering, ascending,
+ * and then by second value's natural ordering.
+ *
+ * @throws ClassCastException if types are not actually {@link Comparable}
+ */
+ @Override
+ public int compareTo(Pair<A,B> other) {
+ Comparable<A> thisFirst = (Comparable<A>) first;
+ A thatFirst = other.getFirst();
+ int compare = thisFirst.compareTo(thatFirst);
+ if (compare != 0) {
+ return compare;
+ }
+ Comparable<B> thisSecond = (Comparable<B>) second;
+ B thatSecond = other.getSecond();
+ return thisSecond.compareTo(thatSecond);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/Parameters.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/Parameters.java b/mr/src/main/java/org/apache/mahout/common/Parameters.java
new file mode 100644
index 0000000..e74c534
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/Parameters.java
@@ -0,0 +1,98 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DefaultStringifier;
+import org.apache.hadoop.util.GenericsUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class Parameters {
+
+ private static final Logger log = LoggerFactory.getLogger(Parameters.class);
+
+ private Map<String,String> params = Maps.newHashMap();
+
+ public Parameters() {
+
+ }
+
+ public Parameters(String serializedString) throws IOException {
+ this(parseParams(serializedString));
+ }
+
+ protected Parameters(Map<String,String> params) {
+ this.params = params;
+ }
+
+ public String get(String key) {
+ return params.get(key);
+ }
+
+ public String get(String key, String defaultValue) {
+ String ret = params.get(key);
+ return ret == null ? defaultValue : ret;
+ }
+
+ public void set(String key, String value) {
+ params.put(key, value);
+ }
+
+ public int getInt(String key, int defaultValue) {
+ String ret = params.get(key);
+ return ret == null ? defaultValue : Integer.parseInt(ret);
+ }
+
+ @Override
+ public String toString() {
+ Configuration conf = new Configuration();
+ conf.set("io.serializations",
+ "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ DefaultStringifier<Map<String,String>> mapStringifier = new DefaultStringifier<>(conf,
+ GenericsUtil.getClass(params));
+ try {
+ return mapStringifier.toString(params);
+ } catch (IOException e) {
+ log.info("Encountered IOException while deserializing returning empty string", e);
+ return "";
+ }
+
+ }
+
+ public String print() {
+ return params.toString();
+ }
+
+ public static Map<String,String> parseParams(String serializedString) throws IOException {
+ Configuration conf = new Configuration();
+ conf.set("io.serializations",
+ "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ Map<String,String> params = Maps.newHashMap();
+ DefaultStringifier<Map<String,String>> mapStringifier = new DefaultStringifier<>(conf,
+ GenericsUtil.getClass(params));
+ return mapStringifier.fromString(serializedString);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/StringTuple.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/StringTuple.java b/mr/src/main/java/org/apache/mahout/common/StringTuple.java
new file mode 100644
index 0000000..0de1a4a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/StringTuple.java
@@ -0,0 +1,177 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+
+/**
+ * An Ordered List of Strings which can be used in a Hadoop Map/Reduce Job
+ */
+public final class StringTuple implements WritableComparable<StringTuple> {
+
+ private List<String> tuple = Lists.newArrayList();
+
+ public StringTuple() { }
+
+ public StringTuple(String firstEntry) {
+ add(firstEntry);
+ }
+
+ public StringTuple(Iterable<String> entries) {
+ for (String entry : entries) {
+ add(entry);
+ }
+ }
+
+ public StringTuple(String[] entries) {
+ for (String entry : entries) {
+ add(entry);
+ }
+ }
+
+ /**
+ * add an entry to the end of the list
+ *
+ * @param entry
+ * @return true if the items get added
+ */
+ public boolean add(String entry) {
+ return tuple.add(entry);
+ }
+
+ /**
+ * Fetches the string at the given location
+ *
+ * @param index
+ * @return String value at the given location in the tuple list
+ */
+ public String stringAt(int index) {
+ return tuple.get(index);
+ }
+
+ /**
+ * Replaces the string at the given index with the given newString
+ *
+ * @param index
+ * @param newString
+ * @return The previous value at that location
+ */
+ public String replaceAt(int index, String newString) {
+ return tuple.set(index, newString);
+ }
+
+ /**
+ * Fetch the list of entries from the tuple
+ *
+ * @return a List containing the strings in the order of insertion
+ */
+ public List<String> getEntries() {
+ return Collections.unmodifiableList(this.tuple);
+ }
+
+ /**
+ * Returns the length of the tuple
+ *
+ * @return length
+ */
+ public int length() {
+ return this.tuple.size();
+ }
+
+ @Override
+ public String toString() {
+ return tuple.toString();
+ }
+
+ @Override
+ public int hashCode() {
+ return tuple.hashCode();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ StringTuple other = (StringTuple) obj;
+ if (tuple == null) {
+ if (other.tuple != null) {
+ return false;
+ }
+ } else if (!tuple.equals(other.tuple)) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int len = in.readInt();
+ tuple = Lists.newArrayListWithCapacity(len);
+ Text value = new Text();
+ for (int i = 0; i < len; i++) {
+ value.readFields(in);
+ tuple.add(value.toString());
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(tuple.size());
+ Text value = new Text();
+ for (String entry : tuple) {
+ value.set(entry);
+ value.write(out);
+ }
+ }
+
+ @Override
+ public int compareTo(StringTuple otherTuple) {
+ int thisLength = length();
+ int otherLength = otherTuple.length();
+ int min = Math.min(thisLength, otherLength);
+ for (int i = 0; i < min; i++) {
+ int ret = this.tuple.get(i).compareTo(otherTuple.stringAt(i));
+ if (ret != 0) {
+ return ret;
+ }
+ }
+ if (thisLength < otherLength) {
+ return -1;
+ } else if (thisLength > otherLength) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/StringUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/StringUtils.java b/mr/src/main/java/org/apache/mahout/common/StringUtils.java
new file mode 100644
index 0000000..a064596
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/StringUtils.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.util.regex.Pattern;
+
+import com.thoughtworks.xstream.XStream;
+
+/**
+ * Offers two methods to convert an object to a string representation and restore the object given its string
+ * representation. Should use Hadoop Stringifier whenever available.
+ */
+public final class StringUtils {
+
+ private static final XStream XSTREAM = new XStream();
+ private static final Pattern NEWLINE_PATTERN = Pattern.compile("\n");
+ private static final Pattern XMLRESERVED = Pattern.compile("\"|\\&|\\<|\\>|\'");
+
+ private StringUtils() {
+ // do nothing
+ }
+
+ /**
+ * Converts the object to a one-line string representation
+ *
+ * @param obj
+ * the object to convert
+ * @return the string representation of the object
+ */
+ public static String toString(Object obj) {
+ return NEWLINE_PATTERN.matcher(XSTREAM.toXML(obj)).replaceAll("");
+ }
+
+ /**
+ * Restores the object from its string representation.
+ *
+ * @param str
+ * the string representation of the object
+ * @return restored object
+ */
+ public static <T> T fromString(String str) {
+ return (T) XSTREAM.fromXML(str);
+ }
+
+ public static String escapeXML(CharSequence input) {
+ return XMLRESERVED.matcher(input).replaceAll("_");
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/TimingStatistics.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/TimingStatistics.java b/mr/src/main/java/org/apache/mahout/common/TimingStatistics.java
new file mode 100644
index 0000000..5ee2066
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/TimingStatistics.java
@@ -0,0 +1,154 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.Serializable;
+import java.text.DecimalFormat;
+
+public final class TimingStatistics implements Serializable {
+ private static final DecimalFormat DF = new DecimalFormat("#.##");
+ private int nCalls;
+ private long minTime;
+ private long maxTime;
+ private long sumTime;
+ private long leadSumTime;
+ private double sumSquaredTime;
+
+
+ /** Creates a new instance of CallStats */
+ public TimingStatistics() { }
+
+ public TimingStatistics(int nCalls, long minTime, long maxTime, long sumTime, double sumSquaredTime) {
+ this.nCalls = nCalls;
+ this.minTime = minTime;
+ this.maxTime = maxTime;
+ this.sumTime = sumTime;
+ this.sumSquaredTime = sumSquaredTime;
+ }
+
+ public synchronized int getNCalls() {
+ return nCalls;
+ }
+
+ public synchronized long getMinTime() {
+ return Math.max(0, minTime);
+ }
+
+ public synchronized long getMaxTime() {
+ return maxTime;
+ }
+
+ public synchronized long getSumTime() {
+ return sumTime;
+ }
+
+ public synchronized double getSumSquaredTime() {
+ return sumSquaredTime;
+ }
+
+ public synchronized long getMeanTime() {
+ return nCalls == 0 ? 0 : sumTime / nCalls;
+ }
+
+ public synchronized long getStdDevTime() {
+ if (nCalls == 0) {
+ return 0;
+ }
+ double mean = getMeanTime();
+ double meanSquared = mean * mean;
+ double meanOfSquares = sumSquaredTime / nCalls;
+ double variance = meanOfSquares - meanSquared;
+ if (variance < 0) {
+ return 0; // might happen due to rounding error
+ }
+ return (long) Math.sqrt(variance);
+ }
+
+ @Override
+ public synchronized String toString() {
+ return '\n'
+ + "nCalls = " + nCalls + ";\n"
+ + "sum = " + DF.format(sumTime / 1000000000.0) + "s;\n"
+ + "min = " + DF.format(minTime / 1000000.0) + "ms;\n"
+ + "max = " + DF.format(maxTime / 1000000.0) + "ms;\n"
+ + "mean = " + DF.format(getMeanTime() / 1000.0) + "us;\n"
+ + "stdDev = " + DF.format(getStdDevTime() / 1000.0) + "us;";
+ }
+
+ /** Ignores counting the performance metrics until leadTimeIsFinished The caller should enough time for the JIT
+ * to warm up. */
+ public Call newCall(long leadTimeUsec) {
+ if (leadSumTime > leadTimeUsec) {
+ return new Call();
+ } else {
+ return new LeadTimeCall();
+ }
+ }
+
+ /** Ignores counting the performance metrics. The caller should enough time for the JIT to warm up. */
+ public final class LeadTimeCall extends Call {
+
+ private LeadTimeCall() { }
+
+ @Override
+ public void end() {
+ long elapsed = System.nanoTime() - startTime;
+ synchronized (TimingStatistics.this) {
+ leadSumTime += elapsed;
+ }
+ }
+
+ @Override
+ public boolean end(long sumMaxUsec) {
+ end();
+ return false;
+ }
+ }
+
+ /**
+ * A call object that can update performance metrics.
+ */
+ public class Call {
+ protected final long startTime = System.nanoTime();
+
+ private Call() { }
+
+ public void end() {
+ long elapsed = System.nanoTime() - startTime;
+ synchronized (TimingStatistics.this) {
+ nCalls++;
+ if (elapsed < minTime || nCalls == 1) {
+ minTime = elapsed;
+ }
+ if (elapsed > maxTime) {
+ maxTime = elapsed;
+ }
+ sumTime += elapsed;
+ sumSquaredTime += elapsed * elapsed;
+ }
+ }
+
+ /**
+ * Returns true if the sumTime as reached this limit;
+ */
+ public boolean end(long sumMaxUsec) {
+ end();
+ return sumMaxUsec < sumTime;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java b/mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java
new file mode 100644
index 0000000..0e7ee96
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/commandline/DefaultOptionCreator.java
@@ -0,0 +1,417 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.commandline;
+
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.clustering.kernel.TriangularKernelProfile;
+
+
+public final class DefaultOptionCreator {
+
+ public static final String CLUSTERING_OPTION = "clustering";
+
+ public static final String CLUSTERS_IN_OPTION = "clusters";
+
+ public static final String CONVERGENCE_DELTA_OPTION = "convergenceDelta";
+
+ public static final String DISTANCE_MEASURE_OPTION = "distanceMeasure";
+
+ public static final String EMIT_MOST_LIKELY_OPTION = "emitMostLikely";
+
+ public static final String INPUT_OPTION = "input";
+
+ public static final String MAX_ITERATIONS_OPTION = "maxIter";
+
+ public static final String MAX_REDUCERS_OPTION = "maxRed";
+
+ public static final String METHOD_OPTION = "method";
+
+ public static final String NUM_CLUSTERS_OPTION = "numClusters";
+
+ public static final String OUTPUT_OPTION = "output";
+
+ public static final String OVERWRITE_OPTION = "overwrite";
+
+ public static final String T1_OPTION = "t1";
+
+ public static final String T2_OPTION = "t2";
+
+ public static final String T3_OPTION = "t3";
+
+ public static final String T4_OPTION = "t4";
+
+ public static final String OUTLIER_THRESHOLD = "outlierThreshold";
+
+ public static final String CLUSTER_FILTER_OPTION = "clusterFilter";
+
+ public static final String THRESHOLD_OPTION = "threshold";
+
+ public static final String SEQUENTIAL_METHOD = "sequential";
+
+ public static final String MAPREDUCE_METHOD = "mapreduce";
+
+ public static final String KERNEL_PROFILE_OPTION = "kernelProfile";
+
+ public static final String ANALYZER_NAME_OPTION = "analyzerName";
+
+ public static final String RANDOM_SEED = "randomSeed";
+
+ private DefaultOptionCreator() {}
+
+ /**
+ * Returns a default command line option for help. Used by all clustering jobs
+ * and many others
+ * */
+ public static Option helpOption() {
+ return new DefaultOptionBuilder().withLongName("help")
+ .withDescription("Print out help").withShortName("h").create();
+ }
+
+ /**
+ * Returns a default command line option for input directory specification.
+ * Used by all clustering jobs plus others
+ */
+ public static DefaultOptionBuilder inputOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(INPUT_OPTION)
+ .withRequired(false)
+ .withShortName("i")
+ .withArgument(
+ new ArgumentBuilder().withName(INPUT_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription("Path to job input directory.");
+ }
+
+ /**
+ * Returns a default command line option for clusters input directory
+ * specification. Used by FuzzyKmeans, Kmeans
+ */
+ public static DefaultOptionBuilder clustersInOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(CLUSTERS_IN_OPTION)
+ .withRequired(true)
+ .withArgument(
+ new ArgumentBuilder().withName(CLUSTERS_IN_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription(
+ "The path to the initial clusters directory. Must be a SequenceFile of some type of Cluster")
+ .withShortName("c");
+ }
+
+ /**
+ * Returns a default command line option for output directory specification.
+ * Used by all clustering jobs plus others
+ */
+ public static DefaultOptionBuilder outputOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(OUTPUT_OPTION)
+ .withRequired(false)
+ .withShortName("o")
+ .withArgument(
+ new ArgumentBuilder().withName(OUTPUT_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription("The directory pathname for output.");
+ }
+
+ /**
+ * Returns a default command line option for output directory overwriting.
+ * Used by all clustering jobs
+ */
+ public static DefaultOptionBuilder overwriteOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(OVERWRITE_OPTION)
+ .withRequired(false)
+ .withDescription(
+ "If present, overwrite the output directory before running job")
+ .withShortName("ow");
+ }
+
+ /**
+ * Returns a default command line option for specification of distance measure
+ * class to use. Used by Canopy, FuzzyKmeans, Kmeans, MeanShift
+ */
+ public static DefaultOptionBuilder distanceMeasureOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(DISTANCE_MEASURE_OPTION)
+ .withRequired(false)
+ .withShortName("dm")
+ .withArgument(
+ new ArgumentBuilder().withName(DISTANCE_MEASURE_OPTION)
+ .withDefault(SquaredEuclideanDistanceMeasure.class.getName())
+ .withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "The classname of the DistanceMeasure. Default is SquaredEuclidean");
+ }
+
+ /**
+ * Returns a default command line option for specification of sequential or
+ * parallel operation. Used by Canopy, FuzzyKmeans, Kmeans, MeanShift,
+ * Dirichlet
+ */
+ public static DefaultOptionBuilder methodOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(METHOD_OPTION)
+ .withRequired(false)
+ .withShortName("xm")
+ .withArgument(
+ new ArgumentBuilder().withName(METHOD_OPTION)
+ .withDefault(MAPREDUCE_METHOD).withMinimum(1).withMaximum(1)
+ .create())
+ .withDescription(
+ "The execution method to use: sequential or mapreduce. Default is mapreduce");
+ }
+
+ /**
+ * Returns a default command line option for specification of T1. Used by
+ * Canopy, MeanShift
+ */
+ public static DefaultOptionBuilder t1Option() {
+ return new DefaultOptionBuilder()
+ .withLongName(T1_OPTION)
+ .withRequired(true)
+ .withArgument(
+ new ArgumentBuilder().withName(T1_OPTION).withMinimum(1)
+ .withMaximum(1).create()).withDescription("T1 threshold value")
+ .withShortName(T1_OPTION);
+ }
+
+ /**
+ * Returns a default command line option for specification of T2. Used by
+ * Canopy, MeanShift
+ */
+ public static DefaultOptionBuilder t2Option() {
+ return new DefaultOptionBuilder()
+ .withLongName(T2_OPTION)
+ .withRequired(true)
+ .withArgument(
+ new ArgumentBuilder().withName(T2_OPTION).withMinimum(1)
+ .withMaximum(1).create()).withDescription("T2 threshold value")
+ .withShortName(T2_OPTION);
+ }
+
+ /**
+ * Returns a default command line option for specification of T3 (Reducer T1).
+ * Used by Canopy
+ */
+ public static DefaultOptionBuilder t3Option() {
+ return new DefaultOptionBuilder()
+ .withLongName(T3_OPTION)
+ .withRequired(false)
+ .withArgument(
+ new ArgumentBuilder().withName(T3_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription("T3 (Reducer T1) threshold value")
+ .withShortName(T3_OPTION);
+ }
+
+ /**
+ * Returns a default command line option for specification of T4 (Reducer T2).
+ * Used by Canopy
+ */
+ public static DefaultOptionBuilder t4Option() {
+ return new DefaultOptionBuilder()
+ .withLongName(T4_OPTION)
+ .withRequired(false)
+ .withArgument(
+ new ArgumentBuilder().withName(T4_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription("T4 (Reducer T2) threshold value")
+ .withShortName(T4_OPTION);
+ }
+
+ /**
+ * @return a DefaultOptionBuilder for the clusterFilter option
+ */
+ public static DefaultOptionBuilder clusterFilterOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(CLUSTER_FILTER_OPTION)
+ .withShortName("cf")
+ .withRequired(false)
+ .withArgument(
+ new ArgumentBuilder().withName(CLUSTER_FILTER_OPTION).withMinimum(1)
+ .withMaximum(1).create())
+ .withDescription("Cluster filter suppresses small canopies from mapper")
+ .withShortName(CLUSTER_FILTER_OPTION);
+ }
+
+ /**
+ * Returns a default command line option for specification of max number of
+ * iterations. Used by Dirichlet, FuzzyKmeans, Kmeans, LDA
+ */
+ public static DefaultOptionBuilder maxIterationsOption() {
+ // default value used by LDA which overrides withRequired(false)
+ return new DefaultOptionBuilder()
+ .withLongName(MAX_ITERATIONS_OPTION)
+ .withRequired(true)
+ .withShortName("x")
+ .withArgument(
+ new ArgumentBuilder().withName(MAX_ITERATIONS_OPTION)
+ .withDefault("-1").withMinimum(1).withMaximum(1).create())
+ .withDescription("The maximum number of iterations.");
+ }
+
+ /**
+ * Returns a default command line option for specification of numbers of
+ * clusters to create. Used by Dirichlet, FuzzyKmeans, Kmeans
+ */
+ public static DefaultOptionBuilder numClustersOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(NUM_CLUSTERS_OPTION)
+ .withRequired(false)
+ .withArgument(
+ new ArgumentBuilder().withName("k").withMinimum(1).withMaximum(1)
+ .create()).withDescription("The number of clusters to create")
+ .withShortName("k");
+ }
+
+ public static DefaultOptionBuilder useSetRandomSeedOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(RANDOM_SEED)
+ .withRequired(false)
+ .withArgument(new ArgumentBuilder().withName(RANDOM_SEED).create())
+ .withDescription("Seed to initaize Random Number Generator with")
+ .withShortName("rs");
+ }
+
+ /**
+ * Returns a default command line option for convergence delta specification.
+ * Used by FuzzyKmeans, Kmeans, MeanShift
+ */
+ public static DefaultOptionBuilder convergenceOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(CONVERGENCE_DELTA_OPTION)
+ .withRequired(false)
+ .withShortName("cd")
+ .withArgument(
+ new ArgumentBuilder().withName(CONVERGENCE_DELTA_OPTION)
+ .withDefault("0.5").withMinimum(1).withMaximum(1).create())
+ .withDescription("The convergence delta value. Default is 0.5");
+ }
+
+ /**
+ * Returns a default command line option for specifying the max number of
+ * reducers. Used by Dirichlet, FuzzyKmeans, Kmeans and LDA
+ *
+ * @deprecated
+ */
+ @Deprecated
+ public static DefaultOptionBuilder numReducersOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(MAX_REDUCERS_OPTION)
+ .withRequired(false)
+ .withShortName("r")
+ .withArgument(
+ new ArgumentBuilder().withName(MAX_REDUCERS_OPTION)
+ .withDefault("2").withMinimum(1).withMaximum(1).create())
+ .withDescription("The number of reduce tasks. Defaults to 2");
+ }
+
+ /**
+ * Returns a default command line option for clustering specification. Used by
+ * all clustering except LDA
+ */
+ public static DefaultOptionBuilder clusteringOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(CLUSTERING_OPTION)
+ .withRequired(false)
+ .withDescription(
+ "If present, run clustering after the iterations have taken place")
+ .withShortName("cl");
+ }
+
+ /**
+ * Returns a default command line option for specifying a Lucene analyzer class
+ * @return {@link DefaultOptionBuilder}
+ */
+ public static DefaultOptionBuilder analyzerOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(ANALYZER_NAME_OPTION)
+ .withRequired(false)
+ .withDescription("If present, the name of a Lucene analyzer class to use")
+ .withArgument(new ArgumentBuilder().withName(ANALYZER_NAME_OPTION).withDefault(StandardAnalyzer.class.getName())
+ .withMinimum(1).withMaximum(1).create())
+ .withShortName("an");
+ }
+
+
+ /**
+ * Returns a default command line option for specifying the emitMostLikely
+ * flag. Used by Dirichlet and FuzzyKmeans
+ */
+ public static DefaultOptionBuilder emitMostLikelyOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(EMIT_MOST_LIKELY_OPTION)
+ .withRequired(false)
+ .withShortName("e")
+ .withArgument(
+ new ArgumentBuilder().withName(EMIT_MOST_LIKELY_OPTION)
+ .withDefault("true").withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "True if clustering should emit the most likely point only, "
+ + "false for threshold clustering. Default is true");
+ }
+
+ /**
+ * Returns a default command line option for specifying the clustering
+ * threshold value. Used by Dirichlet and FuzzyKmeans
+ */
+ public static DefaultOptionBuilder thresholdOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(THRESHOLD_OPTION)
+ .withRequired(false)
+ .withShortName("t")
+ .withArgument(
+ new ArgumentBuilder().withName(THRESHOLD_OPTION).withDefault("0")
+ .withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "The pdf threshold used for cluster determination. Default is 0");
+ }
+
+ public static DefaultOptionBuilder kernelProfileOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(KERNEL_PROFILE_OPTION)
+ .withRequired(false)
+ .withShortName("kp")
+ .withArgument(
+ new ArgumentBuilder()
+ .withName(KERNEL_PROFILE_OPTION)
+ .withDefault(TriangularKernelProfile.class.getName())
+ .withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "The classname of the IKernelProfile. Default is TriangularKernelProfile");
+ }
+
+ /**
+ * Returns a default command line option for specification of OUTLIER THRESHOLD value. Used for
+ * Cluster Classification.
+ */
+ public static DefaultOptionBuilder outlierThresholdOption() {
+ return new DefaultOptionBuilder()
+ .withLongName(OUTLIER_THRESHOLD)
+ .withRequired(false)
+ .withArgument(
+ new ArgumentBuilder().withName(OUTLIER_THRESHOLD).withMinimum(1)
+ .withMaximum(1).create()).withDescription("Outlier threshold value")
+ .withShortName(OUTLIER_THRESHOLD);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java
new file mode 100644
index 0000000..61aa9a5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/ChebyshevDistanceMeasure.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * This class implements a "Chebyshev distance" metric by finding the maximum difference
+ * between each coordinate. Also 'chessboard distance' due to the moves a king can make.
+ */
+public class ChebyshevDistanceMeasure implements DistanceMeasure {
+
+ @Override
+ public void configure(Configuration job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ // nothing to do
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ if (v1.size() != v2.size()) {
+ throw new CardinalityException(v1.size(), v2.size());
+ }
+ return v1.aggregate(v2, Functions.MAX_ABS, Functions.MINUS);
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
new file mode 100644
index 0000000..37265eb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/CosineDistanceMeasure.java
@@ -0,0 +1,119 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.Vector;
+
+/**
+ * This class implements a cosine distance metric by dividing the dot product of two vectors by the product of their
+ * lengths. That gives the cosine of the angle between the two vectors. To convert this to a usable distance,
+ * 1-cos(angle) is what is actually returned.
+ */
+public class CosineDistanceMeasure implements DistanceMeasure {
+
+ @Override
+ public void configure(Configuration job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ // nothing to do
+ }
+
+ public static double distance(double[] p1, double[] p2) {
+ double dotProduct = 0.0;
+ double lengthSquaredp1 = 0.0;
+ double lengthSquaredp2 = 0.0;
+ for (int i = 0; i < p1.length; i++) {
+ lengthSquaredp1 += p1[i] * p1[i];
+ lengthSquaredp2 += p2[i] * p2[i];
+ dotProduct += p1[i] * p2[i];
+ }
+ double denominator = Math.sqrt(lengthSquaredp1) * Math.sqrt(lengthSquaredp2);
+
+ // correct for floating-point rounding errors
+ if (denominator < dotProduct) {
+ denominator = dotProduct;
+ }
+
+ // correct for zero-vector corner case
+ if (denominator == 0 && dotProduct == 0) {
+ return 0;
+ }
+
+ return 1.0 - dotProduct / denominator;
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ if (v1.size() != v2.size()) {
+ throw new CardinalityException(v1.size(), v2.size());
+ }
+ double lengthSquaredv1 = v1.getLengthSquared();
+ double lengthSquaredv2 = v2.getLengthSquared();
+
+ double dotProduct = v2.dot(v1);
+ double denominator = Math.sqrt(lengthSquaredv1) * Math.sqrt(lengthSquaredv2);
+
+ // correct for floating-point rounding errors
+ if (denominator < dotProduct) {
+ denominator = dotProduct;
+ }
+
+ // correct for zero-vector corner case
+ if (denominator == 0 && dotProduct == 0) {
+ return 0;
+ }
+
+ return 1.0 - dotProduct / denominator;
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+
+ double lengthSquaredv = v.getLengthSquared();
+
+ double dotProduct = v.dot(centroid);
+ double denominator = Math.sqrt(centroidLengthSquare) * Math.sqrt(lengthSquaredv);
+
+ // correct for floating-point rounding errors
+ if (denominator < dotProduct) {
+ denominator = dotProduct;
+ }
+
+ // correct for zero-vector corner case
+ if (denominator == 0 && dotProduct == 0) {
+ return 0;
+ }
+
+ return 1.0 - dotProduct / denominator;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java
new file mode 100644
index 0000000..696e79c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/DistanceMeasure.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.common.parameters.Parametered;
+import org.apache.mahout.math.Vector;
+
+/** This interface is used for objects which can determine a distance metric between two points */
+public interface DistanceMeasure extends Parametered {
+
+ /**
+ * Returns the distance metric applied to the arguments
+ *
+ * @param v1
+ * a Vector defining a multidimensional point in some feature space
+ * @param v2
+ * a Vector defining a multidimensional point in some feature space
+ * @return a scalar doubles of the distance
+ */
+ double distance(Vector v1, Vector v2);
+
+ /**
+ * Optimized version of distance metric for sparse vectors. This distance computation requires operations
+ * proportional to the number of non-zero elements in the vector instead of the cardinality of the vector.
+ *
+ * @param centroidLengthSquare
+ * Square of the length of centroid
+ * @param centroid
+ * Centroid vector
+ */
+ double distance(double centroidLengthSquare, Vector centroid, Vector v);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java
new file mode 100644
index 0000000..665678d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/EuclideanDistanceMeasure.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * This class implements a Euclidean distance metric by summing the square root of the squared differences
+ * between each coordinate.
+ * <p/>
+ * If you don't care about the true distance and only need the values for comparison, then the base class,
+ * {@link SquaredEuclideanDistanceMeasure}, will be faster since it doesn't do the actual square root of the
+ * squared differences.
+ */
+public class EuclideanDistanceMeasure extends SquaredEuclideanDistanceMeasure {
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ return Math.sqrt(super.distance(v1, v2));
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return Math.sqrt(super.distance(centroidLengthSquare, centroid, v));
+ }
+}
[33/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java
new file mode 100644
index 0000000..c1d328e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmAlgorithms.java
@@ -0,0 +1,306 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Class containing implementations of the three major HMM algorithms: forward,
+ * backward and Viterbi
+ */
+public final class HmmAlgorithms {
+
+
+ /**
+ * No public constructors for utility classes.
+ */
+ private HmmAlgorithms() {
+ // nothing to do here really
+ }
+
+ /**
+ * External function to compute a matrix of alpha factors
+ *
+ * @param model model to run forward algorithm for.
+ * @param observations observation sequence to train on.
+ * @param scaled Should log-scaled beta factors be computed?
+ * @return matrix of alpha factors.
+ */
+ public static Matrix forwardAlgorithm(HmmModel model, int[] observations, boolean scaled) {
+ Matrix alpha = new DenseMatrix(observations.length, model.getNrOfHiddenStates());
+ forwardAlgorithm(alpha, model, observations, scaled);
+
+ return alpha;
+ }
+
+ /**
+ * Internal function to compute the alpha factors
+ *
+ * @param alpha matrix to store alpha factors in.
+ * @param model model to use for alpha factor computation.
+ * @param observations observation sequence seen.
+ * @param scaled set to true if log-scaled beta factors should be computed.
+ */
+ static void forwardAlgorithm(Matrix alpha, HmmModel model, int[] observations, boolean scaled) {
+
+ // fetch references to the model parameters
+ Vector ip = model.getInitialProbabilities();
+ Matrix b = model.getEmissionMatrix();
+ Matrix a = model.getTransitionMatrix();
+
+ if (scaled) { // compute log scaled alpha values
+ // Initialization
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ alpha.setQuick(0, i, Math.log(ip.getQuick(i) * b.getQuick(i, observations[0])));
+ }
+
+ // Induction
+ for (int t = 1; t < observations.length; t++) {
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ double sum = Double.NEGATIVE_INFINITY; // log(0)
+ for (int j = 0; j < model.getNrOfHiddenStates(); j++) {
+ double tmp = alpha.getQuick(t - 1, j) + Math.log(a.getQuick(j, i));
+ if (tmp > Double.NEGATIVE_INFINITY) {
+ // make sure we handle log(0) correctly
+ sum = tmp + Math.log1p(Math.exp(sum - tmp));
+ }
+ }
+ alpha.setQuick(t, i, sum + Math.log(b.getQuick(i, observations[t])));
+ }
+ }
+ } else {
+
+ // Initialization
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ alpha.setQuick(0, i, ip.getQuick(i) * b.getQuick(i, observations[0]));
+ }
+
+ // Induction
+ for (int t = 1; t < observations.length; t++) {
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ double sum = 0.0;
+ for (int j = 0; j < model.getNrOfHiddenStates(); j++) {
+ sum += alpha.getQuick(t - 1, j) * a.getQuick(j, i);
+ }
+ alpha.setQuick(t, i, sum * b.getQuick(i, observations[t]));
+ }
+ }
+ }
+ }
+
+ /**
+ * External function to compute a matrix of beta factors
+ *
+ * @param model model to use for estimation.
+ * @param observations observation sequence seen.
+ * @param scaled Set to true if log-scaled beta factors should be computed.
+ * @return beta factors based on the model and observation sequence.
+ */
+ public static Matrix backwardAlgorithm(HmmModel model, int[] observations, boolean scaled) {
+ // initialize the matrix
+ Matrix beta = new DenseMatrix(observations.length, model.getNrOfHiddenStates());
+ // compute the beta factors
+ backwardAlgorithm(beta, model, observations, scaled);
+
+ return beta;
+ }
+
+ /**
+ * Internal function to compute the beta factors
+ *
+ * @param beta Matrix to store resulting factors in.
+ * @param model model to use for factor estimation.
+ * @param observations sequence of observations to estimate.
+ * @param scaled set to true to compute log-scaled parameters.
+ */
+ static void backwardAlgorithm(Matrix beta, HmmModel model, int[] observations, boolean scaled) {
+ // fetch references to the model parameters
+ Matrix b = model.getEmissionMatrix();
+ Matrix a = model.getTransitionMatrix();
+
+ if (scaled) { // compute log-scaled factors
+ // initialization
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ beta.setQuick(observations.length - 1, i, 0);
+ }
+
+ // induction
+ for (int t = observations.length - 2; t >= 0; t--) {
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ double sum = Double.NEGATIVE_INFINITY; // log(0)
+ for (int j = 0; j < model.getNrOfHiddenStates(); j++) {
+ double tmp = beta.getQuick(t + 1, j) + Math.log(a.getQuick(i, j))
+ + Math.log(b.getQuick(j, observations[t + 1]));
+ if (tmp > Double.NEGATIVE_INFINITY) {
+ // handle log(0)
+ sum = tmp + Math.log1p(Math.exp(sum - tmp));
+ }
+ }
+ beta.setQuick(t, i, sum);
+ }
+ }
+ } else {
+ // initialization
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ beta.setQuick(observations.length - 1, i, 1);
+ }
+ // induction
+ for (int t = observations.length - 2; t >= 0; t--) {
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ double sum = 0;
+ for (int j = 0; j < model.getNrOfHiddenStates(); j++) {
+ sum += beta.getQuick(t + 1, j) * a.getQuick(i, j) * b.getQuick(j, observations[t + 1]);
+ }
+ beta.setQuick(t, i, sum);
+ }
+ }
+ }
+ }
+
+ /**
+ * Viterbi algorithm to compute the most likely hidden sequence for a given
+ * model and observed sequence
+ *
+ * @param model HmmModel for which the Viterbi path should be computed
+ * @param observations Sequence of observations
+ * @param scaled Use log-scaled computations, this requires higher computational
+ * effort but is numerically more stable for large observation
+ * sequences
+ * @return nrOfObservations 1D int array containing the most likely hidden
+ * sequence
+ */
+ public static int[] viterbiAlgorithm(HmmModel model, int[] observations, boolean scaled) {
+
+ // probability that the most probable hidden states ends at state i at
+ // time t
+ double[][] delta = new double[observations.length][model
+ .getNrOfHiddenStates()];
+
+ // previous hidden state in the most probable state leading up to state
+ // i at time t
+ int[][] phi = new int[observations.length - 1][model.getNrOfHiddenStates()];
+
+ // initialize the return array
+ int[] sequence = new int[observations.length];
+
+ viterbiAlgorithm(sequence, delta, phi, model, observations, scaled);
+
+ return sequence;
+ }
+
+ /**
+ * Internal version of the viterbi algorithm, allowing to reuse existing
+ * arrays instead of allocating new ones
+ *
+ * @param sequence NrOfObservations 1D int array for storing the viterbi sequence
+ * @param delta NrOfObservations x NrHiddenStates 2D double array for storing the
+ * delta factors
+ * @param phi NrOfObservations-1 x NrHiddenStates 2D int array for storing the
+ * phi values
+ * @param model HmmModel for which the viterbi path should be computed
+ * @param observations Sequence of observations
+ * @param scaled Use log-scaled computations, this requires higher computational
+ * effort but is numerically more stable for large observation
+ * sequences
+ */
+ static void viterbiAlgorithm(int[] sequence, double[][] delta, int[][] phi, HmmModel model, int[] observations,
+ boolean scaled) {
+ // fetch references to the model parameters
+ Vector ip = model.getInitialProbabilities();
+ Matrix b = model.getEmissionMatrix();
+ Matrix a = model.getTransitionMatrix();
+
+ // Initialization
+ if (scaled) {
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ delta[0][i] = Math.log(ip.getQuick(i) * b.getQuick(i, observations[0]));
+ }
+ } else {
+
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ delta[0][i] = ip.getQuick(i) * b.getQuick(i, observations[0]);
+ }
+ }
+
+ // Induction
+ // iterate over the time
+ if (scaled) {
+ for (int t = 1; t < observations.length; t++) {
+ // iterate over the hidden states
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ // find the maximum probability and most likely state
+ // leading up
+ // to this
+ int maxState = 0;
+ double maxProb = delta[t - 1][0] + Math.log(a.getQuick(0, i));
+ for (int j = 1; j < model.getNrOfHiddenStates(); j++) {
+ double prob = delta[t - 1][j] + Math.log(a.getQuick(j, i));
+ if (prob > maxProb) {
+ maxProb = prob;
+ maxState = j;
+ }
+ }
+ delta[t][i] = maxProb + Math.log(b.getQuick(i, observations[t]));
+ phi[t - 1][i] = maxState;
+ }
+ }
+ } else {
+ for (int t = 1; t < observations.length; t++) {
+ // iterate over the hidden states
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ // find the maximum probability and most likely state
+ // leading up
+ // to this
+ int maxState = 0;
+ double maxProb = delta[t - 1][0] * a.getQuick(0, i);
+ for (int j = 1; j < model.getNrOfHiddenStates(); j++) {
+ double prob = delta[t - 1][j] * a.getQuick(j, i);
+ if (prob > maxProb) {
+ maxProb = prob;
+ maxState = j;
+ }
+ }
+ delta[t][i] = maxProb * b.getQuick(i, observations[t]);
+ phi[t - 1][i] = maxState;
+ }
+ }
+ }
+
+ // find the most likely end state for initialization
+ double maxProb;
+ if (scaled) {
+ maxProb = Double.NEGATIVE_INFINITY;
+ } else {
+ maxProb = 0.0;
+ }
+ for (int i = 0; i < model.getNrOfHiddenStates(); i++) {
+ if (delta[observations.length - 1][i] > maxProb) {
+ maxProb = delta[observations.length - 1][i];
+ sequence[observations.length - 1] = i;
+ }
+ }
+
+ // now backtrack to find the most likely hidden sequence
+ for (int t = observations.length - 2; t >= 0; t--) {
+ sequence[t] = phi[t][sequence[t + 1]];
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java
new file mode 100644
index 0000000..6e2def6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmEvaluator.java
@@ -0,0 +1,194 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Random;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ * The HMMEvaluator class offers several methods to evaluate an HMM Model. The
+ * following use-cases are covered: 1) Generate a sequence of output states from
+ * a given model (prediction). 2) Compute the likelihood that a given model
+ * generated a given sequence of output states (model likelihood). 3) Compute
+ * the most likely hidden sequence for a given model and a given observed
+ * sequence (decoding).
+ */
+public final class HmmEvaluator {
+
+ /**
+ * No constructor for utility classes.
+ */
+ private HmmEvaluator() {}
+
+ /**
+ * Predict a sequence of steps output states for the given HMM model
+ *
+ * @param model The Hidden Markov model used to generate the output sequence
+ * @param steps Size of the generated output sequence
+ * @return integer array containing a sequence of steps output state IDs,
+ * generated by the specified model
+ */
+ public static int[] predict(HmmModel model, int steps) {
+ return predict(model, steps, RandomUtils.getRandom());
+ }
+
+ /**
+ * Predict a sequence of steps output states for the given HMM model
+ *
+ * @param model The Hidden Markov model used to generate the output sequence
+ * @param steps Size of the generated output sequence
+ * @param seed seed to use for the RNG
+ * @return integer array containing a sequence of steps output state IDs,
+ * generated by the specified model
+ */
+ public static int[] predict(HmmModel model, int steps, long seed) {
+ return predict(model, steps, RandomUtils.getRandom(seed));
+ }
+ /**
+ * Predict a sequence of steps output states for the given HMM model using the
+ * given seed for probabilistic experiments
+ *
+ * @param model The Hidden Markov model used to generate the output sequence
+ * @param steps Size of the generated output sequence
+ * @param rand RNG to use
+ * @return integer array containing a sequence of steps output state IDs,
+ * generated by the specified model
+ */
+ private static int[] predict(HmmModel model, int steps, Random rand) {
+ // fetch the cumulative distributions
+ Vector cip = HmmUtils.getCumulativeInitialProbabilities(model);
+ Matrix ctm = HmmUtils.getCumulativeTransitionMatrix(model);
+ Matrix com = HmmUtils.getCumulativeOutputMatrix(model);
+ // allocate the result IntArrayList
+ int[] result = new int[steps];
+ // choose the initial state
+ int hiddenState = 0;
+
+ double randnr = rand.nextDouble();
+ while (cip.get(hiddenState) < randnr) {
+ hiddenState++;
+ }
+
+ // now draw steps output states according to the cumulative
+ // distributions
+ for (int step = 0; step < steps; ++step) {
+ // choose output state to given hidden state
+ randnr = rand.nextDouble();
+ int outputState = 0;
+ while (com.get(hiddenState, outputState) < randnr) {
+ outputState++;
+ }
+ result[step] = outputState;
+ // choose the next hidden state
+ randnr = rand.nextDouble();
+ int nextHiddenState = 0;
+ while (ctm.get(hiddenState, nextHiddenState) < randnr) {
+ nextHiddenState++;
+ }
+ hiddenState = nextHiddenState;
+ }
+ return result;
+ }
+
+ /**
+ * Returns the likelihood that a given output sequence was produced by the
+ * given model. Internally, this function calls the forward algorithm to
+ * compute the alpha values and then uses the overloaded function to compute
+ * the actual model likelihood.
+ *
+ * @param model Model to base the likelihood on.
+ * @param outputSequence Sequence to compute likelihood for.
+ * @param scaled Use log-scaled parameters for computation. This is computationally
+ * more expensive, but offers better numerically stability in case of
+ * long output sequences
+ * @return Likelihood that the given model produced the given sequence
+ */
+ public static double modelLikelihood(HmmModel model, int[] outputSequence, boolean scaled) {
+ return modelLikelihood(HmmAlgorithms.forwardAlgorithm(model, outputSequence, scaled), scaled);
+ }
+
+ /**
+ * Computes the likelihood that a given output sequence was computed by a
+ * given model using the alpha values computed by the forward algorithm.
+ * // TODO I am a bit confused here - where is the output sequence referenced in the comment above in the code?
+ * @param alpha Matrix of alpha values
+ * @param scaled Set to true if the alpha values are log-scaled.
+ * @return model likelihood.
+ */
+ public static double modelLikelihood(Matrix alpha, boolean scaled) {
+ double likelihood = 0;
+ if (scaled) {
+ for (int i = 0; i < alpha.numCols(); ++i) {
+ likelihood += Math.exp(alpha.getQuick(alpha.numRows() - 1, i));
+ }
+ } else {
+ for (int i = 0; i < alpha.numCols(); ++i) {
+ likelihood += alpha.getQuick(alpha.numRows() - 1, i);
+ }
+ }
+ return likelihood;
+ }
+
+ /**
+ * Computes the likelihood that a given output sequence was computed by a
+ * given model.
+ *
+ * @param model model to compute sequence likelihood for.
+ * @param outputSequence sequence to base computation on.
+ * @param beta beta parameters.
+ * @param scaled set to true if betas are log-scaled.
+ * @return likelihood of the outputSequence given the model.
+ */
+ public static double modelLikelihood(HmmModel model, int[] outputSequence, Matrix beta, boolean scaled) {
+ double likelihood = 0;
+ // fetch the emission probabilities
+ Matrix e = model.getEmissionMatrix();
+ Vector pi = model.getInitialProbabilities();
+ int firstOutput = outputSequence[0];
+ if (scaled) {
+ for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+ likelihood += pi.getQuick(i) * Math.exp(beta.getQuick(0, i)) * e.getQuick(i, firstOutput);
+ }
+ } else {
+ for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+ likelihood += pi.getQuick(i) * beta.getQuick(0, i) * e.getQuick(i, firstOutput);
+ }
+ }
+ return likelihood;
+ }
+
+ /**
+ * Returns the most likely sequence of hidden states for the given model and
+ * observation
+ *
+ * @param model model to use for decoding.
+ * @param observations integer Array containing a sequence of observed state IDs
+ * @param scaled Use log-scaled computations, this requires higher computational
+ * effort but is numerically more stable for large observation
+ * sequences
+ * @return integer array containing the most likely sequence of hidden state
+ * IDs
+ */
+ public static int[] decode(HmmModel model, int[] observations, boolean scaled) {
+ return HmmAlgorithms.viterbiAlgorithm(model, observations, scaled);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java
new file mode 100644
index 0000000..bc24884
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmModel.java
@@ -0,0 +1,383 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Map;
+import java.util.Random;
+
+import com.google.common.collect.BiMap;
+import com.google.common.collect.HashBiMap;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Main class defining a Hidden Markov Model
+ */
+public class HmmModel implements Cloneable {
+
+ /** Bi-directional Map for storing the observed state names */
+ private BiMap<String,Integer> outputStateNames;
+
+ /** Bi-Directional Map for storing the hidden state names */
+ private BiMap<String,Integer> hiddenStateNames;
+
+ /* Number of hidden states */
+ private int nrOfHiddenStates;
+
+ /** Number of output states */
+ private int nrOfOutputStates;
+
+ /**
+ * Transition matrix containing the transition probabilities between hidden
+ * states. TransitionMatrix(i,j) is the probability that we change from hidden
+ * state i to hidden state j In general: P(h(t+1)=h_j | h(t) = h_i) =
+ * transitionMatrix(i,j) Since we have to make sure that each hidden state can
+ * be "left", the following normalization condition has to hold:
+ * sum(transitionMatrix(i,j),j=1..hiddenStates) = 1
+ */
+ private Matrix transitionMatrix;
+
+ /**
+ * Output matrix containing the probabilities that we observe a given output
+ * state given a hidden state. outputMatrix(i,j) is the probability that we
+ * observe output state j if we are in hidden state i Formally: P(o(t)=o_j |
+ * h(t)=h_i) = outputMatrix(i,j) Since we always have an observation for each
+ * hidden state, the following normalization condition has to hold:
+ * sum(outputMatrix(i,j),j=1..outputStates) = 1
+ */
+ private Matrix emissionMatrix;
+
+ /**
+ * Vector containing the initial hidden state probabilities. That is
+ * P(h(0)=h_i) = initialProbabilities(i). Since we are dealing with
+ * probabilities the following normalization condition has to hold:
+ * sum(initialProbabilities(i),i=1..hiddenStates) = 1
+ */
+ private Vector initialProbabilities;
+
+
+ /**
+ * Get a copy of this model
+ */
+ @Override
+ public HmmModel clone() {
+ HmmModel model = new HmmModel(transitionMatrix.clone(), emissionMatrix.clone(), initialProbabilities.clone());
+ if (hiddenStateNames != null) {
+ model.hiddenStateNames = HashBiMap.create(hiddenStateNames);
+ }
+ if (outputStateNames != null) {
+ model.outputStateNames = HashBiMap.create(outputStateNames);
+ }
+ return model;
+ }
+
+ /**
+ * Assign the content of another HMM model to this one
+ *
+ * @param model The HmmModel that will be assigned to this one
+ */
+ public void assign(HmmModel model) {
+ this.nrOfHiddenStates = model.nrOfHiddenStates;
+ this.nrOfOutputStates = model.nrOfOutputStates;
+ this.hiddenStateNames = model.hiddenStateNames;
+ this.outputStateNames = model.outputStateNames;
+ // for now clone the matrix/vectors
+ this.initialProbabilities = model.initialProbabilities.clone();
+ this.emissionMatrix = model.emissionMatrix.clone();
+ this.transitionMatrix = model.transitionMatrix.clone();
+ }
+
+ /**
+ * Construct a valid random Hidden-Markov parameter set with the given number
+ * of hidden and output states using a given seed.
+ *
+ * @param nrOfHiddenStates Number of hidden states
+ * @param nrOfOutputStates Number of output states
+ * @param seed Seed for the random initialization, if set to 0 the current time
+ * is used
+ */
+ public HmmModel(int nrOfHiddenStates, int nrOfOutputStates, long seed) {
+ this.nrOfHiddenStates = nrOfHiddenStates;
+ this.nrOfOutputStates = nrOfOutputStates;
+ this.transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates);
+ this.emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates);
+ this.initialProbabilities = new DenseVector(nrOfHiddenStates);
+ // initialize a random, valid parameter set
+ initRandomParameters(seed);
+ }
+
+ /**
+ * Construct a valid random Hidden-Markov parameter set with the given number
+ * of hidden and output states.
+ *
+ * @param nrOfHiddenStates Number of hidden states
+ * @param nrOfOutputStates Number of output states
+ */
+ public HmmModel(int nrOfHiddenStates, int nrOfOutputStates) {
+ this(nrOfHiddenStates, nrOfOutputStates, 0);
+ }
+
+ /**
+ * Generates a Hidden Markov model using the specified parameters
+ *
+ * @param transitionMatrix transition probabilities.
+ * @param emissionMatrix emission probabilities.
+ * @param initialProbabilities initial start probabilities.
+ * @throws IllegalArgumentException If the given parameter set is invalid
+ */
+ public HmmModel(Matrix transitionMatrix, Matrix emissionMatrix, Vector initialProbabilities) {
+ this.nrOfHiddenStates = initialProbabilities.size();
+ this.nrOfOutputStates = emissionMatrix.numCols();
+ this.transitionMatrix = transitionMatrix;
+ this.emissionMatrix = emissionMatrix;
+ this.initialProbabilities = initialProbabilities;
+ }
+
+ /**
+ * Initialize a valid random set of HMM parameters
+ *
+ * @param seed seed to use for Random initialization. Use 0 to use Java-built-in-version.
+ */
+ private void initRandomParameters(long seed) {
+ Random rand;
+ // initialize the random number generator
+ if (seed == 0) {
+ rand = RandomUtils.getRandom();
+ } else {
+ rand = RandomUtils.getRandom(seed);
+ }
+ // initialize the initial Probabilities
+ double sum = 0; // used for normalization
+ for (int i = 0; i < nrOfHiddenStates; i++) {
+ double nextRand = rand.nextDouble();
+ initialProbabilities.set(i, nextRand);
+ sum += nextRand;
+ }
+ // "normalize" the vector to generate probabilities
+ initialProbabilities = initialProbabilities.divide(sum);
+
+ // initialize the transition matrix
+ double[] values = new double[nrOfHiddenStates];
+ for (int i = 0; i < nrOfHiddenStates; i++) {
+ sum = 0;
+ for (int j = 0; j < nrOfHiddenStates; j++) {
+ values[j] = rand.nextDouble();
+ sum += values[j];
+ }
+ // normalize the random values to obtain probabilities
+ for (int j = 0; j < nrOfHiddenStates; j++) {
+ values[j] /= sum;
+ }
+ // set this row of the transition matrix
+ transitionMatrix.set(i, values);
+ }
+
+ // initialize the output matrix
+ values = new double[nrOfOutputStates];
+ for (int i = 0; i < nrOfHiddenStates; i++) {
+ sum = 0;
+ for (int j = 0; j < nrOfOutputStates; j++) {
+ values[j] = rand.nextDouble();
+ sum += values[j];
+ }
+ // normalize the random values to obtain probabilities
+ for (int j = 0; j < nrOfOutputStates; j++) {
+ values[j] /= sum;
+ }
+ // set this row of the output matrix
+ emissionMatrix.set(i, values);
+ }
+ }
+
+ /**
+ * Getter Method for the number of hidden states
+ *
+ * @return Number of hidden states
+ */
+ public int getNrOfHiddenStates() {
+ return nrOfHiddenStates;
+ }
+
+ /**
+ * Getter Method for the number of output states
+ *
+ * @return Number of output states
+ */
+ public int getNrOfOutputStates() {
+ return nrOfOutputStates;
+ }
+
+ /**
+ * Getter function to get the hidden state transition matrix
+ *
+ * @return returns the model's transition matrix.
+ */
+ public Matrix getTransitionMatrix() {
+ return transitionMatrix;
+ }
+
+ /**
+ * Getter function to get the output state probability matrix
+ *
+ * @return returns the models emission matrix.
+ */
+ public Matrix getEmissionMatrix() {
+ return emissionMatrix;
+ }
+
+ /**
+ * Getter function to return the vector of initial hidden state probabilities
+ *
+ * @return returns the model's init probabilities.
+ */
+ public Vector getInitialProbabilities() {
+ return initialProbabilities;
+ }
+
+ /**
+ * Getter method for the hidden state Names map
+ *
+ * @return hidden state names.
+ */
+ public Map<String, Integer> getHiddenStateNames() {
+ return hiddenStateNames;
+ }
+
+ /**
+ * Register an array of hidden state Names. We assume that the state name at
+ * position i has the ID i
+ *
+ * @param stateNames names of hidden states.
+ */
+ public void registerHiddenStateNames(String[] stateNames) {
+ if (stateNames != null) {
+ hiddenStateNames = HashBiMap.create();
+ for (int i = 0; i < stateNames.length; ++i) {
+ hiddenStateNames.put(stateNames[i], i);
+ }
+ }
+ }
+
+ /**
+ * Register a map of hidden state Names/state IDs
+ *
+ * @param stateNames <String,Integer> Map that assigns each state name an integer ID
+ */
+ public void registerHiddenStateNames(Map<String, Integer> stateNames) {
+ if (stateNames != null) {
+ hiddenStateNames = HashBiMap.create(stateNames);
+ }
+ }
+
+ /**
+ * Lookup the name for the given hidden state ID
+ *
+ * @param id Integer id of the hidden state
+ * @return String containing the name for the given ID, null if this ID is not
+ * known or no hidden state names were specified
+ */
+ public String getHiddenStateName(int id) {
+ if (hiddenStateNames == null) {
+ return null;
+ }
+ return hiddenStateNames.inverse().get(id);
+ }
+
+ /**
+ * Lookup the ID for the given hidden state name
+ *
+ * @param name Name of the hidden state
+ * @return int containing the ID for the given name, -1 if this name is not
+ * known or no hidden state names were specified
+ */
+ public int getHiddenStateID(String name) {
+ if (hiddenStateNames == null) {
+ return -1;
+ }
+ Integer tmp = hiddenStateNames.get(name);
+ return tmp == null ? -1 : tmp;
+ }
+
+ /**
+ * Getter method for the output state Names map
+ *
+ * @return names of output states.
+ */
+ public Map<String, Integer> getOutputStateNames() {
+ return outputStateNames;
+ }
+
+ /**
+ * Register an array of hidden state Names. We assume that the state name at
+ * position i has the ID i
+ *
+ * @param stateNames state names to register.
+ */
+ public void registerOutputStateNames(String[] stateNames) {
+ if (stateNames != null) {
+ outputStateNames = HashBiMap.create();
+ for (int i = 0; i < stateNames.length; ++i) {
+ outputStateNames.put(stateNames[i], i);
+ }
+ }
+ }
+
+ /**
+ * Register a map of hidden state Names/state IDs
+ *
+ * @param stateNames <String,Integer> Map that assigns each state name an integer ID
+ */
+ public void registerOutputStateNames(Map<String, Integer> stateNames) {
+ if (stateNames != null) {
+ outputStateNames = HashBiMap.create(stateNames);
+ }
+ }
+
+ /**
+ * Lookup the name for the given output state id
+ *
+ * @param id Integer id of the output state
+ * @return String containing the name for the given id, null if this id is not
+ * known or no output state names were specified
+ */
+ public String getOutputStateName(int id) {
+ if (outputStateNames == null) {
+ return null;
+ }
+ return outputStateNames.inverse().get(id);
+ }
+
+ /**
+ * Lookup the ID for the given output state name
+ *
+ * @param name Name of the output state
+ * @return int containing the ID for the given name, -1 if this name is not
+ * known or no output state names were specified
+ */
+ public int getOutputStateID(String name) {
+ if (outputStateNames == null) {
+ return -1;
+ }
+ Integer tmp = outputStateNames.get(name);
+ return tmp == null ? -1 : tmp;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
new file mode 100644
index 0000000..a1cd3e0
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmTrainer.java
@@ -0,0 +1,488 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Collection;
+import java.util.Iterator;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Class containing several algorithms used to train a Hidden Markov Model. The
+ * three main algorithms are: supervised learning, unsupervised Viterbi and
+ * unsupervised Baum-Welch.
+ */
+public final class HmmTrainer {
+
+ /**
+ * No public constructor for utility classes.
+ */
+ private HmmTrainer() {
+ // nothing to do here really.
+ }
+
+ /**
+ * Create an supervised initial estimate of an HMM Model based on a sequence
+ * of observed and hidden states.
+ *
+ * @param nrOfHiddenStates The total number of hidden states
+ * @param nrOfOutputStates The total number of output states
+ * @param observedSequence Integer array containing the observed sequence
+ * @param hiddenSequence Integer array containing the hidden sequence
+ * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero
+ * probabilities.
+ * @return An initial model using the estimated parameters
+ */
+ public static HmmModel trainSupervised(int nrOfHiddenStates, int nrOfOutputStates, int[] observedSequence,
+ int[] hiddenSequence, double pseudoCount) {
+ // make sure the pseudo count is not zero
+ pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
+
+ // initialize the parameters
+ DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfHiddenStates);
+ DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates, nrOfOutputStates);
+ // assign a small initial probability that is larger than zero, so
+ // unseen states will not get a zero probability
+ transitionMatrix.assign(pseudoCount);
+ emissionMatrix.assign(pseudoCount);
+ // given no prior knowledge, we have to assume that all initial hidden
+ // states are equally likely
+ DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
+ initialProbabilities.assign(1.0 / nrOfHiddenStates);
+
+ // now loop over the sequences to count the number of transitions
+ countTransitions(transitionMatrix, emissionMatrix, observedSequence,
+ hiddenSequence);
+
+ // make sure that probabilities are normalized
+ for (int i = 0; i < nrOfHiddenStates; i++) {
+ // compute sum of probabilities for current row of transition matrix
+ double sum = 0;
+ for (int j = 0; j < nrOfHiddenStates; j++) {
+ sum += transitionMatrix.getQuick(i, j);
+ }
+ // normalize current row of transition matrix
+ for (int j = 0; j < nrOfHiddenStates; j++) {
+ transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
+ }
+ // compute sum of probabilities for current row of emission matrix
+ sum = 0;
+ for (int j = 0; j < nrOfOutputStates; j++) {
+ sum += emissionMatrix.getQuick(i, j);
+ }
+ // normalize current row of emission matrix
+ for (int j = 0; j < nrOfOutputStates; j++) {
+ emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
+ }
+ }
+
+ // return a new model using the parameter estimations
+ return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
+ }
+
+ /**
+ * Function that counts the number of state->state and state->output
+ * transitions for the given observed/hidden sequence.
+ *
+ * @param transitionMatrix transition matrix to use.
+ * @param emissionMatrix emission matrix to use for counting.
+ * @param observedSequence observation sequence to use.
+ * @param hiddenSequence sequence of hidden states to use.
+ */
+ private static void countTransitions(Matrix transitionMatrix,
+ Matrix emissionMatrix, int[] observedSequence, int[] hiddenSequence) {
+ emissionMatrix.setQuick(hiddenSequence[0], observedSequence[0],
+ emissionMatrix.getQuick(hiddenSequence[0], observedSequence[0]) + 1);
+ for (int i = 1; i < observedSequence.length; ++i) {
+ transitionMatrix
+ .setQuick(hiddenSequence[i - 1], hiddenSequence[i], transitionMatrix
+ .getQuick(hiddenSequence[i - 1], hiddenSequence[i]) + 1);
+ emissionMatrix.setQuick(hiddenSequence[i], observedSequence[i],
+ emissionMatrix.getQuick(hiddenSequence[i], observedSequence[i]) + 1);
+ }
+ }
+
+ /**
+ * Create an supervised initial estimate of an HMM Model based on a number of
+ * sequences of observed and hidden states.
+ *
+ * @param nrOfHiddenStates The total number of hidden states
+ * @param nrOfOutputStates The total number of output states
+ * @param hiddenSequences Collection of hidden sequences to use for training
+ * @param observedSequences Collection of observed sequences to use for training associated with hidden sequences.
+ * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero
+ * probabilities.
+ * @return An initial model using the estimated parameters
+ */
+ public static HmmModel trainSupervisedSequence(int nrOfHiddenStates,
+ int nrOfOutputStates, Collection<int[]> hiddenSequences,
+ Collection<int[]> observedSequences, double pseudoCount) {
+
+ // make sure the pseudo count is not zero
+ pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
+
+ // initialize parameters
+ DenseMatrix transitionMatrix = new DenseMatrix(nrOfHiddenStates,
+ nrOfHiddenStates);
+ DenseMatrix emissionMatrix = new DenseMatrix(nrOfHiddenStates,
+ nrOfOutputStates);
+ DenseVector initialProbabilities = new DenseVector(nrOfHiddenStates);
+
+ // assign pseudo count to avoid zero probabilities
+ transitionMatrix.assign(pseudoCount);
+ emissionMatrix.assign(pseudoCount);
+ initialProbabilities.assign(pseudoCount);
+
+ // now loop over the sequences to count the number of transitions
+ Iterator<int[]> hiddenSequenceIt = hiddenSequences.iterator();
+ Iterator<int[]> observedSequenceIt = observedSequences.iterator();
+ while (hiddenSequenceIt.hasNext() && observedSequenceIt.hasNext()) {
+ // fetch the current set of sequences
+ int[] hiddenSequence = hiddenSequenceIt.next();
+ int[] observedSequence = observedSequenceIt.next();
+ // increase the count for initial probabilities
+ initialProbabilities.setQuick(hiddenSequence[0], initialProbabilities
+ .getQuick(hiddenSequence[0]) + 1);
+ countTransitions(transitionMatrix, emissionMatrix, observedSequence,
+ hiddenSequence);
+ }
+
+ // make sure that probabilities are normalized
+ double isum = 0; // sum of initial probabilities
+ for (int i = 0; i < nrOfHiddenStates; i++) {
+ isum += initialProbabilities.getQuick(i);
+ // compute sum of probabilities for current row of transition matrix
+ double sum = 0;
+ for (int j = 0; j < nrOfHiddenStates; j++) {
+ sum += transitionMatrix.getQuick(i, j);
+ }
+ // normalize current row of transition matrix
+ for (int j = 0; j < nrOfHiddenStates; j++) {
+ transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j) / sum);
+ }
+ // compute sum of probabilities for current row of emission matrix
+ sum = 0;
+ for (int j = 0; j < nrOfOutputStates; j++) {
+ sum += emissionMatrix.getQuick(i, j);
+ }
+ // normalize current row of emission matrix
+ for (int j = 0; j < nrOfOutputStates; j++) {
+ emissionMatrix.setQuick(i, j, emissionMatrix.getQuick(i, j) / sum);
+ }
+ }
+ // normalize the initial probabilities
+ for (int i = 0; i < nrOfHiddenStates; ++i) {
+ initialProbabilities.setQuick(i, initialProbabilities.getQuick(i) / isum);
+ }
+
+ // return a new model using the parameter estimates
+ return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
+ }
+
+ /**
+ * Iteratively train the parameters of the given initial model wrt to the
+ * observed sequence using Viterbi training.
+ *
+ * @param initialModel The initial model that gets iterated
+ * @param observedSequence The sequence of observed states
+ * @param pseudoCount Value that is assigned to non-occurring transitions to avoid zero
+ * probabilities.
+ * @param epsilon Convergence criteria
+ * @param maxIterations The maximum number of training iterations
+ * @param scaled Use Log-scaled implementation, this is computationally more
+ * expensive but offers better numerical stability for large observed
+ * sequences
+ * @return The iterated model
+ */
+ public static HmmModel trainViterbi(HmmModel initialModel,
+ int[] observedSequence, double pseudoCount, double epsilon,
+ int maxIterations, boolean scaled) {
+
+ // make sure the pseudo count is not zero
+ pseudoCount = pseudoCount == 0 ? Double.MIN_VALUE : pseudoCount;
+
+ // allocate space for iteration models
+ HmmModel lastIteration = initialModel.clone();
+ HmmModel iteration = initialModel.clone();
+
+ // allocate space for Viterbi path calculation
+ int[] viterbiPath = new int[observedSequence.length];
+ int[][] phi = new int[observedSequence.length - 1][initialModel
+ .getNrOfHiddenStates()];
+ double[][] delta = new double[observedSequence.length][initialModel
+ .getNrOfHiddenStates()];
+
+ // now run the Viterbi training iteration
+ for (int i = 0; i < maxIterations; ++i) {
+ // compute the Viterbi path
+ HmmAlgorithms.viterbiAlgorithm(viterbiPath, delta, phi, lastIteration,
+ observedSequence, scaled);
+ // Viterbi iteration uses the viterbi path to update
+ // the probabilities
+ Matrix emissionMatrix = iteration.getEmissionMatrix();
+ Matrix transitionMatrix = iteration.getTransitionMatrix();
+
+ // first, assign the pseudo count
+ emissionMatrix.assign(pseudoCount);
+ transitionMatrix.assign(pseudoCount);
+
+ // now count the transitions
+ countTransitions(transitionMatrix, emissionMatrix, observedSequence,
+ viterbiPath);
+
+ // and normalize the probabilities
+ for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+ double sum = 0;
+ // normalize the rows of the transition matrix
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+ sum += transitionMatrix.getQuick(j, k);
+ }
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+ transitionMatrix
+ .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
+ }
+ // normalize the rows of the emission matrix
+ sum = 0;
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+ sum += emissionMatrix.getQuick(j, k);
+ }
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+ emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
+ }
+ }
+ // check for convergence
+ if (checkConvergence(lastIteration, iteration, epsilon)) {
+ break;
+ }
+ // overwrite the last iterated model by the new iteration
+ lastIteration.assign(iteration);
+ }
+ // we are done :)
+ return iteration;
+ }
+
+ /**
+ * Iteratively train the parameters of the given initial model wrt the
+ * observed sequence using Baum-Welch training.
+ *
+ * @param initialModel The initial model that gets iterated
+ * @param observedSequence The sequence of observed states
+ * @param epsilon Convergence criteria
+ * @param maxIterations The maximum number of training iterations
+ * @param scaled Use log-scaled implementations of forward/backward algorithm. This
+ * is computationally more expensive, but offers better numerical
+ * stability for long output sequences.
+ * @return The iterated model
+ */
+ public static HmmModel trainBaumWelch(HmmModel initialModel,
+ int[] observedSequence, double epsilon, int maxIterations, boolean scaled) {
+ // allocate space for the iterations
+ HmmModel lastIteration = initialModel.clone();
+ HmmModel iteration = initialModel.clone();
+
+ // allocate space for baum-welch factors
+ int hiddenCount = initialModel.getNrOfHiddenStates();
+ int visibleCount = observedSequence.length;
+ Matrix alpha = new DenseMatrix(visibleCount, hiddenCount);
+ Matrix beta = new DenseMatrix(visibleCount, hiddenCount);
+
+ // now run the baum Welch training iteration
+ for (int it = 0; it < maxIterations; ++it) {
+ // fetch emission and transition matrix of current iteration
+ Vector initialProbabilities = iteration.getInitialProbabilities();
+ Matrix emissionMatrix = iteration.getEmissionMatrix();
+ Matrix transitionMatrix = iteration.getTransitionMatrix();
+
+ // compute forward and backward factors
+ HmmAlgorithms.forwardAlgorithm(alpha, iteration, observedSequence, scaled);
+ HmmAlgorithms.backwardAlgorithm(beta, iteration, observedSequence, scaled);
+
+ if (scaled) {
+ logScaledBaumWelch(observedSequence, iteration, alpha, beta);
+ } else {
+ unscaledBaumWelch(observedSequence, iteration, alpha, beta);
+ }
+ // normalize transition/emission probabilities
+ // and normalize the probabilities
+ double isum = 0;
+ for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+ double sum = 0;
+ // normalize the rows of the transition matrix
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+ sum += transitionMatrix.getQuick(j, k);
+ }
+ for (int k = 0; k < iteration.getNrOfHiddenStates(); ++k) {
+ transitionMatrix
+ .setQuick(j, k, transitionMatrix.getQuick(j, k) / sum);
+ }
+ // normalize the rows of the emission matrix
+ sum = 0;
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+ sum += emissionMatrix.getQuick(j, k);
+ }
+ for (int k = 0; k < iteration.getNrOfOutputStates(); ++k) {
+ emissionMatrix.setQuick(j, k, emissionMatrix.getQuick(j, k) / sum);
+ }
+ // normalization parameter for initial probabilities
+ isum += initialProbabilities.getQuick(j);
+ }
+ // normalize initial probabilities
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ initialProbabilities.setQuick(i, initialProbabilities.getQuick(i)
+ / isum);
+ }
+ // check for convergence
+ if (checkConvergence(lastIteration, iteration, epsilon)) {
+ break;
+ }
+ // overwrite the last iterated model by the new iteration
+ lastIteration.assign(iteration);
+ }
+ // we are done :)
+ return iteration;
+ }
+
+ private static void unscaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
+ Vector initialProbabilities = iteration.getInitialProbabilities();
+ Matrix emissionMatrix = iteration.getEmissionMatrix();
+ Matrix transitionMatrix = iteration.getTransitionMatrix();
+ double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, false);
+
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ initialProbabilities.setQuick(i, alpha.getQuick(0, i)
+ * beta.getQuick(0, i));
+ }
+
+ // recompute transition probabilities
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+ double temp = 0;
+ for (int t = 0; t < observedSequence.length - 1; ++t) {
+ temp += alpha.getQuick(t, i)
+ * emissionMatrix.getQuick(j, observedSequence[t + 1])
+ * beta.getQuick(t + 1, j);
+ }
+ transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
+ * temp / modelLikelihood);
+ }
+ }
+ // recompute emission probabilities
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) {
+ double temp = 0;
+ for (int t = 0; t < observedSequence.length; ++t) {
+ // delta tensor
+ if (observedSequence[t] == j) {
+ temp += alpha.getQuick(t, i) * beta.getQuick(t, i);
+ }
+ }
+ emissionMatrix.setQuick(i, j, temp / modelLikelihood);
+ }
+ }
+ }
+
+ private static void logScaledBaumWelch(int[] observedSequence, HmmModel iteration, Matrix alpha, Matrix beta) {
+ Vector initialProbabilities = iteration.getInitialProbabilities();
+ Matrix emissionMatrix = iteration.getEmissionMatrix();
+ Matrix transitionMatrix = iteration.getTransitionMatrix();
+ double modelLikelihood = HmmEvaluator.modelLikelihood(alpha, true);
+
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ initialProbabilities.setQuick(i, Math.exp(alpha.getQuick(0, i) + beta.getQuick(0, i)));
+ }
+
+ // recompute transition probabilities
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < iteration.getNrOfHiddenStates(); ++j) {
+ double sum = Double.NEGATIVE_INFINITY; // log(0)
+ for (int t = 0; t < observedSequence.length - 1; ++t) {
+ double temp = alpha.getQuick(t, i)
+ + Math.log(emissionMatrix.getQuick(j, observedSequence[t + 1]))
+ + beta.getQuick(t + 1, j);
+ if (temp > Double.NEGATIVE_INFINITY) {
+ // handle 0-probabilities
+ sum = temp + Math.log1p(Math.exp(sum - temp));
+ }
+ }
+ transitionMatrix.setQuick(i, j, transitionMatrix.getQuick(i, j)
+ * Math.exp(sum - modelLikelihood));
+ }
+ }
+ // recompute emission probabilities
+ for (int i = 0; i < iteration.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < iteration.getNrOfOutputStates(); ++j) {
+ double sum = Double.NEGATIVE_INFINITY; // log(0)
+ for (int t = 0; t < observedSequence.length; ++t) {
+ // delta tensor
+ if (observedSequence[t] == j) {
+ double temp = alpha.getQuick(t, i) + beta.getQuick(t, i);
+ if (temp > Double.NEGATIVE_INFINITY) {
+ // handle 0-probabilities
+ sum = temp + Math.log1p(Math.exp(sum - temp));
+ }
+ }
+ }
+ emissionMatrix.setQuick(i, j, Math.exp(sum - modelLikelihood));
+ }
+ }
+ }
+
+ /**
+ * Check convergence of two HMM models by computing a simple distance between
+ * emission / transition matrices
+ *
+ * @param oldModel Old HMM Model
+ * @param newModel New HMM Model
+ * @param epsilon Convergence Factor
+ * @return true if training converged to a stable state.
+ */
+ private static boolean checkConvergence(HmmModel oldModel, HmmModel newModel,
+ double epsilon) {
+ // check convergence of transitionProbabilities
+ Matrix oldTransitionMatrix = oldModel.getTransitionMatrix();
+ Matrix newTransitionMatrix = newModel.getTransitionMatrix();
+ double diff = 0;
+ for (int i = 0; i < oldModel.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < oldModel.getNrOfHiddenStates(); ++j) {
+ double tmp = oldTransitionMatrix.getQuick(i, j)
+ - newTransitionMatrix.getQuick(i, j);
+ diff += tmp * tmp;
+ }
+ }
+ double norm = Math.sqrt(diff);
+ diff = 0;
+ // check convergence of emissionProbabilities
+ Matrix oldEmissionMatrix = oldModel.getEmissionMatrix();
+ Matrix newEmissionMatrix = newModel.getEmissionMatrix();
+ for (int i = 0; i < oldModel.getNrOfHiddenStates(); i++) {
+ for (int j = 0; j < oldModel.getNrOfOutputStates(); j++) {
+
+ double tmp = oldEmissionMatrix.getQuick(i, j)
+ - newEmissionMatrix.getQuick(i, j);
+ diff += tmp * tmp;
+ }
+ }
+ norm += Math.sqrt(diff);
+ // iteration has converged :)
+ return norm < epsilon;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
new file mode 100644
index 0000000..521be09
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/HmmUtils.java
@@ -0,0 +1,361 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * A collection of utilities for handling HMMModel objects.
+ */
+public final class HmmUtils {
+
+ /**
+ * No public constructor for utility classes.
+ */
+ private HmmUtils() {
+ // nothing to do here really.
+ }
+
+ /**
+ * Compute the cumulative transition probability matrix for the given HMM
+ * model. Matrix where each row i is the cumulative distribution of the
+ * transition probability distribution for hidden state i.
+ *
+ * @param model The HMM model for which the cumulative transition matrix should be
+ * computed
+ * @return The computed cumulative transition matrix.
+ */
+ public static Matrix getCumulativeTransitionMatrix(HmmModel model) {
+ // fetch the needed parameters from the model
+ int hiddenStates = model.getNrOfHiddenStates();
+ Matrix transitionMatrix = model.getTransitionMatrix();
+ // now compute the cumulative transition matrix
+ Matrix resultMatrix = new DenseMatrix(hiddenStates, hiddenStates);
+ for (int i = 0; i < hiddenStates; ++i) {
+ double sum = 0;
+ for (int j = 0; j < hiddenStates; ++j) {
+ sum += transitionMatrix.get(i, j);
+ resultMatrix.set(i, j, sum);
+ }
+ resultMatrix.set(i, hiddenStates - 1, 1.0);
+ // make sure the last
+ // state has always a
+ // cumulative
+ // probability of
+ // exactly 1.0
+ }
+ return resultMatrix;
+ }
+
+ /**
+ * Compute the cumulative output probability matrix for the given HMM model.
+ * Matrix where each row i is the cumulative distribution of the output
+ * probability distribution for hidden state i.
+ *
+ * @param model The HMM model for which the cumulative output matrix should be
+ * computed
+ * @return The computed cumulative output matrix.
+ */
+ public static Matrix getCumulativeOutputMatrix(HmmModel model) {
+ // fetch the needed parameters from the model
+ int hiddenStates = model.getNrOfHiddenStates();
+ int outputStates = model.getNrOfOutputStates();
+ Matrix outputMatrix = model.getEmissionMatrix();
+ // now compute the cumulative output matrix
+ Matrix resultMatrix = new DenseMatrix(hiddenStates, outputStates);
+ for (int i = 0; i < hiddenStates; ++i) {
+ double sum = 0;
+ for (int j = 0; j < outputStates; ++j) {
+ sum += outputMatrix.get(i, j);
+ resultMatrix.set(i, j, sum);
+ }
+ resultMatrix.set(i, outputStates - 1, 1.0);
+ // make sure the last
+ // output state has
+ // always a cumulative
+ // probability of 1.0
+ }
+ return resultMatrix;
+ }
+
+ /**
+ * Compute the cumulative distribution of the initial hidden state
+ * probabilities for the given HMM model.
+ *
+ * @param model The HMM model for which the cumulative initial state probabilities
+ * should be computed
+ * @return The computed cumulative initial state probability vector.
+ */
+ public static Vector getCumulativeInitialProbabilities(HmmModel model) {
+ // fetch the needed parameters from the model
+ int hiddenStates = model.getNrOfHiddenStates();
+ Vector initialProbabilities = model.getInitialProbabilities();
+ // now compute the cumulative output matrix
+ Vector resultVector = new DenseVector(initialProbabilities.size());
+ double sum = 0;
+ for (int i = 0; i < hiddenStates; ++i) {
+ sum += initialProbabilities.get(i);
+ resultVector.set(i, sum);
+ }
+ resultVector.set(hiddenStates - 1, 1.0); // make sure the last initial
+ // hidden state probability
+ // has always a cumulative
+ // probability of 1.0
+ return resultVector;
+ }
+
+ /**
+ * Validates an HMM model set
+ *
+ * @param model model to sanity check.
+ */
+ public static void validate(HmmModel model) {
+ if (model == null) {
+ return; // empty models are valid
+ }
+
+ /*
+ * The number of hidden states is positive.
+ */
+ Preconditions.checkArgument(model.getNrOfHiddenStates() > 0,
+ "Error: The number of hidden states has to be greater than 0");
+
+ /*
+ * The number of output states is positive.
+ */
+ Preconditions.checkArgument(model.getNrOfOutputStates() > 0,
+ "Error: The number of output states has to be greater than 0!");
+
+ /*
+ * The size of the vector of initial probabilities is equal to the number of
+ * the hidden states. Each initial probability is non-negative. The sum of
+ * initial probabilities is equal to 1.
+ */
+ Preconditions.checkArgument(model.getInitialProbabilities() != null
+ && model.getInitialProbabilities().size() == model.getNrOfHiddenStates(),
+ "Error: The vector of initial probabilities is not initialized!");
+
+ double sum = 0;
+ for (int i = 0; i < model.getInitialProbabilities().size(); i++) {
+ Preconditions.checkArgument(model.getInitialProbabilities().get(i) >= 0,
+ "Error: Initial probability of state %d is negative", i);
+ sum += model.getInitialProbabilities().get(i);
+ }
+ Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001,
+ "Error: Initial probabilities do not add up to 1");
+ /*
+ * The row size of the output matrix is equal to the number of the hidden
+ * states. The column size is equal to the number of output states. Each
+ * probability of the matrix is non-negative. The sum of each row is equal
+ * to 1.
+ */
+ Preconditions.checkNotNull(model.getEmissionMatrix(), "Error: The output state matrix is not initialized!");
+ Preconditions.checkArgument(model.getEmissionMatrix().numRows() == model.getNrOfHiddenStates()
+ && model.getEmissionMatrix().numCols() == model.getNrOfOutputStates(),
+ "Error: The output state matrix is not of the form nrOfHiddenStates x nrOfOutputStates");
+ for (int i = 0; i < model.getEmissionMatrix().numRows(); i++) {
+ sum = 0;
+ for (int j = 0; j < model.getEmissionMatrix().numCols(); j++) {
+ Preconditions.checkArgument(model.getEmissionMatrix().get(i, j) >= 0,
+ "The output state probability from hidden state " + i + " to output state " + j + " is negative");
+ sum += model.getEmissionMatrix().get(i, j);
+ }
+ Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001,
+ "Error: The output state probabilities for hidden state %d don't add up to 1", i);
+ }
+
+ /*
+ * The size of both dimension of the transition matrix is equal to the
+ * number of the hidden states. Each probability of the matrix is
+ * non-negative. The sum of each row in transition matrix is equal to 1.
+ */
+ Preconditions.checkArgument(model.getTransitionMatrix() != null,
+ "Error: The hidden state matrix is not initialized!");
+ Preconditions.checkArgument(model.getTransitionMatrix().numRows() == model.getNrOfHiddenStates()
+ && model.getTransitionMatrix().numCols() == model.getNrOfHiddenStates(),
+ "Error: The output state matrix is not of the form nrOfHiddenStates x nrOfHiddenStates");
+ for (int i = 0; i < model.getTransitionMatrix().numRows(); i++) {
+ sum = 0;
+ for (int j = 0; j < model.getTransitionMatrix().numCols(); j++) {
+ Preconditions.checkArgument(model.getTransitionMatrix().get(i, j) >= 0,
+ "Error: The transition probability from hidden state %d to hidden state %d is negative", i, j);
+ sum += model.getTransitionMatrix().get(i, j);
+ }
+ Preconditions.checkArgument(Math.abs(sum - 1) <= 0.00001,
+ "Error: The transition probabilities for hidden state " + i + " don't add up to 1.");
+ }
+ }
+
+ /**
+ * Encodes a given collection of state names by the corresponding state IDs
+ * registered in a given model.
+ *
+ * @param model Model to provide the encoding for
+ * @param sequence Collection of state names
+ * @param observed If set, the sequence is encoded as a sequence of observed states,
+ * else it is encoded as sequence of hidden states
+ * @param defaultValue The default value in case a state is not known
+ * @return integer array containing the encoded state IDs
+ */
+ public static int[] encodeStateSequence(HmmModel model,
+ Collection<String> sequence, boolean observed, int defaultValue) {
+ int[] encoded = new int[sequence.size()];
+ Iterator<String> seqIter = sequence.iterator();
+ for (int i = 0; i < sequence.size(); ++i) {
+ String nextState = seqIter.next();
+ int nextID;
+ if (observed) {
+ nextID = model.getOutputStateID(nextState);
+ } else {
+ nextID = model.getHiddenStateID(nextState);
+ }
+ // if the ID is -1, use the default value
+ encoded[i] = nextID < 0 ? defaultValue : nextID;
+ }
+ return encoded;
+ }
+
+ /**
+ * Decodes a given collection of state IDs into the corresponding state names
+ * registered in a given model.
+ *
+ * @param model model to use for retrieving state names
+ * @param sequence int array of state IDs
+ * @param observed If set, the sequence is encoded as a sequence of observed states,
+ * else it is encoded as sequence of hidden states
+ * @param defaultValue The default value in case a state is not known
+ * @return list containing the decoded state names
+ */
+ public static List<String> decodeStateSequence(HmmModel model,
+ int[] sequence,
+ boolean observed,
+ String defaultValue) {
+ List<String> decoded = Lists.newArrayListWithCapacity(sequence.length);
+ for (int position : sequence) {
+ String nextState;
+ if (observed) {
+ nextState = model.getOutputStateName(position);
+ } else {
+ nextState = model.getHiddenStateName(position);
+ }
+ // if null was returned, use the default value
+ decoded.add(nextState == null ? defaultValue : nextState);
+ }
+ return decoded;
+ }
+
+ /**
+ * Function used to normalize the probabilities of a given HMM model
+ *
+ * @param model model to normalize
+ */
+ public static void normalizeModel(HmmModel model) {
+ Vector ip = model.getInitialProbabilities();
+ Matrix emission = model.getEmissionMatrix();
+ Matrix transition = model.getTransitionMatrix();
+ // check normalization for all probabilities
+ double isum = 0;
+ for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+ isum += ip.getQuick(i);
+ double sum = 0;
+ for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
+ sum += transition.getQuick(i, j);
+ }
+ if (sum != 1.0) {
+ for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
+ transition.setQuick(i, j, transition.getQuick(i, j) / sum);
+ }
+ }
+ sum = 0;
+ for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
+ sum += emission.getQuick(i, j);
+ }
+ if (sum != 1.0) {
+ for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
+ emission.setQuick(i, j, emission.getQuick(i, j) / sum);
+ }
+ }
+ }
+ if (isum != 1.0) {
+ for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+ ip.setQuick(i, ip.getQuick(i) / isum);
+ }
+ }
+ }
+
+ /**
+ * Method to reduce the size of an HMMmodel by converting the models
+ * DenseMatrix/DenseVectors to sparse implementations and setting every value
+ * < threshold to 0
+ *
+ * @param model model to truncate
+ * @param threshold minimum value a model entry must have to be retained.
+ * @return Truncated model
+ */
+ public static HmmModel truncateModel(HmmModel model, double threshold) {
+ Vector ip = model.getInitialProbabilities();
+ Matrix em = model.getEmissionMatrix();
+ Matrix tr = model.getTransitionMatrix();
+ // allocate the sparse data structures
+ RandomAccessSparseVector sparseIp = new RandomAccessSparseVector(model
+ .getNrOfHiddenStates());
+ SparseMatrix sparseEm = new SparseMatrix(model.getNrOfHiddenStates(), model.getNrOfOutputStates());
+ SparseMatrix sparseTr = new SparseMatrix(model.getNrOfHiddenStates(), model.getNrOfHiddenStates());
+ // now transfer the values
+ for (int i = 0; i < model.getNrOfHiddenStates(); ++i) {
+ double value = ip.getQuick(i);
+ if (value > threshold) {
+ sparseIp.setQuick(i, value);
+ }
+ for (int j = 0; j < model.getNrOfHiddenStates(); ++j) {
+ value = tr.getQuick(i, j);
+ if (value > threshold) {
+ sparseTr.setQuick(i, j, value);
+ }
+ }
+
+ for (int j = 0; j < model.getNrOfOutputStates(); ++j) {
+ value = em.getQuick(i, j);
+ if (value > threshold) {
+ sparseEm.setQuick(i, j, value);
+ }
+ }
+ }
+ // create a new model
+ HmmModel sparseModel = new HmmModel(sparseTr, sparseEm, sparseIp);
+ // normalize the model
+ normalizeModel(sparseModel);
+ // register the names
+ sparseModel.registerHiddenStateNames(model.getHiddenStateNames());
+ sparseModel.registerOutputStateNames(model.getOutputStateNames());
+ // and return
+ return sparseModel;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
new file mode 100644
index 0000000..d0ae9c2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/LossyHmmSerializer.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Utils for serializing Writable parts of HmmModel (that means without hidden state names and so on)
+ */
+final class LossyHmmSerializer {
+
+ private LossyHmmSerializer() {
+ }
+
+ static void serialize(HmmModel model, DataOutput output) throws IOException {
+ MatrixWritable matrix = new MatrixWritable(model.getEmissionMatrix());
+ matrix.write(output);
+ matrix.set(model.getTransitionMatrix());
+ matrix.write(output);
+
+ VectorWritable vector = new VectorWritable(model.getInitialProbabilities());
+ vector.write(output);
+ }
+
+ static HmmModel deserialize(DataInput input) throws IOException {
+ MatrixWritable matrix = new MatrixWritable();
+ matrix.readFields(input);
+ Matrix emissionMatrix = matrix.get();
+
+ matrix.readFields(input);
+ Matrix transitionMatrix = matrix.get();
+
+ VectorWritable vector = new VectorWritable();
+ vector.readFields(input);
+ Vector initialProbabilities = vector.get();
+
+ return new HmmModel(transitionMatrix, emissionMatrix, initialProbabilities);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java
new file mode 100644
index 0000000..cd2ced1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/RandomSequenceGenerator.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.DataInputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+
+import com.google.common.base.Charsets;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.mahout.common.CommandLineUtil;
+
+/**
+ * Command-line tool for generating random sequences by given HMM
+ */
+public final class RandomSequenceGenerator {
+
+ private RandomSequenceGenerator() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+ Option outputOption = optionBuilder.withLongName("output").
+ withDescription("Output file with sequence of observed states").
+ withShortName("o").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("path").create()).withRequired(false).create();
+
+ Option modelOption = optionBuilder.withLongName("model").
+ withDescription("Path to serialized HMM model").
+ withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("path").create()).withRequired(true).create();
+
+ Option lengthOption = optionBuilder.withLongName("length").
+ withDescription("Length of generated sequence").
+ withShortName("l").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ Group optionGroup = new GroupBuilder().
+ withOption(outputOption).withOption(modelOption).withOption(lengthOption).
+ withName("Options").create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(optionGroup);
+ CommandLine commandLine = parser.parse(args);
+
+ String output = (String) commandLine.getValue(outputOption);
+
+ String modelPath = (String) commandLine.getValue(modelOption);
+
+ int length = Integer.parseInt((String) commandLine.getValue(lengthOption));
+
+ //reading serialized HMM
+ DataInputStream modelStream = new DataInputStream(new FileInputStream(modelPath));
+ HmmModel model;
+ try {
+ model = LossyHmmSerializer.deserialize(modelStream);
+ } finally {
+ Closeables.close(modelStream, true);
+ }
+
+ //generating observations
+ int[] observations = HmmEvaluator.predict(model, length, System.currentTimeMillis());
+
+ //writing output
+ PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(output), Charsets.UTF_8), true);
+ try {
+ for (int observation : observations) {
+ writer.print(observation);
+ writer.print(' ');
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ } catch (OptionException e) {
+ CommandLineUtil.printHelp(optionGroup);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java
new file mode 100644
index 0000000..fb64385
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ViterbiEvaluator.java
@@ -0,0 +1,127 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.DataInputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.PrintWriter;
+import java.util.List;
+import java.util.Scanner;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+
+/**
+ * Command-line tool for Viterbi evaluating
+ */
+public final class ViterbiEvaluator {
+
+ private ViterbiEvaluator() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+ Option inputOption = DefaultOptionCreator.inputOption().create();
+
+ Option outputOption = DefaultOptionCreator.outputOption().create();
+
+ Option modelOption = optionBuilder.withLongName("model").
+ withDescription("Path to serialized HMM model").
+ withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("path").create()).withRequired(true).create();
+
+ Option likelihoodOption = optionBuilder.withLongName("likelihood").
+ withDescription("Compute likelihood of observed sequence").
+ withShortName("l").withRequired(false).create();
+
+ Group optionGroup = new GroupBuilder().withOption(inputOption).
+ withOption(outputOption).withOption(modelOption).withOption(likelihoodOption).
+ withName("Options").create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(optionGroup);
+ CommandLine commandLine = parser.parse(args);
+
+ String input = (String) commandLine.getValue(inputOption);
+ String output = (String) commandLine.getValue(outputOption);
+
+ String modelPath = (String) commandLine.getValue(modelOption);
+
+ boolean computeLikelihood = commandLine.hasOption(likelihoodOption);
+
+ //reading serialized HMM
+ DataInputStream modelStream = new DataInputStream(new FileInputStream(modelPath));
+ HmmModel model;
+ try {
+ model = LossyHmmSerializer.deserialize(modelStream);
+ } finally {
+ Closeables.close(modelStream, true);
+ }
+
+ //reading observations
+ List<Integer> observations = Lists.newArrayList();
+ try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) {
+ while (scanner.hasNextInt()) {
+ observations.add(scanner.nextInt());
+ }
+ }
+
+ int[] observationsArray = new int[observations.size()];
+ for (int i = 0; i < observations.size(); ++i) {
+ observationsArray[i] = observations.get(i);
+ }
+
+ //decoding
+ int[] hiddenStates = HmmEvaluator.decode(model, observationsArray, true);
+
+ //writing output
+ PrintWriter writer = new PrintWriter(new OutputStreamWriter(new FileOutputStream(output), Charsets.UTF_8), true);
+ try {
+ for (int hiddenState : hiddenStates) {
+ writer.print(hiddenState);
+ writer.print(' ');
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+
+ if (computeLikelihood) {
+ System.out.println("Likelihood: " + HmmEvaluator.modelLikelihood(model, observationsArray, true));
+ }
+ } catch (OptionException e) {
+ CommandLineUtil.printHelp(optionGroup);
+ }
+ }
+}
[47/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJob.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJob.java
new file mode 100644
index 0000000..643b2c3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJob.java
@@ -0,0 +1,337 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.OutputFormat;
+import org.apache.hadoop.mapreduce.lib.input.MultipleInputs;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.EntityEntityWritable;
+import org.apache.mahout.cf.taste.hadoop.RecommendedItemsWritable;
+import org.apache.mahout.cf.taste.hadoop.preparation.PreparePreferenceMatrixJob;
+import org.apache.mahout.cf.taste.hadoop.similarity.item.ItemSimilarityJob;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.RowSimilarityJob;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasures;
+
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+/**
+ * <p>Runs a completely distributed recommender job as a series of mapreduces.</p>
+ * <p/>
+ * <p>Preferences in the input file should look like {@code userID, itemID[, preferencevalue]}</p>
+ * <p/>
+ * <p>
+ * Preference value is optional to accommodate applications that have no notion of a preference value (that is, the user
+ * simply expresses a preference for an item, but no degree of preference).
+ * </p>
+ * <p/>
+ * <p>
+ * The preference value is assumed to be parseable as a {@code double}. The user IDs and item IDs are
+ * parsed as {@code long}s.
+ * </p>
+ * <p/>
+ * <p>Command line arguments specific to this class are:</p>
+ * <p/>
+ * <ol>
+ * <li>--input(path): Directory containing one or more text files with the preference data</li>
+ * <li>--output(path): output path where recommender output should go</li>
+ * <li>--similarityClassname (classname): Name of vector similarity class to instantiate or a predefined similarity
+ * from {@link org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure}</li>
+ * <li>--usersFile (path): only compute recommendations for user IDs contained in this file (optional)</li>
+ * <li>--itemsFile (path): only include item IDs from this file in the recommendations (optional)</li>
+ * <li>--filterFile (path): file containing comma-separated userID,itemID pairs. Used to exclude the item from the
+ * recommendations for that user (optional)</li>
+ * <li>--numRecommendations (integer): Number of recommendations to compute per user (10)</li>
+ * <li>--booleanData (boolean): Treat input data as having no pref values (false)</li>
+ * <li>--maxPrefsPerUser (integer): Maximum number of preferences considered per user in final
+ * recommendation phase (10)</li>
+ * <li>--maxSimilaritiesPerItem (integer): Maximum number of similarities considered per item (100)</li>
+ * <li>--minPrefsPerUser (integer): ignore users with less preferences than this in the similarity computation (1)</li>
+ * <li>--maxPrefsPerUserInItemSimilarity (integer): max number of preferences to consider per user in
+ * the item similarity computation phase,
+ * users with more preferences will be sampled down (1000)</li>
+ * <li>--threshold (double): discard item pairs with a similarity value below this</li>
+ * </ol>
+ * <p/>
+ * <p>General command line options are documented in {@link AbstractJob}.</p>
+ * <p/>
+ * <p>Note that because of how Hadoop parses arguments, all "-D" arguments must appear before all other
+ * arguments.</p>
+ */
+public final class RecommenderJob extends AbstractJob {
+
+ public static final String BOOLEAN_DATA = "booleanData";
+ public static final String DEFAULT_PREPARE_PATH = "preparePreferenceMatrix";
+
+ private static final int DEFAULT_MAX_SIMILARITIES_PER_ITEM = 100;
+ private static final int DEFAULT_MAX_PREFS = 500;
+ private static final int DEFAULT_MIN_PREFS_PER_USER = 1;
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("numRecommendations", "n", "Number of recommendations per user",
+ String.valueOf(AggregateAndRecommendReducer.DEFAULT_NUM_RECOMMENDATIONS));
+ addOption("usersFile", null, "File of users to recommend for", null);
+ addOption("itemsFile", null, "File of items to recommend for", null);
+ addOption("filterFile", "f", "File containing comma-separated userID,itemID pairs. Used to exclude the item from "
+ + "the recommendations for that user (optional)", null);
+ addOption("userItemFile", "uif", "File containing comma-separated userID,itemID pairs (optional). "
+ + "Used to include only these items into recommendations. "
+ + "Cannot be used together with usersFile or itemsFile", null);
+ addOption("booleanData", "b", "Treat input as without pref values", Boolean.FALSE.toString());
+ addOption("maxPrefsPerUser", "mxp",
+ "Maximum number of preferences considered per user in final recommendation phase",
+ String.valueOf(UserVectorSplitterMapper.DEFAULT_MAX_PREFS_PER_USER_CONSIDERED));
+ addOption("minPrefsPerUser", "mp", "ignore users with less preferences than this in the similarity computation "
+ + "(default: " + DEFAULT_MIN_PREFS_PER_USER + ')', String.valueOf(DEFAULT_MIN_PREFS_PER_USER));
+ addOption("maxSimilaritiesPerItem", "m", "Maximum number of similarities considered per item ",
+ String.valueOf(DEFAULT_MAX_SIMILARITIES_PER_ITEM));
+ addOption("maxPrefsInItemSimilarity", "mpiis", "max number of preferences to consider per user or item in the "
+ + "item similarity computation phase, users or items with more preferences will be sampled down (default: "
+ + DEFAULT_MAX_PREFS + ')', String.valueOf(DEFAULT_MAX_PREFS));
+ addOption("similarityClassname", "s", "Name of distributed similarity measures class to instantiate, "
+ + "alternatively use one of the predefined similarities (" + VectorSimilarityMeasures.list() + ')', true);
+ addOption("threshold", "tr", "discard item pairs with a similarity value below this", false);
+ addOption("outputPathForSimilarityMatrix", "opfsm", "write the item similarity matrix to this path (optional)",
+ false);
+ addOption("randomSeed", null, "use this seed for sampling", false);
+ addFlag("sequencefileOutput", null, "write the output into a SequenceFile instead of a text file");
+
+ Map<String, List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Path outputPath = getOutputPath();
+ int numRecommendations = Integer.parseInt(getOption("numRecommendations"));
+ String usersFile = getOption("usersFile");
+ String itemsFile = getOption("itemsFile");
+ String filterFile = getOption("filterFile");
+ String userItemFile = getOption("userItemFile");
+ boolean booleanData = Boolean.valueOf(getOption("booleanData"));
+ int maxPrefsPerUser = Integer.parseInt(getOption("maxPrefsPerUser"));
+ int minPrefsPerUser = Integer.parseInt(getOption("minPrefsPerUser"));
+ int maxPrefsInItemSimilarity = Integer.parseInt(getOption("maxPrefsInItemSimilarity"));
+ int maxSimilaritiesPerItem = Integer.parseInt(getOption("maxSimilaritiesPerItem"));
+ String similarityClassname = getOption("similarityClassname");
+ double threshold = hasOption("threshold")
+ ? Double.parseDouble(getOption("threshold")) : RowSimilarityJob.NO_THRESHOLD;
+ long randomSeed = hasOption("randomSeed")
+ ? Long.parseLong(getOption("randomSeed")) : RowSimilarityJob.NO_FIXED_RANDOM_SEED;
+
+
+ Path prepPath = getTempPath(DEFAULT_PREPARE_PATH);
+ Path similarityMatrixPath = getTempPath("similarityMatrix");
+ Path explicitFilterPath = getTempPath("explicitFilterPath");
+ Path partialMultiplyPath = getTempPath("partialMultiply");
+
+ AtomicInteger currentPhase = new AtomicInteger();
+
+ int numberOfUsers = -1;
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ ToolRunner.run(getConf(), new PreparePreferenceMatrixJob(), new String[]{
+ "--input", getInputPath().toString(),
+ "--output", prepPath.toString(),
+ "--minPrefsPerUser", String.valueOf(minPrefsPerUser),
+ "--booleanData", String.valueOf(booleanData),
+ "--tempDir", getTempPath().toString(),
+ });
+
+ numberOfUsers = HadoopUtil.readInt(new Path(prepPath, PreparePreferenceMatrixJob.NUM_USERS), getConf());
+ }
+
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+
+ /* special behavior if phase 1 is skipped */
+ if (numberOfUsers == -1) {
+ numberOfUsers = (int) HadoopUtil.countRecords(new Path(prepPath, PreparePreferenceMatrixJob.USER_VECTORS),
+ PathType.LIST, null, getConf());
+ }
+
+ //calculate the co-occurrence matrix
+ ToolRunner.run(getConf(), new RowSimilarityJob(), new String[]{
+ "--input", new Path(prepPath, PreparePreferenceMatrixJob.RATING_MATRIX).toString(),
+ "--output", similarityMatrixPath.toString(),
+ "--numberOfColumns", String.valueOf(numberOfUsers),
+ "--similarityClassname", similarityClassname,
+ "--maxObservationsPerRow", String.valueOf(maxPrefsInItemSimilarity),
+ "--maxObservationsPerColumn", String.valueOf(maxPrefsInItemSimilarity),
+ "--maxSimilaritiesPerRow", String.valueOf(maxSimilaritiesPerItem),
+ "--excludeSelfSimilarity", String.valueOf(Boolean.TRUE),
+ "--threshold", String.valueOf(threshold),
+ "--randomSeed", String.valueOf(randomSeed),
+ "--tempDir", getTempPath().toString(),
+ });
+
+ // write out the similarity matrix if the user specified that behavior
+ if (hasOption("outputPathForSimilarityMatrix")) {
+ Path outputPathForSimilarityMatrix = new Path(getOption("outputPathForSimilarityMatrix"));
+
+ Job outputSimilarityMatrix = prepareJob(similarityMatrixPath, outputPathForSimilarityMatrix,
+ SequenceFileInputFormat.class, ItemSimilarityJob.MostSimilarItemPairsMapper.class,
+ EntityEntityWritable.class, DoubleWritable.class, ItemSimilarityJob.MostSimilarItemPairsReducer.class,
+ EntityEntityWritable.class, DoubleWritable.class, TextOutputFormat.class);
+
+ Configuration mostSimilarItemsConf = outputSimilarityMatrix.getConfiguration();
+ mostSimilarItemsConf.set(ItemSimilarityJob.ITEM_ID_INDEX_PATH_STR,
+ new Path(prepPath, PreparePreferenceMatrixJob.ITEMID_INDEX).toString());
+ mostSimilarItemsConf.setInt(ItemSimilarityJob.MAX_SIMILARITIES_PER_ITEM, maxSimilaritiesPerItem);
+ outputSimilarityMatrix.waitForCompletion(true);
+ }
+ }
+
+ //start the multiplication of the co-occurrence matrix by the user vectors
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ Job partialMultiply = new Job(getConf(), "partialMultiply");
+ Configuration partialMultiplyConf = partialMultiply.getConfiguration();
+
+ MultipleInputs.addInputPath(partialMultiply, similarityMatrixPath, SequenceFileInputFormat.class,
+ SimilarityMatrixRowWrapperMapper.class);
+ MultipleInputs.addInputPath(partialMultiply, new Path(prepPath, PreparePreferenceMatrixJob.USER_VECTORS),
+ SequenceFileInputFormat.class, UserVectorSplitterMapper.class);
+ partialMultiply.setJarByClass(ToVectorAndPrefReducer.class);
+ partialMultiply.setMapOutputKeyClass(VarIntWritable.class);
+ partialMultiply.setMapOutputValueClass(VectorOrPrefWritable.class);
+ partialMultiply.setReducerClass(ToVectorAndPrefReducer.class);
+ partialMultiply.setOutputFormatClass(SequenceFileOutputFormat.class);
+ partialMultiply.setOutputKeyClass(VarIntWritable.class);
+ partialMultiply.setOutputValueClass(VectorAndPrefsWritable.class);
+ partialMultiplyConf.setBoolean("mapred.compress.map.output", true);
+ partialMultiplyConf.set("mapred.output.dir", partialMultiplyPath.toString());
+
+ if (usersFile != null) {
+ partialMultiplyConf.set(UserVectorSplitterMapper.USERS_FILE, usersFile);
+ }
+
+ if (userItemFile != null) {
+ partialMultiplyConf.set(IDReader.USER_ITEM_FILE, userItemFile);
+ }
+
+ partialMultiplyConf.setInt(UserVectorSplitterMapper.MAX_PREFS_PER_USER_CONSIDERED, maxPrefsPerUser);
+
+ boolean succeeded = partialMultiply.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ //filter out any users we don't care about
+ /* convert the user/item pairs to filter if a filterfile has been specified */
+ if (filterFile != null) {
+ Job itemFiltering = prepareJob(new Path(filterFile), explicitFilterPath, TextInputFormat.class,
+ ItemFilterMapper.class, VarLongWritable.class, VarLongWritable.class,
+ ItemFilterAsVectorAndPrefsReducer.class, VarIntWritable.class, VectorAndPrefsWritable.class,
+ SequenceFileOutputFormat.class);
+ boolean succeeded = itemFiltering.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ String aggregateAndRecommendInput = partialMultiplyPath.toString();
+ if (filterFile != null) {
+ aggregateAndRecommendInput += "," + explicitFilterPath;
+ }
+
+ Class<? extends OutputFormat> outputFormat = parsedArgs.containsKey("--sequencefileOutput")
+ ? SequenceFileOutputFormat.class : TextOutputFormat.class;
+
+ //extract out the recommendations
+ Job aggregateAndRecommend = prepareJob(
+ new Path(aggregateAndRecommendInput), outputPath, SequenceFileInputFormat.class,
+ PartialMultiplyMapper.class, VarLongWritable.class, PrefAndSimilarityColumnWritable.class,
+ AggregateAndRecommendReducer.class, VarLongWritable.class, RecommendedItemsWritable.class,
+ outputFormat);
+ Configuration aggregateAndRecommendConf = aggregateAndRecommend.getConfiguration();
+ if (itemsFile != null) {
+ aggregateAndRecommendConf.set(AggregateAndRecommendReducer.ITEMS_FILE, itemsFile);
+ }
+
+ if (userItemFile != null) {
+ aggregateAndRecommendConf.set(IDReader.USER_ITEM_FILE, userItemFile);
+ }
+
+ if (filterFile != null) {
+ setS3SafeCombinedInputPath(aggregateAndRecommend, getTempPath(), partialMultiplyPath, explicitFilterPath);
+ }
+ setIOSort(aggregateAndRecommend);
+ aggregateAndRecommendConf.set(AggregateAndRecommendReducer.ITEMID_INDEX_PATH,
+ new Path(prepPath, PreparePreferenceMatrixJob.ITEMID_INDEX).toString());
+ aggregateAndRecommendConf.setInt(AggregateAndRecommendReducer.NUM_RECOMMENDATIONS, numRecommendations);
+ aggregateAndRecommendConf.setBoolean(BOOLEAN_DATA, booleanData);
+ boolean succeeded = aggregateAndRecommend.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ return 0;
+ }
+
+ private static void setIOSort(JobContext job) {
+ Configuration conf = job.getConfiguration();
+ conf.setInt("io.sort.factor", 100);
+ String javaOpts = conf.get("mapred.map.child.java.opts"); // new arg name
+ if (javaOpts == null) {
+ javaOpts = conf.get("mapred.child.java.opts"); // old arg name
+ }
+ int assumedHeapSize = 512;
+ if (javaOpts != null) {
+ Matcher m = Pattern.compile("-Xmx([0-9]+)([mMgG])").matcher(javaOpts);
+ if (m.find()) {
+ assumedHeapSize = Integer.parseInt(m.group(1));
+ String megabyteOrGigabyte = m.group(2);
+ if ("g".equalsIgnoreCase(megabyteOrGigabyte)) {
+ assumedHeapSize *= 1024;
+ }
+ }
+ }
+ // Cap this at 1024MB now; see https://issues.apache.org/jira/browse/MAPREDUCE-2308
+ conf.setInt("io.sort.mb", Math.min(assumedHeapSize / 2, 1024));
+ // For some reason the Merger doesn't report status for a long time; increase
+ // timeout when running these jobs
+ conf.setInt("mapred.task.timeout", 60 * 60 * 1000);
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new RecommenderJob(), args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/SimilarityMatrixRowWrapperMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/SimilarityMatrixRowWrapperMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/SimilarityMatrixRowWrapperMapper.java
new file mode 100644
index 0000000..8ae8215
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/SimilarityMatrixRowWrapperMapper.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * maps a row of the similarity matrix to a {@link VectorOrPrefWritable}
+ *
+ * actually a column from that matrix has to be used but as the similarity matrix is symmetric,
+ * we can use a row instead of having to transpose it
+ */
+public final class SimilarityMatrixRowWrapperMapper extends
+ Mapper<IntWritable,VectorWritable,VarIntWritable,VectorOrPrefWritable> {
+
+ private final VarIntWritable index = new VarIntWritable();
+ private final VectorOrPrefWritable vectorOrPref = new VectorOrPrefWritable();
+
+ @Override
+ protected void map(IntWritable key,
+ VectorWritable value,
+ Context context) throws IOException, InterruptedException {
+ Vector similarityMatrixRow = value.get();
+ /* remove self similarity */
+ similarityMatrixRow.set(key.get(), Double.NaN);
+
+ index.set(key.get());
+ vectorOrPref.set(similarityMatrixRow);
+
+ context.write(index, vectorOrPref);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ToUserVectorsReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ToUserVectorsReducer.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ToUserVectorsReducer.java
new file mode 100644
index 0000000..e6e47fd
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ToUserVectorsReducer.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.io.IOException;
+
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.cf.taste.hadoop.EntityPrefWritable;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * <h1>Input</h1>
+ *
+ * <p>
+ * Takes user IDs as {@link VarLongWritable} mapped to all associated item IDs and preference values, as
+ * {@link EntityPrefWritable}s.
+ * </p>
+ *
+ * <h1>Output</h1>
+ *
+ * <p>
+ * The same user ID mapped to a {@link RandomAccessSparseVector} representation of the same item IDs and
+ * preference values. Item IDs are used as vector indexes; they are hashed into ints to work as indexes with
+ * {@link TasteHadoopUtils#idToIndex(long)}. The mapping is remembered for later with a combination of
+ * {@link ItemIDIndexMapper} and {@link ItemIDIndexReducer}.
+ * </p>
+ */
+public final class ToUserVectorsReducer extends
+ Reducer<VarLongWritable,VarLongWritable,VarLongWritable,VectorWritable> {
+
+ public static final String MIN_PREFERENCES_PER_USER = ToUserVectorsReducer.class.getName()
+ + ".minPreferencesPerUser";
+
+ private int minPreferences;
+
+ public enum Counters { USERS }
+
+ private final VectorWritable userVectorWritable = new VectorWritable();
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
+ minPreferences = ctx.getConfiguration().getInt(MIN_PREFERENCES_PER_USER, 1);
+ }
+
+ @Override
+ protected void reduce(VarLongWritable userID,
+ Iterable<VarLongWritable> itemPrefs,
+ Context context) throws IOException, InterruptedException {
+ Vector userVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ for (VarLongWritable itemPref : itemPrefs) {
+ int index = TasteHadoopUtils.idToIndex(itemPref.get());
+ float value = itemPref instanceof EntityPrefWritable ? ((EntityPrefWritable) itemPref).getPrefValue() : 1.0f;
+ userVector.set(index, value);
+ }
+
+ if (userVector.getNumNondefaultElements() >= minPreferences) {
+ userVectorWritable.set(userVector);
+ userVectorWritable.setWritesLaxPrecision(true);
+ context.getCounter(Counters.USERS).increment(1);
+ context.write(userID, userVectorWritable);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ToVectorAndPrefReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ToVectorAndPrefReducer.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ToVectorAndPrefReducer.java
new file mode 100644
index 0000000..74d30cb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ToVectorAndPrefReducer.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.io.IOException;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.Vector;
+
+public final class ToVectorAndPrefReducer extends
+ Reducer<VarIntWritable,VectorOrPrefWritable,VarIntWritable,VectorAndPrefsWritable> {
+
+ private final VectorAndPrefsWritable vectorAndPrefs = new VectorAndPrefsWritable();
+
+ @Override
+ protected void reduce(VarIntWritable key,
+ Iterable<VectorOrPrefWritable> values,
+ Context context) throws IOException, InterruptedException {
+
+ List<Long> userIDs = Lists.newArrayList();
+ List<Float> prefValues = Lists.newArrayList();
+ Vector similarityMatrixColumn = null;
+ for (VectorOrPrefWritable value : values) {
+ if (value.getVector() == null) {
+ // Then this is a user-pref value
+ userIDs.add(value.getUserID());
+ prefValues.add(value.getValue());
+ } else {
+ // Then this is the column vector
+ if (similarityMatrixColumn != null) {
+ throw new IllegalStateException("Found two similarity-matrix columns for item index " + key.get());
+ }
+ similarityMatrixColumn = value.getVector();
+ }
+ }
+
+ if (similarityMatrixColumn == null) {
+ return;
+ }
+
+ vectorAndPrefs.set(similarityMatrixColumn, userIDs, prefValues);
+ context.write(key, vectorAndPrefs);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorSplitterMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorSplitterMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorSplitterMapper.java
new file mode 100644
index 0000000..2290d06
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/UserVectorSplitterMapper.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.lucene.util.PriorityQueue;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class UserVectorSplitterMapper extends
+ Mapper<VarLongWritable,VectorWritable, VarIntWritable,VectorOrPrefWritable> {
+
+ private static final Logger log = LoggerFactory.getLogger(UserVectorSplitterMapper.class);
+
+ static final String USERS_FILE = "usersFile";
+ static final String MAX_PREFS_PER_USER_CONSIDERED = "maxPrefsPerUserConsidered";
+ static final int DEFAULT_MAX_PREFS_PER_USER_CONSIDERED = 10;
+
+ private int maxPrefsPerUserConsidered;
+ private FastIDSet usersToRecommendFor;
+
+ private final VarIntWritable itemIndexWritable = new VarIntWritable();
+ private final VectorOrPrefWritable vectorOrPref = new VectorOrPrefWritable();
+
+ @Override
+ protected void setup(Context context) throws IOException {
+ Configuration jobConf = context.getConfiguration();
+ maxPrefsPerUserConsidered = jobConf.getInt(MAX_PREFS_PER_USER_CONSIDERED, DEFAULT_MAX_PREFS_PER_USER_CONSIDERED);
+
+ IDReader idReader = new IDReader (jobConf);
+ idReader.readIDs();
+ usersToRecommendFor = idReader.getUserIds();
+ }
+
+ @Override
+ protected void map(VarLongWritable key,
+ VectorWritable value,
+ Context context) throws IOException, InterruptedException {
+ long userID = key.get();
+
+ log.info("UserID = {}", userID);
+
+ if (usersToRecommendFor != null && !usersToRecommendFor.contains(userID)) {
+ return;
+ }
+ Vector userVector = maybePruneUserVector(value.get());
+
+ for (Element e : userVector.nonZeroes()) {
+ itemIndexWritable.set(e.index());
+ vectorOrPref.set(userID, (float) e.get());
+ context.write(itemIndexWritable, vectorOrPref);
+ }
+ }
+
+ private Vector maybePruneUserVector(Vector userVector) {
+ if (userVector.getNumNondefaultElements() <= maxPrefsPerUserConsidered) {
+ return userVector;
+ }
+
+ float smallestLargeValue = findSmallestLargeValue(userVector);
+
+ // "Blank out" small-sized prefs to reduce the amount of partial products
+ // generated later. They're not zeroed, but NaN-ed, so they come through
+ // and can be used to exclude these items from prefs.
+ for (Element e : userVector.nonZeroes()) {
+ float absValue = Math.abs((float) e.get());
+ if (absValue < smallestLargeValue) {
+ e.set(Float.NaN);
+ }
+ }
+
+ return userVector;
+ }
+
+ private float findSmallestLargeValue(Vector userVector) {
+
+ PriorityQueue<Float> topPrefValues = new PriorityQueue<Float>(maxPrefsPerUserConsidered) {
+ @Override
+ protected boolean lessThan(Float f1, Float f2) {
+ return f1 < f2;
+ }
+ };
+
+ for (Element e : userVector.nonZeroes()) {
+ float absValue = Math.abs((float) e.get());
+ topPrefValues.insertWithOverflow(absValue);
+ }
+ return topPrefValues.top();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/VectorAndPrefsWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/VectorAndPrefsWritable.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/VectorAndPrefsWritable.java
new file mode 100644
index 0000000..495a920
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/VectorAndPrefsWritable.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Varint;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public final class VectorAndPrefsWritable implements Writable {
+
+ private Vector vector;
+ private List<Long> userIDs;
+ private List<Float> values;
+
+ public VectorAndPrefsWritable() {
+ }
+
+ public VectorAndPrefsWritable(Vector vector, List<Long> userIDs, List<Float> values) {
+ set(vector, userIDs, values);
+ }
+
+ public void set(Vector vector, List<Long> userIDs, List<Float> values) {
+ this.vector = vector;
+ this.userIDs = userIDs;
+ this.values = values;
+ }
+
+ public Vector getVector() {
+ return vector;
+ }
+
+ public List<Long> getUserIDs() {
+ return userIDs;
+ }
+
+ public List<Float> getValues() {
+ return values;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ VectorWritable vw = new VectorWritable(vector);
+ vw.setWritesLaxPrecision(true);
+ vw.write(out);
+ Varint.writeUnsignedVarInt(userIDs.size(), out);
+ for (int i = 0; i < userIDs.size(); i++) {
+ Varint.writeSignedVarLong(userIDs.get(i), out);
+ out.writeFloat(values.get(i));
+ }
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ VectorWritable writable = new VectorWritable();
+ writable.readFields(in);
+ vector = writable.get();
+ int size = Varint.readUnsignedVarInt(in);
+ userIDs = Lists.newArrayListWithCapacity(size);
+ values = Lists.newArrayListWithCapacity(size);
+ for (int i = 0; i < size; i++) {
+ userIDs.add(Varint.readSignedVarLong(in));
+ values.add(in.readFloat());
+ }
+ }
+
+ @Override
+ public String toString() {
+ return vector + "\t" + userIDs + '\t' + values;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/VectorOrPrefWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/VectorOrPrefWritable.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/VectorOrPrefWritable.java
new file mode 100644
index 0000000..515d7ea
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/VectorOrPrefWritable.java
@@ -0,0 +1,104 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Varint;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public final class VectorOrPrefWritable implements Writable {
+
+ private Vector vector;
+ private long userID;
+ private float value;
+
+ public VectorOrPrefWritable() {
+ }
+
+ public VectorOrPrefWritable(Vector vector) {
+ this.vector = vector;
+ }
+
+ public VectorOrPrefWritable(long userID, float value) {
+ this.userID = userID;
+ this.value = value;
+ }
+
+ public Vector getVector() {
+ return vector;
+ }
+
+ public long getUserID() {
+ return userID;
+ }
+
+ public float getValue() {
+ return value;
+ }
+
+ void set(Vector vector) {
+ this.vector = vector;
+ this.userID = Long.MIN_VALUE;
+ this.value = Float.NaN;
+ }
+
+ public void set(long userID, float value) {
+ this.vector = null;
+ this.userID = userID;
+ this.value = value;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ if (vector == null) {
+ out.writeBoolean(false);
+ Varint.writeSignedVarLong(userID, out);
+ out.writeFloat(value);
+ } else {
+ out.writeBoolean(true);
+ VectorWritable vw = new VectorWritable(vector);
+ vw.setWritesLaxPrecision(true);
+ vw.write(out);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ boolean hasVector = in.readBoolean();
+ if (hasVector) {
+ VectorWritable writable = new VectorWritable();
+ writable.readFields(in);
+ set(writable.get());
+ } else {
+ long theUserID = Varint.readSignedVarLong(in);
+ float theValue = in.readFloat();
+ set(theUserID, theValue);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return vector == null ? userID + ":" + value : vector.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/PreparePreferenceMatrixJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/PreparePreferenceMatrixJob.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/PreparePreferenceMatrixJob.java
new file mode 100644
index 0000000..c64ee38
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/PreparePreferenceMatrixJob.java
@@ -0,0 +1,115 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.preparation;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.EntityPrefWritable;
+import org.apache.mahout.cf.taste.hadoop.ToEntityPrefsMapper;
+import org.apache.mahout.cf.taste.hadoop.ToItemPrefsMapper;
+import org.apache.mahout.cf.taste.hadoop.item.ItemIDIndexMapper;
+import org.apache.mahout.cf.taste.hadoop.item.ItemIDIndexReducer;
+import org.apache.mahout.cf.taste.hadoop.item.RecommenderJob;
+import org.apache.mahout.cf.taste.hadoop.item.ToUserVectorsReducer;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.VectorWritable;
+
+import java.util.List;
+import java.util.Map;
+
+public class PreparePreferenceMatrixJob extends AbstractJob {
+
+ public static final String NUM_USERS = "numUsers.bin";
+ public static final String ITEMID_INDEX = "itemIDIndex";
+ public static final String USER_VECTORS = "userVectors";
+ public static final String RATING_MATRIX = "ratingMatrix";
+
+ private static final int DEFAULT_MIN_PREFS_PER_USER = 1;
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new PreparePreferenceMatrixJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("minPrefsPerUser", "mp", "ignore users with less preferences than this "
+ + "(default: " + DEFAULT_MIN_PREFS_PER_USER + ')', String.valueOf(DEFAULT_MIN_PREFS_PER_USER));
+ addOption("booleanData", "b", "Treat input as without pref values", Boolean.FALSE.toString());
+ addOption("ratingShift", "rs", "shift ratings by this value", "0.0");
+
+ Map<String, List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ int minPrefsPerUser = Integer.parseInt(getOption("minPrefsPerUser"));
+ boolean booleanData = Boolean.valueOf(getOption("booleanData"));
+ float ratingShift = Float.parseFloat(getOption("ratingShift"));
+ //convert items to an internal index
+ Job itemIDIndex = prepareJob(getInputPath(), getOutputPath(ITEMID_INDEX), TextInputFormat.class,
+ ItemIDIndexMapper.class, VarIntWritable.class, VarLongWritable.class, ItemIDIndexReducer.class,
+ VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class);
+ itemIDIndex.setCombinerClass(ItemIDIndexReducer.class);
+ boolean succeeded = itemIDIndex.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ //convert user preferences into a vector per user
+ Job toUserVectors = prepareJob(getInputPath(),
+ getOutputPath(USER_VECTORS),
+ TextInputFormat.class,
+ ToItemPrefsMapper.class,
+ VarLongWritable.class,
+ booleanData ? VarLongWritable.class : EntityPrefWritable.class,
+ ToUserVectorsReducer.class,
+ VarLongWritable.class,
+ VectorWritable.class,
+ SequenceFileOutputFormat.class);
+ toUserVectors.getConfiguration().setBoolean(RecommenderJob.BOOLEAN_DATA, booleanData);
+ toUserVectors.getConfiguration().setInt(ToUserVectorsReducer.MIN_PREFERENCES_PER_USER, minPrefsPerUser);
+ toUserVectors.getConfiguration().set(ToEntityPrefsMapper.RATING_SHIFT, String.valueOf(ratingShift));
+ succeeded = toUserVectors.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ //we need the number of users later
+ int numberOfUsers = (int) toUserVectors.getCounters().findCounter(ToUserVectorsReducer.Counters.USERS).getValue();
+ HadoopUtil.writeInt(numberOfUsers, getOutputPath(NUM_USERS), getConf());
+ //build the rating matrix
+ Job toItemVectors = prepareJob(getOutputPath(USER_VECTORS), getOutputPath(RATING_MATRIX),
+ ToItemVectorsMapper.class, IntWritable.class, VectorWritable.class, ToItemVectorsReducer.class,
+ IntWritable.class, VectorWritable.class);
+ toItemVectors.setCombinerClass(ToItemVectorsReducer.class);
+
+ succeeded = toItemVectors.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ return 0;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsMapper.java
new file mode 100644
index 0000000..5a4144c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsMapper.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.preparation;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+public class ToItemVectorsMapper
+ extends Mapper<VarLongWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private final IntWritable itemID = new IntWritable();
+ private final VectorWritable itemVectorWritable = new VectorWritable();
+
+ @Override
+ protected void map(VarLongWritable rowIndex, VectorWritable vectorWritable, Context ctx)
+ throws IOException, InterruptedException {
+ Vector userRatings = vectorWritable.get();
+
+ int column = TasteHadoopUtils.idToIndex(rowIndex.get());
+
+ itemVectorWritable.setWritesLaxPrecision(true);
+
+ Vector itemVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
+ for (Vector.Element elem : userRatings.nonZeroes()) {
+ itemID.set(elem.index());
+ itemVector.setQuick(column, elem.get());
+ itemVectorWritable.set(itemVector);
+ ctx.write(itemID, itemVectorWritable);
+ // reset vector for reuse
+ itemVector.setQuick(elem.index(), 0.0);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsReducer.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsReducer.java
new file mode 100644
index 0000000..f74511b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/preparation/ToItemVectorsReducer.java
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.preparation;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+public class ToItemVectorsReducer extends Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private final VectorWritable merged = new VectorWritable();
+
+ @Override
+ protected void reduce(IntWritable row, Iterable<VectorWritable> vectors, Context ctx)
+ throws IOException, InterruptedException {
+
+ merged.setWritesLaxPrecision(true);
+ merged.set(VectorWritable.mergeToVector(vectors.iterator()));
+ ctx.write(row, merged);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJob.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJob.java
new file mode 100644
index 0000000..c50fa20
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJob.java
@@ -0,0 +1,233 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.similarity.item;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.EntityEntityWritable;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.hadoop.preparation.PreparePreferenceMatrixJob;
+import org.apache.mahout.cf.taste.similarity.precompute.SimilarItem;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.RowSimilarityJob;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasures;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+
+/**
+ * <p>Distributed precomputation of the item-item-similarities for Itembased Collaborative Filtering</p>
+ *
+ * <p>Preferences in the input file should look like {@code userID,itemID[,preferencevalue]}</p>
+ *
+ * <p>
+ * Preference value is optional to accommodate applications that have no notion of a preference value (that is, the user
+ * simply expresses a preference for an item, but no degree of preference).
+ * </p>
+ *
+ * <p>
+ * The preference value is assumed to be parseable as a {@code double}. The user IDs and item IDs are
+ * parsed as {@code long}s.
+ * </p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--input (path): Directory containing one or more text files with the preference data</li>
+ * <li>--output (path): output path where similarity data should be written</li>
+ * <li>--similarityClassname (classname): Name of distributed similarity measure class to instantiate or a predefined
+ * similarity from {@link org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure}</li>
+ * <li>--maxSimilaritiesPerItem (integer): Maximum number of similarities considered per item (100)</li>
+ * <li>--maxPrefsPerUser (integer): max number of preferences to consider per user, users with more preferences will
+ * be sampled down (1000)</li>
+ * <li>--minPrefsPerUser (integer): ignore users with less preferences than this (1)</li>
+ * <li>--booleanData (boolean): Treat input data as having no pref values (false)</li>
+ * <li>--threshold (double): discard item pairs with a similarity value below this</li>
+ * </ol>
+ *
+ * <p>General command line options are documented in {@link AbstractJob}.</p>
+ *
+ * <p>Note that because of how Hadoop parses arguments, all "-D" arguments must appear before all other arguments.</p>
+ */
+public final class ItemSimilarityJob extends AbstractJob {
+
+ public static final String ITEM_ID_INDEX_PATH_STR = ItemSimilarityJob.class.getName() + ".itemIDIndexPathStr";
+ public static final String MAX_SIMILARITIES_PER_ITEM = ItemSimilarityJob.class.getName() + ".maxSimilarItemsPerItem";
+
+ private static final int DEFAULT_MAX_SIMILAR_ITEMS_PER_ITEM = 100;
+ private static final int DEFAULT_MAX_PREFS = 500;
+ private static final int DEFAULT_MIN_PREFS_PER_USER = 1;
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new ItemSimilarityJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("similarityClassname", "s", "Name of distributed similarity measures class to instantiate, "
+ + "alternatively use one of the predefined similarities (" + VectorSimilarityMeasures.list() + ')');
+ addOption("maxSimilaritiesPerItem", "m", "try to cap the number of similar items per item to this number "
+ + "(default: " + DEFAULT_MAX_SIMILAR_ITEMS_PER_ITEM + ')',
+ String.valueOf(DEFAULT_MAX_SIMILAR_ITEMS_PER_ITEM));
+ addOption("maxPrefs", "mppu", "max number of preferences to consider per user or item, "
+ + "users or items with more preferences will be sampled down (default: " + DEFAULT_MAX_PREFS + ')',
+ String.valueOf(DEFAULT_MAX_PREFS));
+ addOption("minPrefsPerUser", "mp", "ignore users with less preferences than this "
+ + "(default: " + DEFAULT_MIN_PREFS_PER_USER + ')', String.valueOf(DEFAULT_MIN_PREFS_PER_USER));
+ addOption("booleanData", "b", "Treat input as without pref values", String.valueOf(Boolean.FALSE));
+ addOption("threshold", "tr", "discard item pairs with a similarity value below this", false);
+ addOption("randomSeed", null, "use this seed for sampling", false);
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ String similarityClassName = getOption("similarityClassname");
+ int maxSimilarItemsPerItem = Integer.parseInt(getOption("maxSimilaritiesPerItem"));
+ int maxPrefs = Integer.parseInt(getOption("maxPrefs"));
+ int minPrefsPerUser = Integer.parseInt(getOption("minPrefsPerUser"));
+ boolean booleanData = Boolean.valueOf(getOption("booleanData"));
+
+ double threshold = hasOption("threshold")
+ ? Double.parseDouble(getOption("threshold")) : RowSimilarityJob.NO_THRESHOLD;
+ long randomSeed = hasOption("randomSeed")
+ ? Long.parseLong(getOption("randomSeed")) : RowSimilarityJob.NO_FIXED_RANDOM_SEED;
+
+ Path similarityMatrixPath = getTempPath("similarityMatrix");
+ Path prepPath = getTempPath("prepareRatingMatrix");
+
+ AtomicInteger currentPhase = new AtomicInteger();
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ ToolRunner.run(getConf(), new PreparePreferenceMatrixJob(), new String[] {
+ "--input", getInputPath().toString(),
+ "--output", prepPath.toString(),
+ "--minPrefsPerUser", String.valueOf(minPrefsPerUser),
+ "--booleanData", String.valueOf(booleanData),
+ "--tempDir", getTempPath().toString(),
+ });
+ }
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ int numberOfUsers = HadoopUtil.readInt(new Path(prepPath, PreparePreferenceMatrixJob.NUM_USERS), getConf());
+
+ ToolRunner.run(getConf(), new RowSimilarityJob(), new String[] {
+ "--input", new Path(prepPath, PreparePreferenceMatrixJob.RATING_MATRIX).toString(),
+ "--output", similarityMatrixPath.toString(),
+ "--numberOfColumns", String.valueOf(numberOfUsers),
+ "--similarityClassname", similarityClassName,
+ "--maxObservationsPerRow", String.valueOf(maxPrefs),
+ "--maxObservationsPerColumn", String.valueOf(maxPrefs),
+ "--maxSimilaritiesPerRow", String.valueOf(maxSimilarItemsPerItem),
+ "--excludeSelfSimilarity", String.valueOf(Boolean.TRUE),
+ "--threshold", String.valueOf(threshold),
+ "--randomSeed", String.valueOf(randomSeed),
+ "--tempDir", getTempPath().toString(),
+ });
+ }
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ Job mostSimilarItems = prepareJob(similarityMatrixPath, getOutputPath(), SequenceFileInputFormat.class,
+ MostSimilarItemPairsMapper.class, EntityEntityWritable.class, DoubleWritable.class,
+ MostSimilarItemPairsReducer.class, EntityEntityWritable.class, DoubleWritable.class, TextOutputFormat.class);
+ Configuration mostSimilarItemsConf = mostSimilarItems.getConfiguration();
+ mostSimilarItemsConf.set(ITEM_ID_INDEX_PATH_STR,
+ new Path(prepPath, PreparePreferenceMatrixJob.ITEMID_INDEX).toString());
+ mostSimilarItemsConf.setInt(MAX_SIMILARITIES_PER_ITEM, maxSimilarItemsPerItem);
+ boolean succeeded = mostSimilarItems.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ return 0;
+ }
+
+ public static class MostSimilarItemPairsMapper
+ extends Mapper<IntWritable,VectorWritable,EntityEntityWritable,DoubleWritable> {
+
+ private OpenIntLongHashMap indexItemIDMap;
+ private int maxSimilarItemsPerItem;
+
+ @Override
+ protected void setup(Context ctx) {
+ Configuration conf = ctx.getConfiguration();
+ maxSimilarItemsPerItem = conf.getInt(MAX_SIMILARITIES_PER_ITEM, -1);
+ indexItemIDMap = TasteHadoopUtils.readIDIndexMap(conf.get(ITEM_ID_INDEX_PATH_STR), conf);
+
+ Preconditions.checkArgument(maxSimilarItemsPerItem > 0, "maxSimilarItemsPerItem must be greater then 0!");
+ }
+
+ @Override
+ protected void map(IntWritable itemIDIndexWritable, VectorWritable similarityVector, Context ctx)
+ throws IOException, InterruptedException {
+
+ int itemIDIndex = itemIDIndexWritable.get();
+
+ TopSimilarItemsQueue topKMostSimilarItems = new TopSimilarItemsQueue(maxSimilarItemsPerItem);
+
+ for (Vector.Element element : similarityVector.get().nonZeroes()) {
+ SimilarItem top = topKMostSimilarItems.top();
+ double candidateSimilarity = element.get();
+ if (candidateSimilarity > top.getSimilarity()) {
+ top.set(indexItemIDMap.get(element.index()), candidateSimilarity);
+ topKMostSimilarItems.updateTop();
+ }
+ }
+
+ long itemID = indexItemIDMap.get(itemIDIndex);
+ for (SimilarItem similarItem : topKMostSimilarItems.getTopItems()) {
+ long otherItemID = similarItem.getItemID();
+ if (itemID < otherItemID) {
+ ctx.write(new EntityEntityWritable(itemID, otherItemID), new DoubleWritable(similarItem.getSimilarity()));
+ } else {
+ ctx.write(new EntityEntityWritable(otherItemID, itemID), new DoubleWritable(similarItem.getSimilarity()));
+ }
+ }
+ }
+ }
+
+ public static class MostSimilarItemPairsReducer
+ extends Reducer<EntityEntityWritable,DoubleWritable,EntityEntityWritable,DoubleWritable> {
+ @Override
+ protected void reduce(EntityEntityWritable pair, Iterable<DoubleWritable> values, Context ctx)
+ throws IOException, InterruptedException {
+ ctx.write(pair, values.iterator().next());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/TopSimilarItemsQueue.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/TopSimilarItemsQueue.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/TopSimilarItemsQueue.java
new file mode 100644
index 0000000..b0ba24d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/similarity/item/TopSimilarItemsQueue.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.similarity.item;
+
+import com.google.common.collect.Lists;
+import org.apache.lucene.util.PriorityQueue;
+import org.apache.mahout.cf.taste.similarity.precompute.SimilarItem;
+
+import java.util.Collections;
+import java.util.List;
+
+public class TopSimilarItemsQueue extends PriorityQueue<SimilarItem> {
+
+ private static final long SENTINEL_ID = Long.MIN_VALUE;
+
+ private final int maxSize;
+
+ public TopSimilarItemsQueue(int maxSize) {
+ super(maxSize);
+ this.maxSize = maxSize;
+ }
+
+ public List<SimilarItem> getTopItems() {
+ List<SimilarItem> items = Lists.newArrayListWithCapacity(maxSize);
+ while (size() > 0) {
+ SimilarItem topItem = pop();
+ // filter out "sentinel" objects necessary for maintaining an efficient priority queue
+ if (topItem.getItemID() != SENTINEL_ID) {
+ items.add(topItem);
+ }
+ }
+ Collections.reverse(items);
+ return items;
+ }
+
+ @Override
+ protected boolean lessThan(SimilarItem one, SimilarItem two) {
+ return one.getSimilarity() < two.getSimilarity();
+ }
+
+ @Override
+ protected SimilarItem getSentinelObject() {
+ return new SimilarItem(SENTINEL_ID, Double.MIN_VALUE);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/AbstractLongPrimitiveIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/AbstractLongPrimitiveIterator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/AbstractLongPrimitiveIterator.java
new file mode 100644
index 0000000..f46785c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/AbstractLongPrimitiveIterator.java
@@ -0,0 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+public abstract class AbstractLongPrimitiveIterator implements LongPrimitiveIterator {
+
+ @Override
+ public Long next() {
+ return nextLong();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/BitSet.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/BitSet.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/BitSet.java
new file mode 100644
index 0000000..c46b4b6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/BitSet.java
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.io.Serializable;
+import java.util.Arrays;
+
+/** A simplified and streamlined version of {@link java.util.BitSet}. */
+final class BitSet implements Serializable, Cloneable {
+
+ private final long[] bits;
+
+ BitSet(int numBits) {
+ int numLongs = numBits >>> 6;
+ if ((numBits & 0x3F) != 0) {
+ numLongs++;
+ }
+ bits = new long[numLongs];
+ }
+
+ private BitSet(long[] bits) {
+ this.bits = bits;
+ }
+
+ boolean get(int index) {
+ // skipping range check for speed
+ return (bits[index >>> 6] & 1L << (index & 0x3F)) != 0L;
+ }
+
+ void set(int index) {
+ // skipping range check for speed
+ bits[index >>> 6] |= 1L << (index & 0x3F);
+ }
+
+ void clear(int index) {
+ // skipping range check for speed
+ bits[index >>> 6] &= ~(1L << (index & 0x3F));
+ }
+
+ void clear() {
+ int length = bits.length;
+ for (int i = 0; i < length; i++) {
+ bits[i] = 0L;
+ }
+ }
+
+ @Override
+ public BitSet clone() {
+ return new BitSet(bits.clone());
+ }
+
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(bits);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof BitSet)) {
+ return false;
+ }
+ BitSet other = (BitSet) o;
+ return Arrays.equals(bits, other.bits);
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder result = new StringBuilder(64 * bits.length);
+ for (long l : bits) {
+ for (int j = 0; j < 64; j++) {
+ result.append((l & 1L << j) == 0 ? '0' : '1');
+ }
+ result.append(' ');
+ }
+ return result.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Cache.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Cache.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Cache.java
new file mode 100755
index 0000000..9f2a30b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Cache.java
@@ -0,0 +1,178 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+import java.util.Iterator;
+
+/**
+ * <p>
+ * An efficient Map-like class which caches values for keys. Values are not "put" into a {@link Cache};
+ * instead the caller supplies the instance with an implementation of {@link Retriever} which can load the
+ * value for a given key.
+ * </p>
+ *
+ * <p>
+ * The cache does not support {@code null} keys.
+ * </p>
+ *
+ * <p>
+ * Thanks to Amila Jayasooriya for helping evaluate performance of the rewrite of this class, as part of a
+ * Google Summer of Code 2007 project.
+ * </p>
+ */
+public final class Cache<K,V> implements Retriever<K,V> {
+
+ private static final Object NULL = new Object();
+
+ private final FastMap<K,V> cache;
+ private final Retriever<? super K,? extends V> retriever;
+
+ /**
+ * <p>
+ * Creates a new cache based on the given {@link Retriever}.
+ * </p>
+ *
+ * @param retriever
+ * object which can retrieve values for keys
+ */
+ public Cache(Retriever<? super K,? extends V> retriever) {
+ this(retriever, FastMap.NO_MAX_SIZE);
+ }
+
+ /**
+ * <p>
+ * Creates a new cache based on the given {@link Retriever} and with given maximum size.
+ * </p>
+ *
+ * @param retriever
+ * object which can retrieve values for keys
+ * @param maxEntries
+ * maximum number of entries the cache will store before evicting some
+ */
+ public Cache(Retriever<? super K,? extends V> retriever, int maxEntries) {
+ Preconditions.checkArgument(retriever != null, "retriever is null");
+ Preconditions.checkArgument(maxEntries >= 1, "maxEntries must be at least 1");
+ cache = new FastMap<K, V>(11, maxEntries);
+ this.retriever = retriever;
+ }
+
+ /**
+ * <p>
+ * Returns cached value for a key. If it does not exist, it is loaded using a {@link Retriever}.
+ * </p>
+ *
+ * @param key
+ * cache key
+ * @return value for that key
+ * @throws TasteException
+ * if an exception occurs while retrieving a new cached value
+ */
+ @Override
+ public V get(K key) throws TasteException {
+ V value;
+ synchronized (cache) {
+ value = cache.get(key);
+ }
+ if (value == null) {
+ return getAndCacheValue(key);
+ }
+ return value == NULL ? null : value;
+ }
+
+ /**
+ * <p>
+ * Uncaches any existing value for a given key.
+ * </p>
+ *
+ * @param key
+ * cache key
+ */
+ public void remove(K key) {
+ synchronized (cache) {
+ cache.remove(key);
+ }
+ }
+
+ /**
+ * Clears all cache entries whose key matches the given predicate.
+ */
+ public void removeKeysMatching(MatchPredicate<K> predicate) {
+ synchronized (cache) {
+ Iterator<K> it = cache.keySet().iterator();
+ while (it.hasNext()) {
+ K key = it.next();
+ if (predicate.matches(key)) {
+ it.remove();
+ }
+ }
+ }
+ }
+
+ /**
+ * Clears all cache entries whose value matches the given predicate.
+ */
+ public void removeValueMatching(MatchPredicate<V> predicate) {
+ synchronized (cache) {
+ Iterator<V> it = cache.values().iterator();
+ while (it.hasNext()) {
+ V value = it.next();
+ if (predicate.matches(value)) {
+ it.remove();
+ }
+ }
+ }
+ }
+
+ /**
+ * <p>
+ * Clears the cache.
+ * </p>
+ */
+ public void clear() {
+ synchronized (cache) {
+ cache.clear();
+ }
+ }
+
+ private V getAndCacheValue(K key) throws TasteException {
+ V value = retriever.get(key);
+ if (value == null) {
+ value = (V) NULL;
+ }
+ synchronized (cache) {
+ cache.put(key, value);
+ }
+ return value;
+ }
+
+ @Override
+ public String toString() {
+ return "Cache[retriever:" + retriever + ']';
+ }
+
+ /**
+ * Used by {#link #removeKeysMatching(Object)} to decide things that are matching.
+ */
+ public interface MatchPredicate<T> {
+ boolean matches(T thing);
+ }
+
+}
[31/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
new file mode 100644
index 0000000..ebb0614
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelDissector.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Ordering;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.Queue;
+import java.util.Set;
+
+/**
+ * Uses sample data to reverse engineer a feature-hashed model.
+ *
+ * The result gives approximate weights for features and interactions
+ * in the original space.
+ *
+ * The idea is that the hashed encoders have the option of having a trace dictionary. This
+ * tells us where each feature is hashed to, or each feature/value combination in the case
+ * of word-like values. Using this dictionary, we can put values into a synthetic feature
+ * vector in just the locations specified by a single feature or interaction. Then we can
+ * push this through a linear part of a model to see the contribution of that input. For
+ * any generalized linear model like logistic regression, there is a linear part of the
+ * model that allows this.
+ *
+ * What the ModelDissector does is to accept a trace dictionary and a model in an update
+ * method. It figures out the weights for the elements in the trace dictionary and stashes
+ * them. Then in a summary method, the biggest weights are returned. This update/flush
+ * style is used so that the trace dictionary doesn't have to grow to enormous levels,
+ * but instead can be cleared between updates.
+ */
+public class ModelDissector {
+ private final Map<String,Vector> weightMap;
+
+ public ModelDissector() {
+ weightMap = Maps.newHashMap();
+ }
+
+ /**
+ * Probes a model to determine the effect of a particular variable. This is done
+ * with the ade of a trace dictionary which has recorded the locations in the feature
+ * vector that are modified by various variable values. We can set these locations to
+ * 1 and then look at the resulting score. This tells us the weight the model places
+ * on that variable.
+ * @param features A feature vector to use (destructively)
+ * @param traceDictionary A trace dictionary containing variables and what locations
+ * in the feature vector are affected by them
+ * @param learner The model that we are probing to find weights on features
+ */
+
+ public void update(Vector features, Map<String, Set<Integer>> traceDictionary, AbstractVectorClassifier learner) {
+ // zero out feature vector
+ features.assign(0);
+ for (Map.Entry<String, Set<Integer>> entry : traceDictionary.entrySet()) {
+ // get a feature and locations where it is stored in the feature vector
+ String key = entry.getKey();
+ Set<Integer> value = entry.getValue();
+
+ // if we haven't looked at this feature yet
+ if (!weightMap.containsKey(key)) {
+ // put probe values in the feature vector
+ for (Integer where : value) {
+ features.set(where, 1);
+ }
+
+ // see what the model says
+ Vector v = learner.classifyNoLink(features);
+ weightMap.put(key, v);
+
+ // and zero out those locations again
+ for (Integer where : value) {
+ features.set(where, 0);
+ }
+ }
+ }
+ }
+
+ /**
+ * Returns the n most important features with their
+ * weights, most important category and the top few
+ * categories that they affect.
+ * @param n How many results to return.
+ * @return A list of the top variables.
+ */
+ public List<Weight> summary(int n) {
+ Queue<Weight> pq = new PriorityQueue<Weight>();
+ for (Map.Entry<String, Vector> entry : weightMap.entrySet()) {
+ pq.add(new Weight(entry.getKey(), entry.getValue()));
+ while (pq.size() > n) {
+ pq.poll();
+ }
+ }
+ List<Weight> r = Lists.newArrayList(pq);
+ Collections.sort(r, Ordering.natural().reverse());
+ return r;
+ }
+
+ private static final class Category implements Comparable<Category> {
+ private final int index;
+ private final double weight;
+
+ private Category(int index, double weight) {
+ this.index = index;
+ this.weight = weight;
+ }
+
+ @Override
+ public int compareTo(Category o) {
+ int r = Double.compare(Math.abs(weight), Math.abs(o.weight));
+ if (r == 0) {
+ if (o.index < index) {
+ return -1;
+ }
+ if (o.index > index) {
+ return 1;
+ }
+ return 0;
+ }
+ return r;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof Category)) {
+ return false;
+ }
+ Category other = (Category) o;
+ return index == other.index && weight == other.weight;
+ }
+
+ @Override
+ public int hashCode() {
+ return RandomUtils.hashDouble(weight) ^ index;
+ }
+
+ }
+
+ public static class Weight implements Comparable<Weight> {
+ private final String feature;
+ private final double value;
+ private final int maxIndex;
+ private final List<Category> categories;
+
+ public Weight(String feature, Vector weights) {
+ this(feature, weights, 3);
+ }
+
+ public Weight(String feature, Vector weights, int n) {
+ this.feature = feature;
+ // pick out the weight with the largest abs value, but don't forget the sign
+ Queue<Category> biggest = new PriorityQueue<Category>(n + 1, Ordering.natural());
+ for (Vector.Element element : weights.all()) {
+ biggest.add(new Category(element.index(), element.get()));
+ while (biggest.size() > n) {
+ biggest.poll();
+ }
+ }
+ categories = Lists.newArrayList(biggest);
+ Collections.sort(categories, Ordering.natural().reverse());
+ value = categories.get(0).weight;
+ maxIndex = categories.get(0).index;
+ }
+
+ @Override
+ public int compareTo(Weight other) {
+ int r = Double.compare(Math.abs(this.value), Math.abs(other.value));
+ if (r == 0) {
+ return feature.compareTo(other.feature);
+ }
+ return r;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof Weight)) {
+ return false;
+ }
+ Weight other = (Weight) o;
+ return feature.equals(other.feature)
+ && value == other.value
+ && maxIndex == other.maxIndex
+ && categories.equals(other.categories);
+ }
+
+ @Override
+ public int hashCode() {
+ return feature.hashCode() ^ RandomUtils.hashDouble(value) ^ maxIndex ^ categories.hashCode();
+ }
+
+ public String getFeature() {
+ return feature;
+ }
+
+ public double getWeight() {
+ return value;
+ }
+
+ public double getWeight(int n) {
+ return categories.get(n).weight;
+ }
+
+ public double getCategory(int n) {
+ return categories.get(n).index;
+ }
+
+ public int getMaxImpact() {
+ return maxIndex;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
new file mode 100644
index 0000000..f0150e9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/ModelSerializer.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+
+/**
+ * Provides the ability to store SGD model-related objects as binary files.
+ */
+public final class ModelSerializer {
+
+ // static class ... don't instantiate
+ private ModelSerializer() {
+ }
+
+ public static void writeBinary(String path, CrossFoldLearner model) throws IOException {
+ DataOutputStream out = new DataOutputStream(new FileOutputStream(path));
+ try {
+ PolymorphicWritable.write(out, model);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public static void writeBinary(String path, OnlineLogisticRegression model) throws IOException {
+ DataOutputStream out = new DataOutputStream(new FileOutputStream(path));
+ try {
+ PolymorphicWritable.write(out, model);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public static void writeBinary(String path, AdaptiveLogisticRegression model) throws IOException {
+ DataOutputStream out = new DataOutputStream(new FileOutputStream(path));
+ try {
+ PolymorphicWritable.write(out, model);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public static <T extends Writable> T readBinary(InputStream in, Class<T> clazz) throws IOException {
+ DataInput dataIn = new DataInputStream(in);
+ try {
+ return PolymorphicWritable.read(dataIn, clazz);
+ } finally {
+ Closeables.close(in, false);
+ }
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
new file mode 100644
index 0000000..7a9ca83
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegression.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Extends the basic on-line logistic regression learner with a specific set of learning
+ * rate annealing schedules.
+ */
+public class OnlineLogisticRegression extends AbstractOnlineLogisticRegression implements Writable {
+ public static final int WRITABLE_VERSION = 1;
+
+ // these next two control decayFactor^steps exponential type of annealing
+ // learning rate and decay factor
+ private double mu0 = 1;
+ private double decayFactor = 1 - 1.0e-3;
+
+ // these next two control 1/steps^forget type annealing
+ private int stepOffset = 10;
+ // -1 equals even weighting of all examples, 0 means only use exponential annealing
+ private double forgettingExponent = -0.5;
+
+ // controls how per term annealing works
+ private int perTermAnnealingOffset = 20;
+
+ public OnlineLogisticRegression() {
+ // private constructor available for serialization, but not normal use
+ }
+
+ public OnlineLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) {
+ this.numCategories = numCategories;
+ this.prior = prior;
+
+ updateSteps = new DenseVector(numFeatures);
+ updateCounts = new DenseVector(numFeatures).assign(perTermAnnealingOffset);
+ beta = new DenseMatrix(numCategories - 1, numFeatures);
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param alpha New value of decayFactor, the exponential decay rate for the learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public OnlineLogisticRegression alpha(double alpha) {
+ this.decayFactor = alpha;
+ return this;
+ }
+
+ @Override
+ public OnlineLogisticRegression lambda(double lambda) {
+ // we only over-ride this to provide a more restrictive return type
+ super.lambda(lambda);
+ return this;
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param learningRate New value of initial learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public OnlineLogisticRegression learningRate(double learningRate) {
+ this.mu0 = learningRate;
+ return this;
+ }
+
+ public OnlineLogisticRegression stepOffset(int stepOffset) {
+ this.stepOffset = stepOffset;
+ return this;
+ }
+
+ public OnlineLogisticRegression decayExponent(double decayExponent) {
+ if (decayExponent > 0) {
+ decayExponent = -decayExponent;
+ }
+ this.forgettingExponent = decayExponent;
+ return this;
+ }
+
+
+ @Override
+ public double perTermLearningRate(int j) {
+ return Math.sqrt(perTermAnnealingOffset / updateCounts.get(j));
+ }
+
+ @Override
+ public double currentLearningRate() {
+ return mu0 * Math.pow(decayFactor, getStep()) * Math.pow(getStep() + stepOffset, forgettingExponent);
+ }
+
+ public void copyFrom(OnlineLogisticRegression other) {
+ super.copyFrom(other);
+ mu0 = other.mu0;
+ decayFactor = other.decayFactor;
+
+ stepOffset = other.stepOffset;
+ forgettingExponent = other.forgettingExponent;
+
+ perTermAnnealingOffset = other.perTermAnnealingOffset;
+ }
+
+ public OnlineLogisticRegression copy() {
+ close();
+ OnlineLogisticRegression r = new OnlineLogisticRegression(numCategories(), numFeatures(), prior);
+ r.copyFrom(this);
+ return r;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(WRITABLE_VERSION);
+ out.writeDouble(mu0);
+ out.writeDouble(getLambda());
+ out.writeDouble(decayFactor);
+ out.writeInt(stepOffset);
+ out.writeInt(step);
+ out.writeDouble(forgettingExponent);
+ out.writeInt(perTermAnnealingOffset);
+ out.writeInt(numCategories);
+ MatrixWritable.writeMatrix(out, beta);
+ PolymorphicWritable.write(out, prior);
+ VectorWritable.writeVector(out, updateCounts);
+ VectorWritable.writeVector(out, updateSteps);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int version = in.readInt();
+ if (version == WRITABLE_VERSION) {
+ mu0 = in.readDouble();
+ lambda(in.readDouble());
+ decayFactor = in.readDouble();
+ stepOffset = in.readInt();
+ step = in.readInt();
+ forgettingExponent = in.readDouble();
+ perTermAnnealingOffset = in.readInt();
+ numCategories = in.readInt();
+ beta = MatrixWritable.readMatrix(in);
+ prior = PolymorphicWritable.read(in, PriorFunction.class);
+
+ updateCounts = VectorWritable.readVector(in);
+ updateSteps = VectorWritable.readVector(in);
+ } else {
+ throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
new file mode 100644
index 0000000..c51361c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Online passive aggressive learner that tries to minimize the label ranking hinge loss.
+ * Implements a multi-class linear classifier minimizing rank loss.
+ * based on "Online passive aggressive algorithms" by Cramer et al, 2006.
+ * Note: Its better to use classifyNoLink because the loss function is based
+ * on ensuring that the score of the good label is larger than the next
+ * highest label by some margin. The conversion to probability is just done
+ * by exponentiating and dividing by the sum and is empirical at best.
+ * Your features should be pre-normalized in some sensible range, for example,
+ * by subtracting the mean and standard deviation, if they are very
+ * different in magnitude from each other.
+ */
+public class PassiveAggressive extends AbstractVectorClassifier implements OnlineLearner, Writable {
+
+ private static final Logger log = LoggerFactory.getLogger(PassiveAggressive.class);
+
+ public static final int WRITABLE_VERSION = 1;
+
+ // the learning rate of the algorithm
+ private double learningRate = 0.1;
+
+ // loss statistics.
+ private int lossCount = 0;
+ private double lossSum = 0;
+
+ // coefficients for the classification. This is a dense matrix
+ // that is (numCategories ) x numFeatures
+ private Matrix weights;
+
+ // number of categories we are classifying.
+ private int numCategories;
+
+ public PassiveAggressive(int numCategories, int numFeatures) {
+ this.numCategories = numCategories;
+ weights = new DenseMatrix(numCategories, numFeatures);
+ weights.assign(0.0);
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param learningRate New value of initial learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public PassiveAggressive learningRate(double learningRate) {
+ this.learningRate = learningRate;
+ return this;
+ }
+
+ public void copyFrom(PassiveAggressive other) {
+ learningRate = other.learningRate;
+ numCategories = other.numCategories;
+ weights = other.weights;
+ }
+
+ @Override
+ public int numCategories() {
+ return numCategories;
+ }
+
+ @Override
+ public Vector classify(Vector instance) {
+ Vector result = classifyNoLink(instance);
+ // Convert to probabilities by exponentiation.
+ double max = result.maxValue();
+ result.assign(Functions.minus(max)).assign(Functions.EXP);
+ result = result.divide(result.norm(1));
+
+ return result.viewPart(1, result.size() - 1);
+ }
+
+ @Override
+ public Vector classifyNoLink(Vector instance) {
+ Vector result = new DenseVector(weights.numRows());
+ result.assign(0);
+ for (int i = 0; i < weights.numRows(); i++) {
+ result.setQuick(i, weights.viewRow(i).dot(instance));
+ }
+ return result;
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ double v1 = weights.viewRow(0).dot(instance);
+ double v2 = weights.viewRow(1).dot(instance);
+ v1 = Math.exp(v1);
+ v2 = Math.exp(v2);
+ return v2 / (v1 + v2);
+ }
+
+ public int numFeatures() {
+ return weights.numCols();
+ }
+
+ public PassiveAggressive copy() {
+ close();
+ PassiveAggressive r = new PassiveAggressive(numCategories(), numFeatures());
+ r.copyFrom(this);
+ return r;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(WRITABLE_VERSION);
+ out.writeDouble(learningRate);
+ out.writeInt(numCategories);
+ MatrixWritable.writeMatrix(out, weights);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int version = in.readInt();
+ if (version == WRITABLE_VERSION) {
+ learningRate = in.readDouble();
+ numCategories = in.readInt();
+ weights = MatrixWritable.readMatrix(in);
+ } else {
+ throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
+ }
+ }
+
+ @Override
+ public void close() {
+ // This is an online classifier, nothing to do.
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ if (lossCount > 1000) {
+ log.info("Avg. Loss = {}", lossSum / lossCount);
+ lossCount = 0;
+ lossSum = 0;
+ }
+ Vector result = classifyNoLink(instance);
+ double myScore = result.get(actual);
+ // Find the highest score that is not actual.
+ int otherIndex = result.maxValueIndex();
+ double otherValue = result.get(otherIndex);
+ if (otherIndex == actual) {
+ result.setQuick(otherIndex, Double.NEGATIVE_INFINITY);
+ otherIndex = result.maxValueIndex();
+ otherValue = result.get(otherIndex);
+ }
+ double loss = 1.0 - myScore + otherValue;
+ lossCount += 1;
+ if (loss >= 0) {
+ lossSum += loss;
+ double tau = loss / (instance.dot(instance) + 0.5 / learningRate);
+ Vector delta = instance.clone();
+ delta.assign(Functions.mult(tau));
+ weights.viewRow(actual).assign(delta, Functions.PLUS);
+// delta.addTo(weights.viewRow(actual));
+ delta.assign(Functions.mult(-1));
+ weights.viewRow(otherIndex).assign(delta, Functions.PLUS);
+// delta.addTo(weights.viewRow(otherIndex));
+ }
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ train(0, null, actual, instance);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java
new file mode 100644
index 0000000..90062a6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/PolymorphicWritable.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.ClassUtils;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Utilities that write a class name and then serialize using writables.
+ */
+public final class PolymorphicWritable {
+
+ private PolymorphicWritable() {
+ }
+
+ public static <T extends Writable> void write(DataOutput dataOutput, T value) throws IOException {
+ dataOutput.writeUTF(value.getClass().getName());
+ value.write(dataOutput);
+ }
+
+ public static <T extends Writable> T read(DataInput dataInput, Class<? extends T> clazz) throws IOException {
+ String className = dataInput.readUTF();
+ T r = ClassUtils.instantiateAs(className, clazz);
+ r.readFields(dataInput);
+ return r;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
new file mode 100644
index 0000000..857f061
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/PriorFunction.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+
+/**
+ * A prior is used to regularize the learning algorithm. This allows a trade-off to
+ * be made between complexity of the model being learned and the accuracy with which
+ * the model fits the training data. There are different definitions of complexity
+ * which can be approximated using different priors. For large sparse systems, such
+ * as text classification, the L1 prior is often used which favors sparse models.
+ */
+public interface PriorFunction extends Writable {
+ /**
+ * Applies the regularization to a coefficient.
+ * @param oldValue The previous value.
+ * @param generations The number of generations.
+ * @param learningRate The learning rate with lambda baked in.
+ * @return The new coefficient value after regularization.
+ */
+ double age(double oldValue, double generations, double learningRate);
+
+ /**
+ * Returns the log of the probability of a particular coefficient value according to the prior.
+ * @param betaIJ The coefficient.
+ * @return The log probability.
+ */
+ double logP(double betaIJ);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
new file mode 100644
index 0000000..b52cb8c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/RankingGradient.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * Uses the difference between this instance and recent history to get a
+ * gradient that optimizes ranking performance. Essentially this is the
+ * same as directly optimizing AUC. It isn't expected that this would
+ * be used alone, but rather that a MixedGradient would use it and a
+ * DefaultGradient together to combine both ranking and log-likelihood
+ * goals.
+ */
+public class RankingGradient implements Gradient {
+
+ private static final Gradient BASIC = new DefaultGradient();
+
+ private int window = 10;
+
+ private final List<Deque<Vector>> history = Lists.newArrayList();
+
+ public RankingGradient(int window) {
+ this.window = window;
+ }
+
+ @Override
+ public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+ addToHistory(actual, instance);
+
+ // now compute average gradient versus saved vectors from the other side
+ Deque<Vector> otherSide = history.get(1 - actual);
+ int n = otherSide.size();
+
+ Vector r = null;
+ for (Vector other : otherSide) {
+ Vector g = BASIC.apply(groupKey, actual, instance.minus(other), classifier);
+
+ if (r == null) {
+ r = g;
+ } else {
+ r.assign(g, Functions.plusMult(1.0 / n));
+ }
+ }
+ return r;
+ }
+
+ public void addToHistory(int actual, Vector instance) {
+ while (history.size() <= actual) {
+ history.add(new ArrayDeque<Vector>(window));
+ }
+ // save this instance
+ Deque<Vector> ourSide = history.get(actual);
+ ourSide.add(instance);
+ while (ourSide.size() >= window) {
+ ourSide.pollFirst();
+ }
+ }
+
+ public Gradient getBaseGradient() {
+ return BASIC;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java
new file mode 100644
index 0000000..fbc825d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/RecordFactory.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * A record factor understands how to convert a line of data into fields and then into a vector.
+ */
+public interface RecordFactory {
+ void defineTargetCategories(List<String> values);
+
+ RecordFactory maxTargetValue(int max);
+
+ boolean usesFirstLineAsSchema();
+
+ int processLine(String line, Vector featureVector);
+
+ Iterable<String> getPredictors();
+
+ Map<String, Set<Integer>> getTraceDictionary();
+
+ RecordFactory includeBiasTerm(boolean useBias);
+
+ List<String> getTargetCategories();
+
+ void firstLine(String line);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java
new file mode 100644
index 0000000..0a7b6a7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/TPrior.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.commons.math3.special.Gamma;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Provides a t-distribution as a prior.
+ */
+public class TPrior implements PriorFunction {
+ private double df;
+
+ public TPrior(double df) {
+ this.df = df;
+ }
+
+ @Override
+ public double age(double oldValue, double generations, double learningRate) {
+ for (int i = 0; i < generations; i++) {
+ oldValue -= learningRate * oldValue * (df + 1.0) / (df + oldValue * oldValue);
+ }
+ return oldValue;
+ }
+
+ @Override
+ public double logP(double betaIJ) {
+ return Gamma.logGamma((df + 1.0) / 2.0)
+ - Math.log(df * Math.PI)
+ - Gamma.logGamma(df / 2.0)
+ - (df + 1.0) / 2.0 * Math.log1p(betaIJ * betaIJ);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(df);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ df = in.readDouble();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java
new file mode 100644
index 0000000..23c812f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/UniformPrior.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * A uniform prior. This is an improper prior that corresponds to no regularization at all.
+ */
+public class UniformPrior implements PriorFunction {
+ @Override
+ public double age(double oldValue, double generations, double learningRate) {
+ return oldValue;
+ }
+
+ @Override
+ public double logP(double betaIJ) {
+ return 0;
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ // nothing to write
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ // stateless class is trivial to read
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java
new file mode 100644
index 0000000..c2ad966
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/package-info.java
@@ -0,0 +1,23 @@
+/**
+ * <p>Implements a variety of on-line logistric regression classifiers using SGD-based algorithms.
+ * SGD stands for Stochastic Gradient Descent and refers to a class of learning algorithms
+ * that make it relatively easy to build high speed on-line learning algorithms for a variety
+ * of problems, notably including supervised learning for classification.</p>
+ *
+ * <p>The primary class of interest in the this package is
+ * {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} which contains a
+ * number (typically 5) of sub-learners, each of which is given a different portion of the
+ * training data. Each of these sub-learners can then be evaluated on the data it was not
+ * trained on. This allows fully incremental learning while still getting cross-validated
+ * performance estimates.</p>
+ *
+ * <p>The CrossFoldLearner implements {@link org.apache.mahout.classifier.OnlineLearner}
+ * and thus expects to be fed input in the form
+ * of a target variable and a feature vector. The target variable is simply an integer in the
+ * half-open interval [0..numFeatures) where numFeatures is defined when the CrossFoldLearner
+ * is constructed. The creation of feature vectors is facilitated by the classes that inherit
+ * from {@link org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder}.
+ * These classes currently implement a form of feature hashing with
+ * multiple probes to limit feature ambiguity.</p>
+ */
+package org.apache.mahout.classifier.sgd;
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java b/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
new file mode 100644
index 0000000..cc05beb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
@@ -0,0 +1,391 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.HashMap;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.SquareRootFunction;
+import org.codehaus.jackson.map.ObjectMapper;
+
+public abstract class AbstractCluster implements Cluster {
+
+ // cluster persistent state
+ private int id;
+
+ private long numObservations;
+
+ private long totalObservations;
+
+ private Vector center;
+
+ private Vector radius;
+
+ // the observation statistics
+ private double s0;
+
+ private Vector s1;
+
+ private Vector s2;
+
+ private static final ObjectMapper jxn = new ObjectMapper();
+
+ protected AbstractCluster() {}
+
+ protected AbstractCluster(Vector point, int id2) {
+ this.numObservations = (long) 0;
+ this.totalObservations = (long) 0;
+ this.center = point.clone();
+ this.radius = center.like();
+ this.s0 = (double) 0;
+ this.s1 = center.like();
+ this.s2 = center.like();
+ this.id = id2;
+ }
+
+ protected AbstractCluster(Vector center2, Vector radius2, int id2) {
+ this.numObservations = (long) 0;
+ this.totalObservations = (long) 0;
+ this.center = new RandomAccessSparseVector(center2);
+ this.radius = new RandomAccessSparseVector(radius2);
+ this.s0 = (double) 0;
+ this.s1 = center.like();
+ this.s2 = center.like();
+ this.id = id2;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(id);
+ out.writeLong(getNumObservations());
+ out.writeLong(getTotalObservations());
+ VectorWritable.writeVector(out, getCenter());
+ VectorWritable.writeVector(out, getRadius());
+ out.writeDouble(s0);
+ VectorWritable.writeVector(out, s1);
+ VectorWritable.writeVector(out, s2);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.id = in.readInt();
+ this.setNumObservations(in.readLong());
+ this.setTotalObservations(in.readLong());
+ this.setCenter(VectorWritable.readVector(in));
+ this.setRadius(VectorWritable.readVector(in));
+ this.setS0(in.readDouble());
+ this.setS1(VectorWritable.readVector(in));
+ this.setS2(VectorWritable.readVector(in));
+ }
+
+ @Override
+ public void configure(Configuration job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ // nothing to do
+ }
+
+ @Override
+ public int getId() {
+ return id;
+ }
+
+ /**
+ * @param id
+ * the id to set
+ */
+ protected void setId(int id) {
+ this.id = id;
+ }
+
+ @Override
+ public long getNumObservations() {
+ return numObservations;
+ }
+
+ /**
+ * @param l
+ * the numPoints to set
+ */
+ protected void setNumObservations(long l) {
+ this.numObservations = l;
+ }
+
+ @Override
+ public long getTotalObservations() {
+ return totalObservations;
+ }
+
+ protected void setTotalObservations(long totalPoints) {
+ this.totalObservations = totalPoints;
+ }
+
+ @Override
+ public Vector getCenter() {
+ return center;
+ }
+
+ /**
+ * @param center
+ * the center to set
+ */
+ protected void setCenter(Vector center) {
+ this.center = center;
+ }
+
+ @Override
+ public Vector getRadius() {
+ return radius;
+ }
+
+ /**
+ * @param radius
+ * the radius to set
+ */
+ protected void setRadius(Vector radius) {
+ this.radius = radius;
+ }
+
+ /**
+ * @return the s0
+ */
+ protected double getS0() {
+ return s0;
+ }
+
+ protected void setS0(double s0) {
+ this.s0 = s0;
+ }
+
+ /**
+ * @return the s1
+ */
+ protected Vector getS1() {
+ return s1;
+ }
+
+ protected void setS1(Vector s1) {
+ this.s1 = s1;
+ }
+
+ /**
+ * @return the s2
+ */
+ protected Vector getS2() {
+ return s2;
+ }
+
+ protected void setS2(Vector s2) {
+ this.s2 = s2;
+ }
+
+ @Override
+ public void observe(Model<VectorWritable> x) {
+ AbstractCluster cl = (AbstractCluster) x;
+ setS0(getS0() + cl.getS0());
+ setS1(getS1().plus(cl.getS1()));
+ setS2(getS2().plus(cl.getS2()));
+ }
+
+ @Override
+ public void observe(VectorWritable x) {
+ observe(x.get());
+ }
+
+ @Override
+ public void observe(VectorWritable x, double weight) {
+ observe(x.get(), weight);
+ }
+
+ public void observe(Vector x, double weight) {
+ if (weight == 1.0) {
+ observe(x);
+ } else {
+ setS0(getS0() + weight);
+ Vector weightedX = x.times(weight);
+ if (getS1() == null) {
+ setS1(weightedX);
+ } else {
+ getS1().assign(weightedX, Functions.PLUS);
+ }
+ Vector x2 = x.times(x).times(weight);
+ if (getS2() == null) {
+ setS2(x2);
+ } else {
+ getS2().assign(x2, Functions.PLUS);
+ }
+ }
+ }
+
+ public void observe(Vector x) {
+ setS0(getS0() + 1);
+ if (getS1() == null) {
+ setS1(x.clone());
+ } else {
+ getS1().assign(x, Functions.PLUS);
+ }
+ Vector x2 = x.times(x);
+ if (getS2() == null) {
+ setS2(x2);
+ } else {
+ getS2().assign(x2, Functions.PLUS);
+ }
+ }
+
+
+ @Override
+ public void computeParameters() {
+ if (getS0() == 0) {
+ return;
+ }
+ setNumObservations((long) getS0());
+ setTotalObservations(getTotalObservations() + getNumObservations());
+ setCenter(getS1().divide(getS0()));
+ // compute the component stds
+ if (getS0() > 1) {
+ setRadius(getS2().times(getS0()).minus(getS1().times(getS1())).assign(new SquareRootFunction()).divide(getS0()));
+ }
+ setS0(0);
+ setS1(center.like());
+ setS2(center.like());
+ }
+
+ @Override
+ public String asFormatString(String[] bindings) {
+ String fmtString = "";
+ try {
+ fmtString = jxn.writeValueAsString(asJson(bindings));
+ } catch (IOException e) {
+ log.error("Error writing JSON as String.", e);
+ }
+ return fmtString;
+ }
+
+ public Map<String,Object> asJson(String[] bindings) {
+ Map<String,Object> dict = new HashMap<>();
+ dict.put("identifier", getIdentifier());
+ dict.put("n", getNumObservations());
+ if (getCenter() != null) {
+ try {
+ dict.put("c", formatVectorAsJson(getCenter(), bindings));
+ } catch (IOException e) {
+ log.error("IOException: ", e);
+ }
+ }
+ if (getRadius() != null) {
+ try {
+ dict.put("r", formatVectorAsJson(getRadius(), bindings));
+ } catch (IOException e) {
+ log.error("IOException: ", e);
+ }
+ }
+ return dict;
+ }
+
+ public abstract String getIdentifier();
+
+ /**
+ * Compute the centroid by averaging the pointTotals
+ *
+ * @return the new centroid
+ */
+ public Vector computeCentroid() {
+ return getS0() == 0 ? getCenter() : getS1().divide(getS0());
+ }
+
+ /**
+ * Return a human-readable formatted string representation of the vector, not
+ * intended to be complete nor usable as an input/output representation
+ */
+ public static String formatVector(Vector v, String[] bindings) {
+ String fmtString = "";
+ try {
+ fmtString = jxn.writeValueAsString(formatVectorAsJson(v, bindings));
+ } catch (IOException e) {
+ log.error("Error writing JSON as String.", e);
+ }
+ return fmtString;
+ }
+
+ /**
+ * Create a List of HashMaps containing vector terms and weights
+ *
+ * @return List<Object>
+ */
+ public static List<Object> formatVectorAsJson(Vector v, String[] bindings) throws IOException {
+
+ boolean hasBindings = bindings != null;
+ boolean isSparse = !v.isDense() && v.getNumNondefaultElements() != v.size();
+
+ // we assume sequential access in the output
+ Vector provider = v.isSequentialAccess() ? v : new SequentialAccessSparseVector(v);
+
+ List<Object> terms = Lists.newLinkedList();
+ String term = "";
+
+ for (Element elem : provider.nonZeroes()) {
+
+ if (hasBindings && bindings.length >= elem.index() + 1 && bindings[elem.index()] != null) {
+ term = bindings[elem.index()];
+ } else if (hasBindings || isSparse) {
+ term = String.valueOf(elem.index());
+ }
+
+ Map<String, Object> term_entry = Maps.newHashMap();
+ double roundedWeight = (double) Math.round(elem.get() * 1000) / 1000;
+ if (hasBindings || isSparse) {
+ term_entry.put(term, roundedWeight);
+ terms.add(term_entry);
+ } else {
+ terms.add(roundedWeight);
+ }
+ }
+
+ return terms;
+ }
+
+ @Override
+ public boolean isConverged() {
+ // Convergence has no meaning yet, perhaps in subclasses
+ return false;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/Cluster.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/Cluster.java b/mr/src/main/java/org/apache/mahout/clustering/Cluster.java
new file mode 100644
index 0000000..07d6927
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/Cluster.java
@@ -0,0 +1,90 @@
+/* Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.common.parameters.Parametered;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.util.Map;
+
+/**
+ * Implementations of this interface have a printable representation and certain
+ * attributes that are common across all clustering implementations
+ *
+ */
+public interface Cluster extends Model<VectorWritable>, Parametered {
+
+ // default directory for initial clusters to prime iterative clustering
+ // algorithms
+ String INITIAL_CLUSTERS_DIR = "clusters-0";
+
+ // default directory for output of clusters per iteration
+ String CLUSTERS_DIR = "clusters-";
+
+ // default suffix for output of clusters for final iteration
+ String FINAL_ITERATION_SUFFIX = "-final";
+
+ /**
+ * Get the id of the Cluster
+ *
+ * @return a unique integer
+ */
+ int getId();
+
+ /**
+ * Get the "center" of the Cluster as a Vector
+ *
+ * @return a Vector
+ */
+ Vector getCenter();
+
+ /**
+ * Get the "radius" of the Cluster as a Vector. Usually the radius is the
+ * standard deviation expressed as a Vector of size equal to the center. Some
+ * clusters may return zero values if not appropriate.
+ *
+ * @return aVector
+ */
+ Vector getRadius();
+
+ /**
+ * Produce a custom, human-friendly, printable representation of the Cluster.
+ *
+ * @param bindings
+ * an optional String[] containing labels used to format the primary
+ * Vector/s of this implementation.
+ * @return a String
+ */
+ String asFormatString(String[] bindings);
+
+ /**
+ * Produce a JSON representation of the Cluster.
+ *
+ * @param bindings
+ * an optional String[] containing labels used to format the primary
+ * Vector/s of this implementation.
+ * @return a Map
+ */
+ Map<String,Object> asJson(String[] bindings);
+
+ /**
+ * @return if the receiver has converged, or false if that has no meaning for
+ * the implementation
+ */
+ boolean isConverged();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java b/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
new file mode 100644
index 0000000..421ffcf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/ClusteringUtils.java
@@ -0,0 +1,305 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import java.util.List;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.ProjectionSearch;
+import org.apache.mahout.math.neighborhood.Searcher;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+import org.apache.mahout.math.random.WeightedThing;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+public final class ClusteringUtils {
+ private ClusteringUtils() {
+ }
+
+ /**
+ * Computes the summaries for the distances in each cluster.
+ * @param datapoints iterable of datapoints.
+ * @param centroids iterable of Centroids.
+ * @return a list of OnlineSummarizers where the i-th element is the summarizer corresponding to the cluster whose
+ * index is i.
+ */
+ public static List<OnlineSummarizer> summarizeClusterDistances(Iterable<? extends Vector> datapoints,
+ Iterable<? extends Vector> centroids,
+ DistanceMeasure distanceMeasure) {
+ UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1);
+ searcher.addAll(centroids);
+ List<OnlineSummarizer> summarizers = Lists.newArrayList();
+ if (searcher.size() == 0) {
+ return summarizers;
+ }
+ for (int i = 0; i < searcher.size(); ++i) {
+ summarizers.add(new OnlineSummarizer());
+ }
+ for (Vector v : datapoints) {
+ Centroid closest = (Centroid)searcher.search(v, 1).get(0).getValue();
+ OnlineSummarizer summarizer = summarizers.get(closest.getIndex());
+ summarizer.add(distanceMeasure.distance(v, closest));
+ }
+ return summarizers;
+ }
+
+ /**
+ * Adds up the distances from each point to its closest cluster and returns the sum.
+ * @param datapoints iterable of datapoints.
+ * @param centroids iterable of Centroids.
+ * @return the total cost described above.
+ */
+ public static double totalClusterCost(Iterable<? extends Vector> datapoints, Iterable<? extends Vector> centroids) {
+ DistanceMeasure distanceMeasure = new EuclideanDistanceMeasure();
+ UpdatableSearcher searcher = new ProjectionSearch(distanceMeasure, 3, 1);
+ searcher.addAll(centroids);
+ return totalClusterCost(datapoints, searcher);
+ }
+
+ /**
+ * Adds up the distances from each point to its closest cluster and returns the sum.
+ * @param datapoints iterable of datapoints.
+ * @param centroids searcher of Centroids.
+ * @return the total cost described above.
+ */
+ public static double totalClusterCost(Iterable<? extends Vector> datapoints, Searcher centroids) {
+ double totalCost = 0;
+ for (Vector vector : datapoints) {
+ totalCost += centroids.searchFirst(vector, false).getWeight();
+ }
+ return totalCost;
+ }
+
+ /**
+ * Estimates the distance cutoff. In StreamingKMeans, the distance between two vectors divided
+ * by this value is used as a probability threshold when deciding whether to form a new cluster
+ * or not.
+ * Small values (comparable to the minimum distance between two points) are preferred as they
+ * guarantee with high likelihood that all but very close points are put in separate clusters
+ * initially. The clusters themselves are actually collapsed periodically when their number goes
+ * over the maximum number of clusters and the distanceCutoff is increased.
+ * So, the returned value is only an initial estimate.
+ * @param data the datapoints whose distance is to be estimated.
+ * @param distanceMeasure the distance measure used to compute the distance between two points.
+ * @return the minimum distance between the first sampleLimit points
+ * @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans#clusterInternal(Iterable, boolean)
+ */
+ public static double estimateDistanceCutoff(List<? extends Vector> data, DistanceMeasure distanceMeasure) {
+ BruteSearch searcher = new BruteSearch(distanceMeasure);
+ searcher.addAll(data);
+ double minDistance = Double.POSITIVE_INFINITY;
+ for (Vector vector : data) {
+ double closest = searcher.searchFirst(vector, true).getWeight();
+ if (minDistance > 0 && closest < minDistance) {
+ minDistance = closest;
+ }
+ searcher.add(vector);
+ }
+ return minDistance;
+ }
+
+ public static <T extends Vector> double estimateDistanceCutoff(
+ Iterable<T> data, DistanceMeasure distanceMeasure, int sampleLimit) {
+ return estimateDistanceCutoff(Lists.newArrayList(Iterables.limit(data, sampleLimit)), distanceMeasure);
+ }
+
+ /**
+ * Computes the Davies-Bouldin Index for a given clustering.
+ * See http://en.wikipedia.org/wiki/Clustering_algorithm#Internal_evaluation
+ * @param centroids list of centroids
+ * @param distanceMeasure distance measure for inter-cluster distances
+ * @param clusterDistanceSummaries summaries of the clusters; See summarizeClusterDistances
+ * @return the Davies-Bouldin Index
+ */
+ public static double daviesBouldinIndex(List<? extends Vector> centroids, DistanceMeasure distanceMeasure,
+ List<OnlineSummarizer> clusterDistanceSummaries) {
+ Preconditions.checkArgument(centroids.size() == clusterDistanceSummaries.size(),
+ "Number of centroids and cluster summaries differ.");
+ int n = centroids.size();
+ double totalDBIndex = 0;
+ // The inner loop shouldn't be reduced for j = i + 1 to n because the computation of the Davies-Bouldin
+ // index is not really symmetric.
+ // For a given cluster i, we look for a cluster j that maximizes the ratio of the sum of average distances
+ // from points in cluster i to its center and and points in cluster j to its center to the distance between
+ // cluster i and cluster j.
+ // The maximization is the key issue, as the cluster that maximizes this ratio might be j for i but is NOT
+ // NECESSARILY i for j.
+ for (int i = 0; i < n; ++i) {
+ double averageDistanceI = clusterDistanceSummaries.get(i).getMean();
+ double maxDBIndex = 0;
+ for (int j = 0; j < n; ++j) {
+ if (i != j) {
+ double dbIndex = (averageDistanceI + clusterDistanceSummaries.get(j).getMean())
+ / distanceMeasure.distance(centroids.get(i), centroids.get(j));
+ if (dbIndex > maxDBIndex) {
+ maxDBIndex = dbIndex;
+ }
+ }
+ }
+ totalDBIndex += maxDBIndex;
+ }
+ return totalDBIndex / n;
+ }
+
+ /**
+ * Computes the Dunn Index of a given clustering. See http://en.wikipedia.org/wiki/Dunn_index
+ * @param centroids list of centroids
+ * @param distanceMeasure distance measure to compute inter-centroid distance with
+ * @param clusterDistanceSummaries summaries of the clusters; See summarizeClusterDistances
+ * @return the Dunn Index
+ */
+ public static double dunnIndex(List<? extends Vector> centroids, DistanceMeasure distanceMeasure,
+ List<OnlineSummarizer> clusterDistanceSummaries) {
+ Preconditions.checkArgument(centroids.size() == clusterDistanceSummaries.size(),
+ "Number of centroids and cluster summaries differ.");
+ int n = centroids.size();
+ // Intra-cluster distances will come from the OnlineSummarizer, and will be the median distance (noting that
+ // the median for just one value is that value).
+ // A variety of metrics can be used for the intra-cluster distance including max distance between two points,
+ // mean distance, etc. Median distance was chosen as this is more robust to outliers and characterizes the
+ // distribution of distances (from a point to the center) better.
+ double maxIntraClusterDistance = 0;
+ for (OnlineSummarizer summarizer : clusterDistanceSummaries) {
+ if (summarizer.getCount() > 0) {
+ double intraClusterDistance;
+ if (summarizer.getCount() == 1) {
+ intraClusterDistance = summarizer.getMean();
+ } else {
+ intraClusterDistance = summarizer.getMedian();
+ }
+ if (maxIntraClusterDistance < intraClusterDistance) {
+ maxIntraClusterDistance = intraClusterDistance;
+ }
+ }
+ }
+ double minDunnIndex = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < n; ++i) {
+ // Distances are symmetric, so d(i, j) = d(j, i).
+ for (int j = i + 1; j < n; ++j) {
+ double dunnIndex = distanceMeasure.distance(centroids.get(i), centroids.get(j));
+ if (minDunnIndex > dunnIndex) {
+ minDunnIndex = dunnIndex;
+ }
+ }
+ }
+ return minDunnIndex / maxIntraClusterDistance;
+ }
+
+ public static double choose2(double n) {
+ return n * (n - 1) / 2;
+ }
+
+ /**
+ * Creates a confusion matrix by searching for the closest cluster of both the row clustering and column clustering
+ * of a point and adding its weight to that cell of the matrix.
+ * It doesn't matter which clustering is the row clustering and which is the column clustering. If they're
+ * interchanged, the resulting matrix is the transpose of the original one.
+ * @param rowCentroids clustering one
+ * @param columnCentroids clustering two
+ * @param datapoints datapoints whose closest cluster we need to find
+ * @param distanceMeasure distance measure to use
+ * @return the confusion matrix
+ */
+ public static Matrix getConfusionMatrix(List<? extends Vector> rowCentroids, List<? extends Vector> columnCentroids,
+ Iterable<? extends Vector> datapoints, DistanceMeasure distanceMeasure) {
+ Searcher rowSearcher = new BruteSearch(distanceMeasure);
+ rowSearcher.addAll(rowCentroids);
+ Searcher columnSearcher = new BruteSearch(distanceMeasure);
+ columnSearcher.addAll(columnCentroids);
+
+ int numRows = rowCentroids.size();
+ int numCols = columnCentroids.size();
+ Matrix confusionMatrix = new DenseMatrix(numRows, numCols);
+
+ for (Vector vector : datapoints) {
+ WeightedThing<Vector> closestRowCentroid = rowSearcher.search(vector, 1).get(0);
+ WeightedThing<Vector> closestColumnCentroid = columnSearcher.search(vector, 1).get(0);
+ int row = ((Centroid) closestRowCentroid.getValue()).getIndex();
+ int column = ((Centroid) closestColumnCentroid.getValue()).getIndex();
+ double vectorWeight;
+ if (vector instanceof WeightedVector) {
+ vectorWeight = ((WeightedVector) vector).getWeight();
+ } else {
+ vectorWeight = 1;
+ }
+ confusionMatrix.set(row, column, confusionMatrix.get(row, column) + vectorWeight);
+ }
+
+ return confusionMatrix;
+ }
+
+ /**
+ * Computes the Adjusted Rand Index for a given confusion matrix.
+ * @param confusionMatrix confusion matrix; not to be confused with the more restrictive ConfusionMatrix class
+ * @return the Adjusted Rand Index
+ */
+ public static double getAdjustedRandIndex(Matrix confusionMatrix) {
+ int numRows = confusionMatrix.numRows();
+ int numCols = confusionMatrix.numCols();
+ double rowChoiceSum = 0;
+ double columnChoiceSum = 0;
+ double totalChoiceSum = 0;
+ double total = 0;
+ for (int i = 0; i < numRows; ++i) {
+ double rowSum = 0;
+ for (int j = 0; j < numCols; ++j) {
+ rowSum += confusionMatrix.get(i, j);
+ totalChoiceSum += choose2(confusionMatrix.get(i, j));
+ }
+ total += rowSum;
+ rowChoiceSum += choose2(rowSum);
+ }
+ for (int j = 0; j < numCols; ++j) {
+ double columnSum = 0;
+ for (int i = 0; i < numRows; ++i) {
+ columnSum += confusionMatrix.get(i, j);
+ }
+ columnChoiceSum += choose2(columnSum);
+ }
+ double rowColumnChoiceSumDivTotal = rowChoiceSum * columnChoiceSum / choose2(total);
+ return (totalChoiceSum - rowColumnChoiceSumDivTotal)
+ / ((rowChoiceSum + columnChoiceSum) / 2 - rowColumnChoiceSumDivTotal);
+ }
+
+ /**
+ * Computes the total weight of the points in the given Vector iterable.
+ * @param data iterable of points
+ * @return total weight
+ */
+ public static double totalWeight(Iterable<? extends Vector> data) {
+ double sum = 0;
+ for (Vector row : data) {
+ Preconditions.checkNotNull(row);
+ if (row instanceof WeightedVector) {
+ sum += ((WeightedVector)row).getWeight();
+ } else {
+ sum++;
+ }
+ }
+ return sum;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java b/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
new file mode 100644
index 0000000..c25e039
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.math.Vector;
+
+public interface GaussianAccumulator {
+
+ /**
+ * @return the number of observations
+ */
+ double getN();
+
+ /**
+ * @return the mean of the observations
+ */
+ Vector getMean();
+
+ /**
+ * @return the std of the observations
+ */
+ Vector getStd();
+
+ /**
+ * @return the average of the vector std elements
+ */
+ double getAverageStd();
+
+ /**
+ * @return the variance of the observations
+ */
+ Vector getVariance();
+
+ /**
+ * Observe the vector
+ *
+ * @param x a Vector
+ * @param weight the double observation weight (usually 1.0)
+ */
+ void observe(Vector x, double weight);
+
+ /**
+ * Compute the mean, variance and standard deviation
+ */
+ void compute();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/Model.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/Model.java b/mr/src/main/java/org/apache/mahout/clustering/Model.java
new file mode 100644
index 0000000..79dab30
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/Model.java
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * A model is a probability distribution over observed data points and allows
+ * the probability of any data point to be computed. All Models have a
+ * persistent representation and extend
+ * WritablesampleFromPosterior(Model<VectorWritable>[])
+ */
+public interface Model<O> extends Writable {
+
+ /**
+ * Return the probability that the observation is described by this model
+ *
+ * @param x
+ * an Observation from the posterior
+ * @return the probability that x is in the receiver
+ */
+ double pdf(O x);
+
+ /**
+ * Observe the given observation, retaining information about it
+ *
+ * @param x
+ * an Observation from the posterior
+ */
+ void observe(O x);
+
+ /**
+ * Observe the given observation, retaining information about it
+ *
+ * @param x
+ * an Observation from the posterior
+ * @param weight
+ * a double weighting factor
+ */
+ void observe(O x, double weight);
+
+ /**
+ * Observe the given model, retaining information about its observations
+ *
+ * @param x
+ * a Model<0>
+ */
+ void observe(Model<O> x);
+
+ /**
+ * Compute a new set of posterior parameters based upon the Observations that
+ * have been observed since my creation
+ */
+ void computeParameters();
+
+ /**
+ * Return the number of observations that this model has seen since its
+ * parameters were last computed
+ *
+ * @return a long
+ */
+ long getNumObservations();
+
+ /**
+ * Return the number of observations that this model has seen over its
+ * lifetime
+ *
+ * @return a long
+ */
+ long getTotalObservations();
+
+ /**
+ * @return a sample of my posterior model
+ */
+ Model<VectorWritable> sampleFromPosterior();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java b/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java
new file mode 100644
index 0000000..d77bf40
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/ModelDistribution.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+/** A model distribution allows us to sample a model from its prior distribution. */
+public interface ModelDistribution<O> {
+
+ /**
+ * Return a list of models sampled from the prior
+ *
+ * @param howMany
+ * the int number of models to return
+ * @return a Model<Observation>[] representing what is known apriori
+ */
+ Model<O>[] sampleFromPrior(int howMany);
+
+ /**
+ * Return a list of models sampled from the posterior
+ *
+ * @param posterior
+ * the Model<Observation>[] after observations
+ * @return a Model<Observation>[] representing what is known apriori
+ */
+ Model<O>[] sampleFromPosterior(Model<O>[] posterior);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java b/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
new file mode 100644
index 0000000..b76e00f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
@@ -0,0 +1,107 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.SquareRootFunction;
+
+/**
+ * An online Gaussian statistics accumulator based upon Knuth (who cites Welford) which is declared to be
+ * numerically-stable. See http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ */
+public class OnlineGaussianAccumulator implements GaussianAccumulator {
+
+ private double sumWeight;
+ private Vector mean;
+ private Vector s;
+ private Vector variance;
+
+ @Override
+ public double getN() {
+ return sumWeight;
+ }
+
+ @Override
+ public Vector getMean() {
+ return mean;
+ }
+
+ @Override
+ public Vector getStd() {
+ return variance.clone().assign(new SquareRootFunction());
+ }
+
+ /* from Wikipedia: http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ *
+ * Weighted incremental algorithm
+ *
+ * def weighted_incremental_variance(dataWeightPairs):
+ * mean = 0
+ * S = 0
+ * sumweight = 0
+ * for x, weight in dataWeightPairs: # Alternately "for x in zip(data, weight):"
+ * temp = weight + sumweight
+ * Q = x - mean
+ * R = Q * weight / temp
+ * S = S + sumweight * Q * R
+ * mean = mean + R
+ * sumweight = temp
+ * Variance = S / (sumweight-1) # if sample is the population, omit -1
+ * return Variance
+ */
+ @Override
+ public void observe(Vector x, double weight) {
+ double temp = weight + sumWeight;
+ Vector q;
+ if (mean == null) {
+ mean = x.like();
+ q = x.clone();
+ } else {
+ q = x.minus(mean);
+ }
+ Vector r = q.times(weight).divide(temp);
+ if (s == null) {
+ s = q.times(sumWeight).times(r);
+ } else {
+ s = s.plus(q.times(sumWeight).times(r));
+ }
+ mean = mean.plus(r);
+ sumWeight = temp;
+ variance = s.divide(sumWeight - 1); // # if sample is the population, omit -1
+ }
+
+ @Override
+ public void compute() {
+ // nothing to do here!
+ }
+
+ @Override
+ public double getAverageStd() {
+ if (sumWeight == 0.0) {
+ return 0.0;
+ } else {
+ Vector std = getStd();
+ return std.zSum() / std.size();
+ }
+ }
+
+ @Override
+ public Vector getVariance() {
+ return variance;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java b/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
new file mode 100644
index 0000000..138e830
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.SquareRootFunction;
+
+/**
+ * An online Gaussian accumulator that uses a running power sums approach as reported
+ * on http://en.wikipedia.org/wiki/Standard_deviation
+ * Suffers from overflow, underflow and roundoff error but has minimal observe-time overhead
+ */
+public class RunningSumsGaussianAccumulator implements GaussianAccumulator {
+
+ private double s0;
+ private Vector s1;
+ private Vector s2;
+ private Vector mean;
+ private Vector std;
+
+ @Override
+ public double getN() {
+ return s0;
+ }
+
+ @Override
+ public Vector getMean() {
+ return mean;
+ }
+
+ @Override
+ public Vector getStd() {
+ return std;
+ }
+
+ @Override
+ public double getAverageStd() {
+ if (s0 == 0.0) {
+ return 0.0;
+ } else {
+ return std.zSum() / std.size();
+ }
+ }
+
+ @Override
+ public Vector getVariance() {
+ return std.times(std);
+ }
+
+ @Override
+ public void observe(Vector x, double weight) {
+ s0 += weight;
+ Vector weightedX = x.times(weight);
+ if (s1 == null) {
+ s1 = weightedX;
+ } else {
+ s1.assign(weightedX, Functions.PLUS);
+ }
+ Vector x2 = x.times(x).times(weight);
+ if (s2 == null) {
+ s2 = x2;
+ } else {
+ s2.assign(x2, Functions.PLUS);
+ }
+ }
+
+ @Override
+ public void compute() {
+ if (s0 != 0.0) {
+ mean = s1.divide(s0);
+ std = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java b/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java
new file mode 100644
index 0000000..ef43e1b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/UncommonDistributions.java
@@ -0,0 +1,136 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.commons.math3.distribution.NormalDistribution;
+import org.apache.commons.math3.distribution.RealDistribution;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+
+public final class UncommonDistributions {
+
+ private static final RandomWrapper RANDOM = RandomUtils.getRandom();
+
+ private UncommonDistributions() {}
+
+ // =============== start of BSD licensed code. See LICENSE.txt
+ /**
+ * Returns a double sampled according to this distribution. Uniformly fast for all k > 0. (Reference:
+ * Non-Uniform Random Variate Generation, Devroye http://cgm.cs.mcgill.ca/~luc/rnbookindex.html) Uses
+ * Cheng's rejection algorithm (GB) for k>=1, rejection from Weibull distribution for 0 < k < 1.
+ */
+ public static double rGamma(double k, double lambda) {
+ boolean accept = false;
+ if (k >= 1.0) {
+ // Cheng's algorithm
+ double b = k - Math.log(4.0);
+ double c = k + Math.sqrt(2.0 * k - 1.0);
+ double lam = Math.sqrt(2.0 * k - 1.0);
+ double cheng = 1.0 + Math.log(4.5);
+ double x;
+ do {
+ double u = RANDOM.nextDouble();
+ double v = RANDOM.nextDouble();
+ double y = 1.0 / lam * Math.log(v / (1.0 - v));
+ x = k * Math.exp(y);
+ double z = u * v * v;
+ double r = b + c * y - x;
+ if (r >= 4.5 * z - cheng || r >= Math.log(z)) {
+ accept = true;
+ }
+ } while (!accept);
+ return x / lambda;
+ } else {
+ // Weibull algorithm
+ double c = 1.0 / k;
+ double d = (1.0 - k) * Math.pow(k, k / (1.0 - k));
+ double x;
+ do {
+ double u = RANDOM.nextDouble();
+ double v = RANDOM.nextDouble();
+ double z = -Math.log(u);
+ double e = -Math.log(v);
+ x = Math.pow(z, c);
+ if (z + e >= d + x) {
+ accept = true;
+ }
+ } while (!accept);
+ return x / lambda;
+ }
+ }
+
+ // ============= end of BSD licensed code
+
+ /**
+ * Returns a random sample from a beta distribution with the given shapes
+ *
+ * @param shape1
+ * a double representing shape1
+ * @param shape2
+ * a double representing shape2
+ * @return a Vector of samples
+ */
+ public static double rBeta(double shape1, double shape2) {
+ double gam1 = rGamma(shape1, 1.0);
+ double gam2 = rGamma(shape2, 1.0);
+ return gam1 / (gam1 + gam2);
+
+ }
+
+ /**
+ * Return a random value from a normal distribution with the given mean and standard deviation
+ *
+ * @param mean
+ * a double mean value
+ * @param sd
+ * a double standard deviation
+ * @return a double sample
+ */
+ public static double rNorm(double mean, double sd) {
+ RealDistribution dist = new NormalDistribution(RANDOM.getRandomGenerator(),
+ mean,
+ sd,
+ NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY);
+ return dist.sample();
+ }
+
+ /**
+ * Returns an integer sampled according to this distribution. Takes time proportional to np + 1. (Reference:
+ * Non-Uniform Random Variate Generation, Devroye http://cgm.cs.mcgill.ca/~luc/rnbookindex.html) Second
+ * time-waiting algorithm.
+ */
+ public static int rBinomial(int n, double p) {
+ if (p >= 1.0) {
+ return n; // needed to avoid infinite loops and negative results
+ }
+ double q = -Math.log1p(-p);
+ double sum = 0.0;
+ int x = 0;
+ while (sum <= q) {
+ double u = RANDOM.nextDouble();
+ double e = -Math.log(u);
+ sum += e / (n - x);
+ x++;
+ }
+ if (x == 0) {
+ return 0;
+ }
+ return x - 1;
+ }
+
+}
[03/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java b/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java
new file mode 100644
index 0000000..dbb950a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java
@@ -0,0 +1,205 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.mahout.clustering.topdown.postprocessor;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.clustering.classify.WeightedVectorWritable;
+import org.apache.mahout.clustering.topdown.PathDirectory;
+import org.apache.mahout.common.DummyOutputCollector;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public final class ClusterOutputPostProcessorTest extends MahoutTestCase {
+
+ private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4}, {5, 4}, {4, 5}, {5, 5}};
+
+ private FileSystem fs;
+
+ private Path outputPath;
+
+ private Configuration conf;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ Configuration conf = getConfiguration();
+ fs = FileSystem.get(conf);
+ }
+
+ private static List<VectorWritable> getPointsWritable(double[][] raw) {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : raw) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+ /**
+ * Story: User wants to use cluster post processor after canopy clustering and then run clustering on the
+ * output clusters
+ */
+ @Test
+ public void testTopDownClustering() throws Exception {
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ Path pointsPath = getTestTempDirPath("points");
+ conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file2"), fs, conf);
+
+ outputPath = getTestTempDirPath("output");
+
+ topLevelClustering(pointsPath, conf);
+
+ Map<String,Path> postProcessedClusterDirectories = ouputPostProcessing(conf);
+
+ assertPostProcessedOutput(postProcessedClusterDirectories);
+
+ bottomLevelClustering(postProcessedClusterDirectories);
+ }
+
+ private void assertTopLevelCluster(Entry<String,Path> cluster) {
+ String clusterId = cluster.getKey();
+ Path clusterPath = cluster.getValue();
+
+ try {
+ if ("0".equals(clusterId)) {
+ assertPointsInFirstTopLevelCluster(clusterPath);
+ } else if ("1".equals(clusterId)) {
+ assertPointsInSecondTopLevelCluster(clusterPath);
+ }
+ } catch (IOException e) {
+ Assert.fail("Exception occurred while asserting top level cluster.");
+ }
+
+ }
+
+ private void assertPointsInFirstTopLevelCluster(Path clusterPath) throws IOException {
+ List<Vector> vectorsInCluster = getVectorsInCluster(clusterPath);
+ for (Vector vector : vectorsInCluster) {
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:1.0,1:1.0}", "{0:2.0,1:1.0}", "{0:1.0,1:2.0}"},
+ vector.asFormatString()));
+ }
+ }
+
+ private void assertPointsInSecondTopLevelCluster(Path clusterPath) throws IOException {
+ List<Vector> vectorsInCluster = getVectorsInCluster(clusterPath);
+ for (Vector vector : vectorsInCluster) {
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:4.0,1:4.0}", "{0:5.0,1:4.0}", "{0:4.0,1:5.0}",
+ "{0:5.0,1:5.0}"}, vector.asFormatString()));
+ }
+ }
+
+ private List<Vector> getVectorsInCluster(Path clusterPath) throws IOException {
+ Path[] partFilePaths = FileUtil.stat2Paths(fs.globStatus(clusterPath));
+ FileStatus[] listStatus = fs.listStatus(partFilePaths);
+ List<Vector> vectors = Lists.newArrayList();
+ for (FileStatus partFile : listStatus) {
+ SequenceFile.Reader topLevelClusterReader = new SequenceFile.Reader(fs, partFile.getPath(), conf);
+ Writable clusterIdAsKey = new LongWritable();
+ VectorWritable point = new VectorWritable();
+ while (topLevelClusterReader.next(clusterIdAsKey, point)) {
+ vectors.add(point.get());
+ }
+ }
+ return vectors;
+ }
+
+ private void bottomLevelClustering(Map<String,Path> postProcessedClusterDirectories) throws IOException,
+ InterruptedException,
+ ClassNotFoundException {
+ for (Entry<String,Path> topLevelCluster : postProcessedClusterDirectories.entrySet()) {
+ String clusterId = topLevelCluster.getKey();
+ Path topLevelclusterPath = topLevelCluster.getValue();
+
+ Path bottomLevelCluster = PathDirectory.getBottomLevelClusterPath(outputPath, clusterId);
+ CanopyDriver.run(conf, topLevelclusterPath, bottomLevelCluster, new ManhattanDistanceMeasure(), 2.1,
+ 2.0, true, 0.0, true);
+ assertBottomLevelCluster(bottomLevelCluster);
+ }
+ }
+
+ private void assertBottomLevelCluster(Path bottomLevelCluster) {
+ Path clusteredPointsPath = new Path(bottomLevelCluster, "clusteredPoints");
+
+ DummyOutputCollector<IntWritable,WeightedVectorWritable> collector =
+ new DummyOutputCollector<IntWritable,WeightedVectorWritable>();
+
+ // The key is the clusterId, the value is the weighted vector
+ for (Pair<IntWritable,WeightedVectorWritable> record :
+ new SequenceFileIterable<IntWritable,WeightedVectorWritable>(new Path(clusteredPointsPath, "part-m-0"),
+ conf)) {
+ collector.collect(record.getFirst(), record.getSecond());
+ }
+ int clusterSize = collector.getKeys().size();
+ // First top level cluster produces two more clusters, second top level cluster is not broken again
+ assertTrue(clusterSize == 1 || clusterSize == 2);
+
+ }
+
+ private void assertPostProcessedOutput(Map<String,Path> postProcessedClusterDirectories) {
+ for (Entry<String,Path> cluster : postProcessedClusterDirectories.entrySet()) {
+ assertTopLevelCluster(cluster);
+ }
+ }
+
+ private Map<String,Path> ouputPostProcessing(Configuration conf) throws IOException {
+ ClusterOutputPostProcessor clusterOutputPostProcessor = new ClusterOutputPostProcessor(outputPath,
+ outputPath, conf);
+ clusterOutputPostProcessor.process();
+ return clusterOutputPostProcessor.getPostProcessedClusterDirectories();
+ }
+
+ private void topLevelClustering(Path pointsPath, Configuration conf) throws IOException,
+ InterruptedException,
+ ClassNotFoundException {
+ CanopyDriver.run(conf, pointsPath, outputPath, new ManhattanDistanceMeasure(), 3.1, 2.1, true, 0.0, true);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java b/mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java
new file mode 100644
index 0000000..7683b57
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java
@@ -0,0 +1,240 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.junit.Test;
+
+public final class AbstractJobTest extends MahoutTestCase {
+
+ interface AbstractJobFactory {
+ AbstractJob getJob();
+ }
+
+ @Test
+ public void testFlag() throws Exception {
+ final Map<String,List<String>> testMap = Maps.newHashMap();
+
+ AbstractJobFactory fact = new AbstractJobFactory() {
+ @Override
+ public AbstractJob getJob() {
+ return new AbstractJob() {
+ @Override
+ public int run(String[] args) throws IOException {
+ addFlag("testFlag", "t", "a simple test flag");
+
+ Map<String,List<String>> argMap = parseArguments(args);
+ testMap.clear();
+ testMap.putAll(argMap);
+ return 1;
+ }
+ };
+ }
+ };
+
+ // testFlag will only be present if specified on the command-line
+
+ ToolRunner.run(fact.getJob(), new String[0]);
+ assertFalse("test map for absent flag", testMap.containsKey("--testFlag"));
+
+ String[] withFlag = { "--testFlag" };
+ ToolRunner.run(fact.getJob(), withFlag);
+ assertTrue("test map for present flag", testMap.containsKey("--testFlag"));
+ }
+
+ @Test
+ public void testOptions() throws Exception {
+ final Map<String,List<String>> testMap = Maps.newHashMap();
+
+ AbstractJobFactory fact = new AbstractJobFactory() {
+ @Override
+ public AbstractJob getJob() {
+ return new AbstractJob() {
+ @Override
+ public int run(String[] args) throws IOException {
+ this.addOption(DefaultOptionCreator.overwriteOption().create());
+ this.addOption("option", "o", "option");
+ this.addOption("required", "r", "required", true /* required */);
+ this.addOption("notRequired", "nr", "not required", false /* not required */);
+ this.addOption("hasDefault", "hd", "option w/ default", "defaultValue");
+
+
+ Map<String,List<String>> argMap = parseArguments(args);
+ if (argMap == null) {
+ return -1;
+ }
+
+ testMap.clear();
+ testMap.putAll(argMap);
+
+ return 0;
+ }
+ };
+ }
+ };
+
+ int ret = ToolRunner.run(fact.getJob(), new String[0]);
+ assertEquals("-1 for missing required options", -1, ret);
+
+ ret = ToolRunner.run(fact.getJob(), new String[]{
+ "--required", "requiredArg"
+ });
+ assertEquals("0 for no missing required options", 0, ret);
+ assertEquals(Collections.singletonList("requiredArg"), testMap.get("--required"));
+ assertEquals(Collections.singletonList("defaultValue"), testMap.get("--hasDefault"));
+ assertNull(testMap.get("--option"));
+ assertNull(testMap.get("--notRequired"));
+ assertFalse(testMap.containsKey("--overwrite"));
+
+ ret = ToolRunner.run(fact.getJob(), new String[]{
+ "--required", "requiredArg",
+ "--unknownArg"
+ });
+ assertEquals("-1 for including unknown options", -1, ret);
+
+ ret = ToolRunner.run(fact.getJob(), new String[]{
+ "--required", "requiredArg",
+ "--required", "requiredArg2",
+ });
+ assertEquals("-1 for including duplicate options", -1, ret);
+
+ ret = ToolRunner.run(fact.getJob(), new String[]{
+ "--required", "requiredArg",
+ "--overwrite",
+ "--hasDefault", "nonDefault",
+ "--option", "optionValue",
+ "--notRequired", "notRequired"
+ });
+ assertEquals("0 for no missing required options", 0, ret);
+ assertEquals(Collections.singletonList("requiredArg"), testMap.get("--required"));
+ assertEquals(Collections.singletonList("nonDefault"), testMap.get("--hasDefault"));
+ assertEquals(Collections.singletonList("optionValue"), testMap.get("--option"));
+ assertEquals(Collections.singletonList("notRequired"), testMap.get("--notRequired"));
+ assertTrue(testMap.containsKey("--overwrite"));
+
+ ret = ToolRunner.run(fact.getJob(), new String[]{
+ "-r", "requiredArg",
+ "-ow",
+ "-hd", "nonDefault",
+ "-o", "optionValue",
+ "-nr", "notRequired"
+ });
+ assertEquals("0 for no missing required options", 0, ret);
+ assertEquals(Collections.singletonList("requiredArg"), testMap.get("--required"));
+ assertEquals(Collections.singletonList("nonDefault"), testMap.get("--hasDefault"));
+ assertEquals(Collections.singletonList("optionValue"), testMap.get("--option"));
+ assertEquals(Collections.singletonList("notRequired"), testMap.get("--notRequired"));
+ assertTrue(testMap.containsKey("--overwrite"));
+
+ }
+
+ @Test
+ public void testInputOutputPaths() throws Exception {
+
+ AbstractJobFactory fact = new AbstractJobFactory() {
+ @Override
+ public AbstractJob getJob() {
+ return new AbstractJob() {
+ @Override
+ public int run(String[] args) throws IOException {
+ addInputOption();
+ addOutputOption();
+
+ // arg map should be null if a required option is missing.
+ Map<String, List<String>> argMap = parseArguments(args);
+
+ if (argMap == null) {
+ return -1;
+ }
+
+ Path inputPath = getInputPath();
+ assertNotNull("getInputPath() returns non-null", inputPath);
+
+ Path outputPath = getInputPath();
+ assertNotNull("getOutputPath() returns non-null", outputPath);
+ return 0;
+ }
+ };
+ }
+ };
+
+ int ret = ToolRunner.run(fact.getJob(), new String[0]);
+ assertEquals("-1 for missing input option", -1, ret);
+
+ String testInputPath = "testInputPath";
+
+ AbstractJob job = fact.getJob();
+ ret = ToolRunner.run(job, new String[]{
+ "--input", testInputPath });
+ assertEquals("-1 for missing output option", -1, ret);
+ assertEquals("input path is correct", testInputPath, job.getInputPath().toString());
+
+ job = fact.getJob();
+ String testOutputPath = "testOutputPath";
+ ret = ToolRunner.run(job, new String[]{
+ "--output", testOutputPath });
+ assertEquals("-1 for missing input option", -1, ret);
+ assertEquals("output path is correct", testOutputPath, job.getOutputPath().toString());
+
+ job = fact.getJob();
+ ret = ToolRunner.run(job, new String[]{
+ "--input", testInputPath, "--output", testOutputPath });
+ assertEquals("0 for complete options", 0, ret);
+ assertEquals("input path is correct", testInputPath, job.getInputPath().toString());
+ assertEquals("output path is correct", testOutputPath, job.getOutputPath().toString());
+
+ job = fact.getJob();
+ ret = ToolRunner.run(job, new String[]{
+ "--input", testInputPath, "--output", testOutputPath });
+ assertEquals("0 for complete options", 0, ret);
+ assertEquals("input path is correct", testInputPath, job.getInputPath().toString());
+ assertEquals("output path is correct", testOutputPath, job.getOutputPath().toString());
+
+ job = fact.getJob();
+ String testInputPropertyPath = "testInputPropertyPath";
+ String testOutputPropertyPath = "testOutputPropertyPath";
+ ret = ToolRunner.run(job, new String[]{
+ "-Dmapred.input.dir=" + testInputPropertyPath,
+ "-Dmapred.output.dir=" + testOutputPropertyPath });
+ assertEquals("0 for complete options", 0, ret);
+ assertEquals("input path from property is correct", testInputPropertyPath, job.getInputPath().toString());
+ assertEquals("output path from property is correct", testOutputPropertyPath, job.getOutputPath().toString());
+
+ job = fact.getJob();
+ ret = ToolRunner.run(job, new String[]{
+ "-Dmapred.input.dir=" + testInputPropertyPath,
+ "-Dmapred.output.dir=" + testOutputPropertyPath,
+ "--input", testInputPath,
+ "--output", testOutputPath });
+ assertEquals("0 for complete options", 0, ret);
+ assertEquals("input command-line option precedes property",
+ testInputPath, job.getInputPath().toString());
+ assertEquals("output command-line option precedes property",
+ testOutputPath, job.getOutputPath().toString());
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java b/mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java
new file mode 100644
index 0000000..5d3532c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import org.apache.hadoop.fs.Path;
+import org.junit.Test;
+
+import java.io.File;
+import java.net.URI;
+
+
+public class DistributedCacheFileLocationTest extends MahoutTestCase {
+
+ static final File FILE_I_WANT_TO_FIND = new File("file/i_want_to_find.txt");
+ static final URI[] DISTRIBUTED_CACHE_FILES = new URI[] {
+ new File("/first/file").toURI(), new File("/second/file").toURI(), FILE_I_WANT_TO_FIND.toURI() };
+
+ @Test
+ public void nonExistingFile() {
+ Path path = HadoopUtil.findInCacheByPartOfFilename("no such file", DISTRIBUTED_CACHE_FILES);
+ assertNull(path);
+ }
+
+ @Test
+ public void existingFile() {
+ Path path = HadoopUtil.findInCacheByPartOfFilename("want_to_find", DISTRIBUTED_CACHE_FILES);
+ assertNotNull(path);
+ assertEquals(FILE_I_WANT_TO_FIND.getName(), path.getName());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java b/mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java
new file mode 100644
index 0000000..6951f5a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapred.OutputCollector;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.TreeMap;
+
+public final class DummyOutputCollector<K extends WritableComparable, V extends Writable>
+ implements OutputCollector<K,V> {
+
+ private final Map<K, List<V>> data = new TreeMap<K,List<V>>();
+
+ @Override
+ public void collect(K key,V values) {
+ List<V> points = data.get(key);
+ if (points == null) {
+ points = Lists.newArrayList();
+ data.put(key, points);
+ }
+ points.add(values);
+ }
+
+ public Map<K,List<V>> getData() {
+ return data;
+ }
+
+ public List<V> getValue(K key) {
+ return data.get(key);
+ }
+
+ public Set<K> getKeys() {
+ return data.keySet();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java b/mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java
new file mode 100644
index 0000000..7dea174
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.MapContext;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.RecordWriter;
+import org.apache.hadoop.mapreduce.ReduceContext;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.hadoop.mapreduce.TaskAttemptID;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Method;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+public final class DummyRecordWriter<K extends Writable, V extends Writable> extends RecordWriter<K, V> {
+
+ private final List<K> keysInInsertionOrder = Lists.newArrayList();
+ private final Map<K, List<V>> data = Maps.newHashMap();
+
+ @Override
+ public void write(K key, V value) {
+
+ // if the user reuses the same writable class, we need to create a new one
+ // otherwise the Map content will be modified after the insert
+ try {
+
+ K keyToUse = key instanceof NullWritable ? key : (K) cloneWritable(key);
+ V valueToUse = (V) cloneWritable(value);
+
+ keysInInsertionOrder.add(keyToUse);
+
+ List<V> points = data.get(key);
+ if (points == null) {
+ points = Lists.newArrayList();
+ data.put(keyToUse, points);
+ }
+ points.add(valueToUse);
+
+ } catch (IOException e) {
+ throw new RuntimeException(e.getMessage(), e);
+ }
+ }
+
+ private Writable cloneWritable(Writable original) throws IOException {
+
+ Writable clone;
+ try {
+ clone = original.getClass().asSubclass(Writable.class).newInstance();
+ } catch (Exception e) {
+ throw new RuntimeException("Unable to instantiate writable!", e);
+ }
+ ByteArrayOutputStream bytes = new ByteArrayOutputStream();
+
+ original.write(new DataOutputStream(bytes));
+ clone.readFields(new DataInputStream(new ByteArrayInputStream(bytes.toByteArray())));
+
+ return clone;
+ }
+
+ @Override
+ public void close(TaskAttemptContext context) {
+ }
+
+ public Map<K, List<V>> getData() {
+ return data;
+ }
+
+ public List<V> getValue(K key) {
+ return data.get(key);
+ }
+
+ public Set<K> getKeys() {
+ return data.keySet();
+ }
+
+ public Iterable<K> getKeysInInsertionOrder() {
+ return keysInInsertionOrder;
+ }
+
+ public static <K1, V1, K2, V2> Mapper<K1, V1, K2, V2>.Context build(Mapper<K1, V1, K2, V2> mapper,
+ Configuration configuration,
+ RecordWriter<K2, V2> output) {
+
+ // Use reflection since the context types changed incompatibly between 0.20
+ // and 0.23.
+ try {
+ return buildNewMapperContext(configuration, output);
+ } catch (Exception e) {
+ try {
+ return buildOldMapperContext(mapper, configuration, output);
+ } catch (Exception ex) {
+ throw new IllegalStateException(ex);
+ }
+ }
+ }
+
+ public static <K1, V1, K2, V2> Reducer<K1, V1, K2, V2>.Context build(Reducer<K1, V1, K2, V2> reducer,
+ Configuration configuration,
+ RecordWriter<K2, V2> output,
+ Class<K1> keyClass,
+ Class<V1> valueClass) {
+
+ // Use reflection since the context types changed incompatibly between 0.20
+ // and 0.23.
+ try {
+ return buildNewReducerContext(configuration, output, keyClass, valueClass);
+ } catch (Exception e) {
+ try {
+ return buildOldReducerContext(reducer, configuration, output, keyClass, valueClass);
+ } catch (Exception ex) {
+ throw new IllegalStateException(ex);
+ }
+ }
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ private static <K1, V1, K2, V2> Mapper<K1, V1, K2, V2>.Context buildNewMapperContext(
+ Configuration configuration, RecordWriter<K2, V2> output) throws Exception {
+ Class<?> mapContextImplClass = Class.forName("org.apache.hadoop.mapreduce.task.MapContextImpl");
+ Constructor<?> cons = mapContextImplClass.getConstructors()[0];
+ Object mapContextImpl = cons.newInstance(configuration,
+ new TaskAttemptID(), null, output, null, new DummyStatusReporter(), null);
+
+ Class<?> wrappedMapperClass = Class.forName("org.apache.hadoop.mapreduce.lib.map.WrappedMapper");
+ Object wrappedMapper = wrappedMapperClass.getConstructor().newInstance();
+ Method getMapContext = wrappedMapperClass.getMethod("getMapContext", MapContext.class);
+ return (Mapper.Context) getMapContext.invoke(wrappedMapper, mapContextImpl);
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ private static <K1, V1, K2, V2> Mapper<K1, V1, K2, V2>.Context buildOldMapperContext(
+ Mapper<K1, V1, K2, V2> mapper, Configuration configuration,
+ RecordWriter<K2, V2> output) throws Exception {
+ Constructor<?> cons = getNestedContextConstructor(mapper.getClass());
+ // first argument to the constructor is the enclosing instance
+ return (Mapper.Context) cons.newInstance(mapper, configuration,
+ new TaskAttemptID(), null, output, null, new DummyStatusReporter(), null);
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ private static <K1, V1, K2, V2> Reducer<K1, V1, K2, V2>.Context buildNewReducerContext(
+ Configuration configuration, RecordWriter<K2, V2> output, Class<K1> keyClass,
+ Class<V1> valueClass) throws Exception {
+ Class<?> reduceContextImplClass = Class.forName("org.apache.hadoop.mapreduce.task.ReduceContextImpl");
+ Constructor<?> cons = reduceContextImplClass.getConstructors()[0];
+ Object reduceContextImpl = cons.newInstance(configuration,
+ new TaskAttemptID(),
+ new MockIterator(),
+ null,
+ null,
+ output,
+ null,
+ new DummyStatusReporter(),
+ null,
+ keyClass,
+ valueClass);
+
+ Class<?> wrappedReducerClass = Class.forName("org.apache.hadoop.mapreduce.lib.reduce.WrappedReducer");
+ Object wrappedReducer = wrappedReducerClass.getConstructor().newInstance();
+ Method getReducerContext = wrappedReducerClass.getMethod("getReducerContext", ReduceContext.class);
+ return (Reducer.Context) getReducerContext.invoke(wrappedReducer, reduceContextImpl);
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ private static <K1, V1, K2, V2> Reducer<K1, V1, K2, V2>.Context buildOldReducerContext(
+ Reducer<K1, V1, K2, V2> reducer, Configuration configuration,
+ RecordWriter<K2, V2> output, Class<K1> keyClass,
+ Class<V1> valueClass) throws Exception {
+ Constructor<?> cons = getNestedContextConstructor(reducer.getClass());
+ // first argument to the constructor is the enclosing instance
+ return (Reducer.Context) cons.newInstance(reducer,
+ configuration,
+ new TaskAttemptID(),
+ new MockIterator(),
+ null,
+ null,
+ output,
+ null,
+ new DummyStatusReporter(),
+ null,
+ keyClass,
+ valueClass);
+ }
+
+ private static Constructor<?> getNestedContextConstructor(Class<?> outerClass) {
+ for (Class<?> nestedClass : outerClass.getClasses()) {
+ if ("Context".equals(nestedClass.getSimpleName())) {
+ return nestedClass.getConstructors()[0];
+ }
+ }
+ throw new IllegalStateException("Cannot find context class for " + outerClass);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java b/mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java
new file mode 100644
index 0000000..6b25448
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class DummyRecordWriterTest {
+
+ @Test
+ public void testWrite() {
+ DummyRecordWriter<IntWritable, VectorWritable> writer =
+ new DummyRecordWriter<IntWritable, VectorWritable>();
+ IntWritable reusableIntWritable = new IntWritable();
+ VectorWritable reusableVectorWritable = new VectorWritable();
+ reusableIntWritable.set(0);
+ reusableVectorWritable.set(new DenseVector(new double[] { 1, 2, 3 }));
+ writer.write(reusableIntWritable, reusableVectorWritable);
+ reusableIntWritable.set(1);
+ reusableVectorWritable.set(new DenseVector(new double[] { 4, 5, 6 }));
+ writer.write(reusableIntWritable, reusableVectorWritable);
+
+ Assert.assertEquals(
+ "The writer must remember the two keys that is written to it", 2,
+ writer.getKeys().size());
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java b/mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java
new file mode 100644
index 0000000..c6bc34b
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.mahout.common;
+
+import org.easymock.EasyMock;
+
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import org.apache.hadoop.mapreduce.Counter;
+import org.apache.hadoop.mapreduce.StatusReporter;
+
+public final class DummyStatusReporter extends StatusReporter {
+
+ private final Map<Enum<?>, Counter> counters = Maps.newHashMap();
+ private final Map<String, Counter> counterGroups = Maps.newHashMap();
+
+ private static Counter newCounter() {
+ try {
+ // 0.23 case
+ String c = "org.apache.hadoop.mapreduce.counters.GenericCounter";
+ return (Counter) EasyMock.createMockBuilder(Class.forName(c)).createMock();
+ } catch (ClassNotFoundException e) {
+ // 0.20 case
+ return EasyMock.createMockBuilder(Counter.class).createMock();
+ }
+ }
+
+ @Override
+ public Counter getCounter(Enum<?> name) {
+ if (!counters.containsKey(name)) {
+ counters.put(name, newCounter());
+ }
+ return counters.get(name);
+ }
+
+
+ @Override
+ public Counter getCounter(String group, String name) {
+ if (!counterGroups.containsKey(group + name)) {
+ counterGroups.put(group + name, newCounter());
+ }
+ return counterGroups.get(group+name);
+ }
+
+ @Override
+ public void progress() {
+ }
+
+ @Override
+ public void setStatus(String status) {
+ }
+
+ @Override
+ public float getProgress() {
+ return 0.0f;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java b/mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java
new file mode 100644
index 0000000..ceffe3e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/IntPairWritableTest.java
@@ -0,0 +1,114 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.util.Arrays;
+
+import org.junit.Test;
+
+public final class IntPairWritableTest extends MahoutTestCase {
+
+ @Test
+ public void testGetSet() {
+ IntPairWritable n = new IntPairWritable();
+
+ assertEquals(0, n.getFirst());
+ assertEquals(0, n.getSecond());
+
+ n.setFirst(5);
+ n.setSecond(10);
+
+ assertEquals(5, n.getFirst());
+ assertEquals(10, n.getSecond());
+
+ n = new IntPairWritable(2,4);
+
+ assertEquals(2, n.getFirst());
+ assertEquals(4, n.getSecond());
+ }
+
+ @Test
+ public void testWritable() throws Exception {
+ IntPairWritable one = new IntPairWritable(1,2);
+ IntPairWritable two = new IntPairWritable(3,4);
+
+ assertEquals(1, one.getFirst());
+ assertEquals(2, one.getSecond());
+
+ assertEquals(3, two.getFirst());
+ assertEquals(4, two.getSecond());
+
+
+ ByteArrayOutputStream bout = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(bout);
+
+ two.write(out);
+
+ byte[] b = bout.toByteArray();
+
+ ByteArrayInputStream bin = new ByteArrayInputStream(b);
+ DataInput din = new DataInputStream(bin);
+
+ one.readFields(din);
+
+ assertEquals(two.getFirst(), one.getFirst());
+ assertEquals(two.getSecond(), one.getSecond());
+ }
+
+ @Test
+ public void testComparable() {
+ IntPairWritable[] input = {
+ new IntPairWritable(2,3),
+ new IntPairWritable(2,2),
+ new IntPairWritable(1,3),
+ new IntPairWritable(1,2),
+ new IntPairWritable(2,1),
+ new IntPairWritable(2,2),
+ new IntPairWritable(1,-2),
+ new IntPairWritable(1,-1),
+ new IntPairWritable(-2,-2),
+ new IntPairWritable(-2,-1),
+ new IntPairWritable(-1,-1),
+ new IntPairWritable(-1,-2),
+ new IntPairWritable(Integer.MAX_VALUE,1),
+ new IntPairWritable(Integer.MAX_VALUE/2,1),
+ new IntPairWritable(Integer.MIN_VALUE,1),
+ new IntPairWritable(Integer.MIN_VALUE/2,1)
+
+ };
+
+ IntPairWritable[] sorted = new IntPairWritable[input.length];
+ System.arraycopy(input, 0, sorted, 0, input.length);
+ Arrays.sort(sorted);
+
+ int[] expected = {
+ 14, 15, 8, 9, 11, 10, 6, 7, 3, 2, 4, 1, 5, 0, 13, 12
+ };
+
+ for (int i=0; i < input.length; i++) {
+ assertSame(input[expected[i]], sorted[i]);
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java b/mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java
new file mode 100644
index 0000000..775c8d8
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/MahoutTestCase.java
@@ -0,0 +1,148 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.lang.reflect.Field;
+
+import com.google.common.base.Charsets;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.junit.After;
+import org.junit.Before;
+
+public class MahoutTestCase extends org.apache.mahout.math.MahoutTestCase {
+
+ /** "Close enough" value for floating-point comparisons. */
+ public static final double EPSILON = 0.000001;
+
+ private Path testTempDirPath;
+ private FileSystem fs;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ RandomUtils.useTestSeed();
+ testTempDirPath = null;
+ fs = null;
+ }
+
+ @Override
+ @After
+ public void tearDown() throws Exception {
+ if (testTempDirPath != null) {
+ try {
+ fs.delete(testTempDirPath, true);
+ } catch (IOException e) {
+ throw new IllegalStateException("Test file not found");
+ }
+ testTempDirPath = null;
+ fs = null;
+ }
+ super.tearDown();
+ }
+
+ public final Configuration getConfiguration() throws IOException {
+ Configuration conf = new Configuration();
+ conf.set("hadoop.tmp.dir", getTestTempDir("hadoop" + Math.random()).getAbsolutePath());
+ return conf;
+ }
+
+ protected final Path getTestTempDirPath() throws IOException {
+ if (testTempDirPath == null) {
+ fs = FileSystem.get(getConfiguration());
+ long simpleRandomLong = (long) (Long.MAX_VALUE * Math.random());
+ testTempDirPath = fs.makeQualified(
+ new Path("/tmp/mahout-" + getClass().getSimpleName() + '-' + simpleRandomLong));
+ if (!fs.mkdirs(testTempDirPath)) {
+ throw new IOException("Could not create " + testTempDirPath);
+ }
+ fs.deleteOnExit(testTempDirPath);
+ }
+ return testTempDirPath;
+ }
+
+ protected final Path getTestTempFilePath(String name) throws IOException {
+ return getTestTempFileOrDirPath(name, false);
+ }
+
+ protected final Path getTestTempDirPath(String name) throws IOException {
+ return getTestTempFileOrDirPath(name, true);
+ }
+
+ private Path getTestTempFileOrDirPath(String name, boolean dir) throws IOException {
+ Path testTempDirPath = getTestTempDirPath();
+ Path tempFileOrDir = fs.makeQualified(new Path(testTempDirPath, name));
+ fs.deleteOnExit(tempFileOrDir);
+ if (dir && !fs.mkdirs(tempFileOrDir)) {
+ throw new IOException("Could not create " + tempFileOrDir);
+ }
+ return tempFileOrDir;
+ }
+
+ /**
+ * Try to directly set a (possibly private) field on an Object
+ */
+ protected static void setField(Object target, String fieldname, Object value)
+ throws NoSuchFieldException, IllegalAccessException {
+ Field field = findDeclaredField(target.getClass(), fieldname);
+ field.setAccessible(true);
+ field.set(target, value);
+ }
+
+ /**
+ * Find a declared field in a class or one of it's super classes
+ */
+ private static Field findDeclaredField(Class<?> inClass, String fieldname) throws NoSuchFieldException {
+ while (!Object.class.equals(inClass)) {
+ for (Field field : inClass.getDeclaredFields()) {
+ if (field.getName().equalsIgnoreCase(fieldname)) {
+ return field;
+ }
+ }
+ inClass = inClass.getSuperclass();
+ }
+ throw new NoSuchFieldException();
+ }
+
+ /**
+ * @return a job option key string (--name) from the given option name
+ */
+ protected static String optKey(String optionName) {
+ return AbstractJob.keyFor(optionName);
+ }
+
+ protected static void writeLines(File file, String... lines) throws IOException {
+ Writer writer = new OutputStreamWriter(new FileOutputStream(file), Charsets.UTF_8);
+ try {
+ for (String line : lines) {
+ writer.write(line);
+ writer.write('\n');
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/MockIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/MockIterator.java b/mr/src/test/java/org/apache/mahout/common/MockIterator.java
new file mode 100644
index 0000000..ce48fdc
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/MockIterator.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import org.apache.hadoop.io.DataInputBuffer;
+import org.apache.hadoop.mapred.RawKeyValueIterator;
+import org.apache.hadoop.util.Progress;
+
+public final class MockIterator implements RawKeyValueIterator {
+
+ @Override
+ public void close() {
+ }
+
+ @Override
+ public DataInputBuffer getKey() {
+ return null;
+ }
+
+ @Override
+ public Progress getProgress() {
+ return null;
+ }
+
+ @Override
+ public DataInputBuffer getValue() {
+
+ return null;
+ }
+
+ @Override
+ public boolean next() {
+ return true;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java b/mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java
new file mode 100644
index 0000000..0633685
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/StringUtilsTest.java
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import com.google.common.collect.Lists;
+import org.junit.Test;
+
+import java.util.List;
+
+public final class StringUtilsTest extends MahoutTestCase {
+
+ private static class DummyTest {
+ private int field;
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof DummyTest)) {
+ return false;
+ }
+
+ DummyTest dt = (DummyTest) obj;
+ return field == dt.field;
+ }
+
+ @Override
+ public int hashCode() {
+ return field;
+ }
+
+ public int getField() {
+ return field;
+ }
+ }
+
+ @Test
+ public void testStringConversion() throws Exception {
+
+ List<String> expected = Lists.newArrayList("A", "B", "C");
+ assertEquals(expected, StringUtils.fromString(StringUtils
+ .toString(expected)));
+
+ // test a non serializable object
+ DummyTest test = new DummyTest();
+ assertEquals(test, StringUtils.fromString(StringUtils.toString(test)));
+ }
+
+ @Test
+ public void testEscape() throws Exception {
+ String res = StringUtils.escapeXML("\",\',&,>,<");
+ assertEquals("_,_,_,_,_", res);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java b/mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java
new file mode 100644
index 0000000..6db7c9b
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/CosineDistanceMeasureTest.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public final class CosineDistanceMeasureTest extends MahoutTestCase {
+
+ @Test
+ public void testMeasure() {
+
+ DistanceMeasure distanceMeasure = new CosineDistanceMeasure();
+
+ Vector[] vectors = {
+ new DenseVector(new double[]{1, 0, 0, 0, 0, 0}),
+ new DenseVector(new double[]{1, 1, 1, 0, 0, 0}),
+ new DenseVector(new double[]{1, 1, 1, 1, 1, 1})
+ };
+
+ double[][] distanceMatrix = new double[3][3];
+
+ for (int a = 0; a < 3; a++) {
+ for (int b = 0; b < 3; b++) {
+ distanceMatrix[a][b] = distanceMeasure.distance(vectors[a], vectors[b]);
+ }
+ }
+
+ assertEquals(0.0, distanceMatrix[0][0], EPSILON);
+ assertTrue(distanceMatrix[0][0] < distanceMatrix[0][1]);
+ assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]);
+
+ assertEquals(0.0, distanceMatrix[1][1], EPSILON);
+ assertTrue(distanceMatrix[1][0] > distanceMatrix[1][1]);
+ assertTrue(distanceMatrix[1][2] < distanceMatrix[1][0]);
+
+ assertEquals(0.0, distanceMatrix[2][2], EPSILON);
+ assertTrue(distanceMatrix[2][0] > distanceMatrix[2][1]);
+ assertTrue(distanceMatrix[2][1] > distanceMatrix[2][2]);
+
+ // Two equal vectors (despite them being zero) should have 0 distance.
+ assertEquals(0,
+ distanceMeasure.distance(new SequentialAccessSparseVector(1),
+ new SequentialAccessSparseVector(1)),
+ EPSILON);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java b/mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java
new file mode 100644
index 0000000..ad1608c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/DefaultDistanceMeasureTest.java
@@ -0,0 +1,103 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public abstract class DefaultDistanceMeasureTest extends MahoutTestCase {
+
+ protected abstract DistanceMeasure distanceMeasureFactory();
+
+ @Test
+ public void testMeasure() {
+
+ DistanceMeasure distanceMeasure = distanceMeasureFactory();
+
+ Vector[] vectors = {
+ new DenseVector(new double[]{1, 1, 1, 1, 1, 1}),
+ new DenseVector(new double[]{2, 2, 2, 2, 2, 2}),
+ new DenseVector(new double[]{6, 6, 6, 6, 6, 6}),
+ new DenseVector(new double[]{-1,-1,-1,-1,-1,-1})
+ };
+
+ compare(distanceMeasure, vectors);
+
+ vectors = new Vector[4];
+
+ vectors[0] = new RandomAccessSparseVector(5);
+ vectors[0].setQuick(0, 1);
+ vectors[0].setQuick(3, 1);
+ vectors[0].setQuick(4, 1);
+
+ vectors[1] = new RandomAccessSparseVector(5);
+ vectors[1].setQuick(0, 2);
+ vectors[1].setQuick(3, 2);
+ vectors[1].setQuick(4, 2);
+
+ vectors[2] = new RandomAccessSparseVector(5);
+ vectors[2].setQuick(0, 6);
+ vectors[2].setQuick(3, 6);
+ vectors[2].setQuick(4, 6);
+
+ vectors[3] = new RandomAccessSparseVector(5);
+
+ compare(distanceMeasure, vectors);
+ }
+
+ private static void compare(DistanceMeasure distanceMeasure, Vector[] vectors) {
+ double[][] distanceMatrix = new double[4][4];
+
+ for (int a = 0; a < 4; a++) {
+ for (int b = 0; b < 4; b++) {
+ distanceMatrix[a][b] = distanceMeasure.distance(vectors[a], vectors[b]);
+ }
+ }
+
+ assertEquals("Distance from first vector to itself is not zero", 0.0, distanceMatrix[0][0], EPSILON);
+ assertTrue(distanceMatrix[0][0] < distanceMatrix[0][1]);
+ assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]);
+
+ assertEquals("Distance from second vector to itself is not zero", 0.0, distanceMatrix[1][1], EPSILON);
+ assertTrue(distanceMatrix[1][0] > distanceMatrix[1][1]);
+ assertTrue(distanceMatrix[1][2] > distanceMatrix[1][0]);
+
+ assertEquals("Distance from third vector to itself is not zero", 0.0, distanceMatrix[2][2], EPSILON);
+ assertTrue(distanceMatrix[2][0] > distanceMatrix[2][1]);
+ assertTrue(distanceMatrix[2][1] > distanceMatrix[2][2]);
+
+ for (int a = 0; a < 4; a++) {
+ for (int b = 0; b < 4; b++) {
+ assertTrue("Distance between vectors less than zero: "
+ + distanceMatrix[a][b] + " = " + distanceMeasure
+ + ".distance("+ vectors[a].asFormatString() + ", "
+ + vectors[b].asFormatString() + ')',
+ distanceMatrix[a][b] >= 0);
+ if (vectors[a].plus(vectors[b]).norm(2) == 0 && vectors[a].norm(2) > 0) {
+ assertTrue("Distance from v to -v is equal to zero"
+ + vectors[a].asFormatString() + " = -" + vectors[b].asFormatString(),
+ distanceMatrix[a][b] > 0);
+ }
+ }
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java b/mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java
new file mode 100644
index 0000000..a8f1d0b
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/DefaultWeightedDistanceMeasureTest.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public abstract class DefaultWeightedDistanceMeasureTest extends DefaultDistanceMeasureTest {
+
+ @Override
+ public abstract WeightedDistanceMeasure distanceMeasureFactory();
+
+ @Test
+ public void testMeasureWeighted() {
+
+ WeightedDistanceMeasure distanceMeasure = distanceMeasureFactory();
+
+ Vector[] vectors = {
+ new DenseVector(new double[]{9, 9, 1}),
+ new DenseVector(new double[]{1, 9, 9}),
+ new DenseVector(new double[]{9, 1, 9}),
+ };
+ distanceMeasure.setWeights(new DenseVector(new double[]{1, 1000, 1}));
+
+ double[][] distanceMatrix = new double[3][3];
+
+ for (int a = 0; a < 3; a++) {
+ for (int b = 0; b < 3; b++) {
+ distanceMatrix[a][b] = distanceMeasure.distance(vectors[a], vectors[b]);
+ }
+ }
+
+ assertEquals(0.0, distanceMatrix[0][0], EPSILON);
+ assertTrue(distanceMatrix[0][1] < distanceMatrix[0][2]);
+
+
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java
new file mode 100644
index 0000000..185adf3
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/TestChebyshevMeasure.java
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public final class TestChebyshevMeasure extends MahoutTestCase {
+
+ @Test
+ public void testMeasure() {
+
+ DistanceMeasure chebyshevDistanceMeasure = new ChebyshevDistanceMeasure();
+
+ Vector[] vectors = {
+ new DenseVector(new double[]{1, 0, 0, 0, 0, 0}),
+ new DenseVector(new double[]{1, 1, 1, 0, 0, 0}),
+ new DenseVector(new double[]{1, 1, 1, 1, 1, 1})
+ };
+ double[][] distances = {{0.0, 1.0, 1.0}, {1.0, 0.0, 1.0}, {1.0, 1.0, 0.0}};
+
+ double[][] chebyshevDistanceMatrix = new double[3][3];
+
+ for (int a = 0; a < 3; a++) {
+ for (int b = 0; b < 3; b++) {
+ chebyshevDistanceMatrix[a][b] = chebyshevDistanceMeasure.distance(vectors[a], vectors[b]);
+ }
+ }
+ for (int a = 0; a < 3; a++) {
+ for (int b = 0; b < 3; b++) {
+ assertEquals(distances[a][b], chebyshevDistanceMatrix[a][b], EPSILON);
+ }
+ }
+
+ assertEquals(0.0, chebyshevDistanceMatrix[0][0], EPSILON);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java
new file mode 100644
index 0000000..cc9e9e7
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/TestEuclideanDistanceMeasure.java
@@ -0,0 +1,26 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+public final class TestEuclideanDistanceMeasure extends DefaultDistanceMeasureTest {
+
+ @Override
+ public DistanceMeasure distanceMeasureFactory() {
+ return new EuclideanDistanceMeasure();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java
new file mode 100644
index 0000000..8e3d205
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/TestMahalanobisDistanceMeasure.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+
+/**
+ * To launch this test only : mvn test -Dtest=org.apache.mahout.common.distance.TestMahalanobisDistanceMeasure
+ */
+public final class TestMahalanobisDistanceMeasure extends MahoutTestCase {
+
+ @Test
+ public void testMeasure() {
+ double[][] invCovValues = { { 2.2, 0.4 }, { 0.4, 2.8 } };
+ double[] meanValues = { -2.3, -0.9 };
+ Matrix invCov = new DenseMatrix(invCovValues);
+ Vector meanVector = new DenseVector(meanValues);
+ MahalanobisDistanceMeasure distanceMeasure = new MahalanobisDistanceMeasure();
+ distanceMeasure.setInverseCovarianceMatrix(invCov);
+ distanceMeasure.setMeanVector(meanVector);
+ double[] v1 = { -1.9, -2.3 };
+ double[] v2 = { -2.9, -1.3 };
+ double dist = distanceMeasure.distance(new DenseVector(v1),new DenseVector(v2));
+ assertEquals(2.0493901531919194, dist, EPSILON);
+ //now set the covariance Matrix
+ distanceMeasure.setCovarianceMatrix(invCov);
+ //check the inverse covariance times covariance equals identity
+ Matrix identity = distanceMeasure.getInverseCovarianceMatrix().times(invCov);
+ assertEquals(1, identity.get(0,0), EPSILON);
+ assertEquals(1, identity.get(1,1), EPSILON);
+ assertEquals(0, identity.get(1,0), EPSILON);
+ assertEquals(0, identity.get(0,1), EPSILON);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java
new file mode 100644
index 0000000..97a5612
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/TestManhattanDistanceMeasure.java
@@ -0,0 +1,26 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+public final class TestManhattanDistanceMeasure extends DefaultDistanceMeasureTest {
+
+ @Override
+ public DistanceMeasure distanceMeasureFactory() {
+ return new ManhattanDistanceMeasure();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java
new file mode 100644
index 0000000..d2cd85e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/TestMinkowskiMeasure.java
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public final class TestMinkowskiMeasure extends MahoutTestCase {
+
+ @Test
+ public void testMeasure() {
+
+ DistanceMeasure minkowskiDistanceMeasure = new MinkowskiDistanceMeasure(1.5);
+ DistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
+ DistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
+
+ Vector[] vectors = {
+ new DenseVector(new double[]{1, 0, 0, 0, 0, 0}),
+ new DenseVector(new double[]{1, 1, 1, 0, 0, 0}),
+ new DenseVector(new double[]{1, 1, 1, 1, 1, 1})
+ };
+
+ double[][] minkowskiDistanceMatrix = new double[3][3];
+ double[][] manhattanDistanceMatrix = new double[3][3];
+ double[][] euclideanDistanceMatrix = new double[3][3];
+
+ for (int a = 0; a < 3; a++) {
+ for (int b = 0; b < 3; b++) {
+ minkowskiDistanceMatrix[a][b] = minkowskiDistanceMeasure.distance(vectors[a], vectors[b]);
+ manhattanDistanceMatrix[a][b] = manhattanDistanceMeasure.distance(vectors[a], vectors[b]);
+ euclideanDistanceMatrix[a][b] = euclideanDistanceMeasure.distance(vectors[a], vectors[b]);
+ }
+ }
+
+ for (int a = 0; a < 3; a++) {
+ for (int b = 0; b < 3; b++) {
+ assertTrue(minkowskiDistanceMatrix[a][b] <= manhattanDistanceMatrix[a][b]);
+ assertTrue(minkowskiDistanceMatrix[a][b] >= euclideanDistanceMatrix[a][b]);
+ }
+ }
+
+ assertEquals(0.0, minkowskiDistanceMatrix[0][0], EPSILON);
+ assertTrue(minkowskiDistanceMatrix[0][0] < minkowskiDistanceMatrix[0][1]);
+ assertTrue(minkowskiDistanceMatrix[0][1] < minkowskiDistanceMatrix[0][2]);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java
new file mode 100644
index 0000000..01f9134
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/TestTanimotoDistanceMeasure.java
@@ -0,0 +1,25 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+public final class TestTanimotoDistanceMeasure extends DefaultWeightedDistanceMeasureTest {
+ @Override
+ public TanimotoDistanceMeasure distanceMeasureFactory() {
+ return new TanimotoDistanceMeasure();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java b/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java
new file mode 100644
index 0000000..b99d165
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedEuclideanDistanceMeasureTest.java
@@ -0,0 +1,25 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+public final class TestWeightedEuclideanDistanceMeasureTest extends DefaultWeightedDistanceMeasureTest {
+ @Override
+ public WeightedDistanceMeasure distanceMeasureFactory() {
+ return new WeightedEuclideanDistanceMeasure();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java b/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java
new file mode 100644
index 0000000..77d4a01
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/distance/TestWeightedManhattanDistanceMeasure.java
@@ -0,0 +1,26 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+public final class TestWeightedManhattanDistanceMeasure extends DefaultWeightedDistanceMeasureTest {
+
+ @Override
+ public WeightedManhattanDistanceMeasure distanceMeasureFactory() {
+ return new WeightedManhattanDistanceMeasure();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java b/mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java
new file mode 100644
index 0000000..d38178c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/iterator/CountingIteratorTest.java
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Iterator;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public final class CountingIteratorTest extends MahoutTestCase {
+
+ @Test
+ public void testEmptyCase() {
+ assertFalse(new CountingIterator(0).hasNext());
+ }
+
+ @Test
+ public void testCount() {
+ Iterator<Integer> it = new CountingIterator(3);
+ assertTrue(it.hasNext());
+ assertEquals(0, (int) it.next());
+ assertTrue(it.hasNext());
+ assertEquals(1, (int) it.next());
+ assertTrue(it.hasNext());
+ assertEquals(2, (int) it.next());
+ assertFalse(it.hasNext());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java b/mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java
new file mode 100644
index 0000000..b67d34b
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/iterator/SamplerCase.java
@@ -0,0 +1,101 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public abstract class SamplerCase extends MahoutTestCase {
+ // these provide access to the underlying implementation
+ protected abstract Iterator<Integer> createSampler(int n, Iterator<Integer> source);
+
+ protected abstract boolean isSorted();
+
+ @Test
+ public void testEmptyCase() {
+ assertFalse(createSampler(100, new CountingIterator(0)).hasNext());
+ }
+
+ @Test
+ public void testSmallInput() {
+ Iterator<Integer> t = createSampler(10, new CountingIterator(1));
+ assertTrue(t.hasNext());
+ assertEquals(0, t.next().intValue());
+ assertFalse(t.hasNext());
+
+ t = createSampler(10, new CountingIterator(1));
+ assertTrue(t.hasNext());
+ assertEquals(0, t.next().intValue());
+ assertFalse(t.hasNext());
+ }
+
+ @Test
+ public void testAbsurdSize() {
+ Iterator<Integer> t = createSampler(0, new CountingIterator(2));
+ assertFalse(t.hasNext());
+ }
+
+ @Test
+ public void testExactSizeMatch() {
+ Iterator<Integer> t = createSampler(10, new CountingIterator(10));
+ for (int i = 0; i < 10; i++) {
+ assertTrue(t.hasNext());
+ assertEquals(i, t.next().intValue());
+ }
+ assertFalse(t.hasNext());
+ }
+
+ @Test
+ public void testSample() {
+ Iterator<Integer> source = new CountingIterator(100);
+ Iterator<Integer> t = createSampler(15, source);
+
+ // this is just a regression test, not a real test
+ List<Integer> expectedValues = Arrays.asList(52,28,2,60,50,32,65,79,78,9,40,33,96,25,48);
+ if (isSorted()) {
+ Collections.sort(expectedValues);
+ }
+ Iterator<Integer> expected = expectedValues.iterator();
+ int last = Integer.MIN_VALUE;
+ for (int i = 0; i < 15; i++) {
+ assertTrue(t.hasNext());
+ int actual = t.next();
+ if (isSorted()) {
+ assertTrue(actual >= last);
+ last = actual;
+ } else {
+ // any of the first few values should be in the original places
+ if (actual < 15) {
+ assertEquals(i, actual);
+ }
+ }
+
+ assertTrue(actual >= 0 && actual < 100);
+
+ // this is just a regression test, but still of some value
+ assertEquals(expected.next().intValue(), actual);
+ assertFalse(source.hasNext());
+ }
+ assertFalse(t.hasNext());
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java b/mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java
new file mode 100644
index 0000000..470e6d8
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/iterator/TestFixedSizeSampler.java
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Iterator;
+
+public final class TestFixedSizeSampler extends SamplerCase {
+
+ @Override
+ protected Iterator<Integer> createSampler(int n, Iterator<Integer> source) {
+ return new FixedSizeSamplingIterator<Integer>(n, source);
+ }
+
+ @Override
+ protected boolean isSorted() {
+ return false;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java b/mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java
new file mode 100644
index 0000000..970ea79
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/iterator/TestSamplingIterator.java
@@ -0,0 +1,77 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Iterator;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public final class TestSamplingIterator extends MahoutTestCase {
+
+ @Test
+ public void testEmptyCase() {
+ assertFalse(new SamplingIterator<Integer>(new CountingIterator(0), 0.9999).hasNext());
+ assertFalse(new SamplingIterator<Integer>(new CountingIterator(0), 1).hasNext());
+ }
+
+ @Test
+ public void testSmallInput() {
+ Iterator<Integer> t = new SamplingIterator<Integer>(new CountingIterator(1), 0.9999);
+ assertTrue(t.hasNext());
+ assertEquals(0, t.next().intValue());
+ assertFalse(t.hasNext());
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testBadRate1() {
+ new SamplingIterator<Integer>(new CountingIterator(1), 0.0);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testBadRate2() {
+ new SamplingIterator<Integer>(new CountingIterator(1), 1.1);
+ }
+
+ @Test
+ public void testExactSizeMatch() {
+ Iterator<Integer> t = new SamplingIterator<Integer>(new CountingIterator(10), 1);
+ for (int i = 0; i < 10; i++) {
+ assertTrue(t.hasNext());
+ assertEquals(i, t.next().intValue());
+ }
+ assertFalse(t.hasNext());
+ }
+
+ @Test
+ public void testSample() {
+ for (int i = 0; i < 1000; i++) {
+ Iterator<Integer> t = new SamplingIterator<Integer>(new CountingIterator(1000), 0.1);
+ int k = 0;
+ while (t.hasNext()) {
+ int v = t.next();
+ k++;
+ assertTrue(v >= 0);
+ assertTrue(v < 1000);
+ }
+ double sd = Math.sqrt(0.9 * 0.1 * 1000);
+ assertTrue(k >= 100 - 4 * sd);
+ assertTrue(k <= 100 + 4 * sd);
+ }
+ }
+}
[41/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
new file mode 100644
index 0000000..33be59d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizer.java
@@ -0,0 +1,313 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
+import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * factorizes the rating matrix using "Alternating-Least-Squares with Weighted-λ-Regularization" as described in
+ * <a href="http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf">
+ * "Large-scale Collaborative Filtering for the Netflix Prize"</a>
+ *
+ * also supports the implicit feedback variant of this approach as described in "Collaborative Filtering for Implicit
+ * Feedback Datasets" available at http://research.yahoo.com/pub/2433
+ */
+public class ALSWRFactorizer extends AbstractFactorizer {
+
+ private final DataModel dataModel;
+
+ /** number of features used to compute this factorization */
+ private final int numFeatures;
+ /** parameter to control the regularization */
+ private final double lambda;
+ /** number of iterations */
+ private final int numIterations;
+
+ private final boolean usesImplicitFeedback;
+ /** confidence weighting parameter, only necessary when working with implicit feedback */
+ private final double alpha;
+
+ private final int numTrainingThreads;
+
+ private static final double DEFAULT_ALPHA = 40;
+
+ private static final Logger log = LoggerFactory.getLogger(ALSWRFactorizer.class);
+
+ public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ boolean usesImplicitFeedback, double alpha, int numTrainingThreads) throws TasteException {
+ super(dataModel);
+ this.dataModel = dataModel;
+ this.numFeatures = numFeatures;
+ this.lambda = lambda;
+ this.numIterations = numIterations;
+ this.usesImplicitFeedback = usesImplicitFeedback;
+ this.alpha = alpha;
+ this.numTrainingThreads = numTrainingThreads;
+ }
+
+ public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ boolean usesImplicitFeedback, double alpha) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, usesImplicitFeedback, alpha,
+ Runtime.getRuntime().availableProcessors());
+ }
+
+ public ALSWRFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, false, DEFAULT_ALPHA);
+ }
+
+ static class Features {
+
+ private final DataModel dataModel;
+ private final int numFeatures;
+
+ private final double[][] M;
+ private final double[][] U;
+
+ Features(ALSWRFactorizer factorizer) throws TasteException {
+ dataModel = factorizer.dataModel;
+ numFeatures = factorizer.numFeatures;
+ Random random = RandomUtils.getRandom();
+ M = new double[dataModel.getNumItems()][numFeatures];
+ LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
+ while (itemIDsIterator.hasNext()) {
+ long itemID = itemIDsIterator.nextLong();
+ int itemIDIndex = factorizer.itemIndex(itemID);
+ M[itemIDIndex][0] = averateRating(itemID);
+ for (int feature = 1; feature < numFeatures; feature++) {
+ M[itemIDIndex][feature] = random.nextDouble() * 0.1;
+ }
+ }
+ U = new double[dataModel.getNumUsers()][numFeatures];
+ }
+
+ double[][] getM() {
+ return M;
+ }
+
+ double[][] getU() {
+ return U;
+ }
+
+ Vector getUserFeatureColumn(int index) {
+ return new DenseVector(U[index]);
+ }
+
+ Vector getItemFeatureColumn(int index) {
+ return new DenseVector(M[index]);
+ }
+
+ void setFeatureColumnInU(int idIndex, Vector vector) {
+ setFeatureColumn(U, idIndex, vector);
+ }
+
+ void setFeatureColumnInM(int idIndex, Vector vector) {
+ setFeatureColumn(M, idIndex, vector);
+ }
+
+ protected void setFeatureColumn(double[][] matrix, int idIndex, Vector vector) {
+ for (int feature = 0; feature < numFeatures; feature++) {
+ matrix[idIndex][feature] = vector.get(feature);
+ }
+ }
+
+ protected double averateRating(long itemID) throws TasteException {
+ PreferenceArray prefs = dataModel.getPreferencesForItem(itemID);
+ RunningAverage avg = new FullRunningAverage();
+ for (Preference pref : prefs) {
+ avg.addDatum(pref.getValue());
+ }
+ return avg.getAverage();
+ }
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ log.info("starting to compute the factorization...");
+ final Features features = new Features(this);
+
+ /* feature maps necessary for solving for implicit feedback */
+ OpenIntObjectHashMap<Vector> userY = null;
+ OpenIntObjectHashMap<Vector> itemY = null;
+
+ if (usesImplicitFeedback) {
+ userY = userFeaturesMapping(dataModel.getUserIDs(), dataModel.getNumUsers(), features.getU());
+ itemY = itemFeaturesMapping(dataModel.getItemIDs(), dataModel.getNumItems(), features.getM());
+ }
+
+ for (int iteration = 0; iteration < numIterations; iteration++) {
+ log.info("iteration {}", iteration);
+
+ /* fix M - compute U */
+ ExecutorService queue = createQueue();
+ LongPrimitiveIterator userIDsIterator = dataModel.getUserIDs();
+ try {
+
+ final ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackSolver = usesImplicitFeedback
+ ? new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, itemY, numTrainingThreads)
+ : null;
+
+ while (userIDsIterator.hasNext()) {
+ final long userID = userIDsIterator.nextLong();
+ final LongPrimitiveIterator itemIDsFromUser = dataModel.getItemIDsFromUser(userID).iterator();
+ final PreferenceArray userPrefs = dataModel.getPreferencesFromUser(userID);
+ queue.execute(new Runnable() {
+ @Override
+ public void run() {
+ List<Vector> featureVectors = Lists.newArrayList();
+ while (itemIDsFromUser.hasNext()) {
+ long itemID = itemIDsFromUser.nextLong();
+ featureVectors.add(features.getItemFeatureColumn(itemIndex(itemID)));
+ }
+
+ Vector userFeatures = usesImplicitFeedback
+ ? implicitFeedbackSolver.solve(sparseUserRatingVector(userPrefs))
+ : AlternatingLeastSquaresSolver.solve(featureVectors, ratingVector(userPrefs), lambda, numFeatures);
+
+ features.setFeatureColumnInU(userIndex(userID), userFeatures);
+ }
+ });
+ }
+ } finally {
+ queue.shutdown();
+ try {
+ queue.awaitTermination(dataModel.getNumUsers(), TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ log.warn("Error when computing user features", e);
+ }
+ }
+
+ /* fix U - compute M */
+ queue = createQueue();
+ LongPrimitiveIterator itemIDsIterator = dataModel.getItemIDs();
+ try {
+
+ final ImplicitFeedbackAlternatingLeastSquaresSolver implicitFeedbackSolver = usesImplicitFeedback
+ ? new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha, userY, numTrainingThreads)
+ : null;
+
+ while (itemIDsIterator.hasNext()) {
+ final long itemID = itemIDsIterator.nextLong();
+ final PreferenceArray itemPrefs = dataModel.getPreferencesForItem(itemID);
+ queue.execute(new Runnable() {
+ @Override
+ public void run() {
+ List<Vector> featureVectors = Lists.newArrayList();
+ for (Preference pref : itemPrefs) {
+ long userID = pref.getUserID();
+ featureVectors.add(features.getUserFeatureColumn(userIndex(userID)));
+ }
+
+ Vector itemFeatures = usesImplicitFeedback
+ ? implicitFeedbackSolver.solve(sparseItemRatingVector(itemPrefs))
+ : AlternatingLeastSquaresSolver.solve(featureVectors, ratingVector(itemPrefs), lambda, numFeatures);
+
+ features.setFeatureColumnInM(itemIndex(itemID), itemFeatures);
+ }
+ });
+ }
+ } finally {
+ queue.shutdown();
+ try {
+ queue.awaitTermination(dataModel.getNumItems(), TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ log.warn("Error when computing item features", e);
+ }
+ }
+ }
+
+ log.info("finished computation of the factorization...");
+ return createFactorization(features.getU(), features.getM());
+ }
+
+ protected ExecutorService createQueue() {
+ return Executors.newFixedThreadPool(numTrainingThreads);
+ }
+
+ protected static Vector ratingVector(PreferenceArray prefs) {
+ double[] ratings = new double[prefs.length()];
+ for (int n = 0; n < prefs.length(); n++) {
+ ratings[n] = prefs.get(n).getValue();
+ }
+ return new DenseVector(ratings, true);
+ }
+
+ //TODO find a way to get rid of the object overhead here
+ protected OpenIntObjectHashMap<Vector> itemFeaturesMapping(LongPrimitiveIterator itemIDs, int numItems,
+ double[][] featureMatrix) {
+ OpenIntObjectHashMap<Vector> mapping = new OpenIntObjectHashMap<Vector>(numItems);
+ while (itemIDs.hasNext()) {
+ long itemID = itemIDs.next();
+ int itemIndex = itemIndex(itemID);
+ mapping.put(itemIndex, new DenseVector(featureMatrix[itemIndex(itemID)], true));
+ }
+
+ return mapping;
+ }
+
+ protected OpenIntObjectHashMap<Vector> userFeaturesMapping(LongPrimitiveIterator userIDs, int numUsers,
+ double[][] featureMatrix) {
+ OpenIntObjectHashMap<Vector> mapping = new OpenIntObjectHashMap<Vector>(numUsers);
+
+ while (userIDs.hasNext()) {
+ long userID = userIDs.next();
+ int userIndex = userIndex(userID);
+ mapping.put(userIndex, new DenseVector(featureMatrix[userIndex(userID)], true));
+ }
+
+ return mapping;
+ }
+
+ protected Vector sparseItemRatingVector(PreferenceArray prefs) {
+ SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length());
+ for (Preference preference : prefs) {
+ ratings.set(userIndex(preference.getUserID()), preference.getValue());
+ }
+ return ratings;
+ }
+
+ protected Vector sparseUserRatingVector(PreferenceArray prefs) {
+ SequentialAccessSparseVector ratings = new SequentialAccessSparseVector(Integer.MAX_VALUE, prefs.length());
+ for (Preference preference : prefs) {
+ ratings.set(itemIndex(preference.getItemID()), preference.getValue());
+ }
+ return ratings;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
new file mode 100644
index 0000000..5225222
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/AbstractFactorizer.java
@@ -0,0 +1,94 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+/**
+ * base class for {@link Factorizer}s, provides ID to index mapping
+ */
+public abstract class AbstractFactorizer implements Factorizer {
+
+ private final DataModel dataModel;
+ private FastByIDMap<Integer> userIDMapping;
+ private FastByIDMap<Integer> itemIDMapping;
+ private final RefreshHelper refreshHelper;
+
+ protected AbstractFactorizer(DataModel dataModel) throws TasteException {
+ this.dataModel = dataModel;
+ buildMappings();
+ refreshHelper = new RefreshHelper(new Callable<Object>() {
+ @Override
+ public Object call() throws TasteException {
+ buildMappings();
+ return null;
+ }
+ });
+ refreshHelper.addDependency(dataModel);
+ }
+
+ private void buildMappings() throws TasteException {
+ userIDMapping = createIDMapping(dataModel.getNumUsers(), dataModel.getUserIDs());
+ itemIDMapping = createIDMapping(dataModel.getNumItems(), dataModel.getItemIDs());
+ }
+
+ protected Factorization createFactorization(double[][] userFeatures, double[][] itemFeatures) {
+ return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
+ }
+
+ protected Integer userIndex(long userID) {
+ Integer userIndex = userIDMapping.get(userID);
+ if (userIndex == null) {
+ userIndex = userIDMapping.size();
+ userIDMapping.put(userID, userIndex);
+ }
+ return userIndex;
+ }
+
+ protected Integer itemIndex(long itemID) {
+ Integer itemIndex = itemIDMapping.get(itemID);
+ if (itemIndex == null) {
+ itemIndex = itemIDMapping.size();
+ itemIDMapping.put(itemID, itemIndex);
+ }
+ return itemIndex;
+ }
+
+ private static FastByIDMap<Integer> createIDMapping(int size, LongPrimitiveIterator idIterator) {
+ FastByIDMap<Integer> mapping = new FastByIDMap<Integer>(size);
+ int index = 0;
+ while (idIterator.hasNext()) {
+ mapping.put(idIterator.nextLong(), index++);
+ }
+ return mapping;
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
new file mode 100644
index 0000000..f169a60
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorization.java
@@ -0,0 +1,137 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.util.Arrays;
+import java.util.Map;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+
+/**
+ * a factorization of the rating matrix
+ */
+public class Factorization {
+
+ /** used to find the rows in the user features matrix by userID */
+ private final FastByIDMap<Integer> userIDMapping;
+ /** used to find the rows in the item features matrix by itemID */
+ private final FastByIDMap<Integer> itemIDMapping;
+
+ /** user features matrix */
+ private final double[][] userFeatures;
+ /** item features matrix */
+ private final double[][] itemFeatures;
+
+ public Factorization(FastByIDMap<Integer> userIDMapping, FastByIDMap<Integer> itemIDMapping, double[][] userFeatures,
+ double[][] itemFeatures) {
+ this.userIDMapping = Preconditions.checkNotNull(userIDMapping);
+ this.itemIDMapping = Preconditions.checkNotNull(itemIDMapping);
+ this.userFeatures = userFeatures;
+ this.itemFeatures = itemFeatures;
+ }
+
+ public double[][] allUserFeatures() {
+ return userFeatures;
+ }
+
+ public double[] getUserFeatures(long userID) throws NoSuchUserException {
+ Integer index = userIDMapping.get(userID);
+ if (index == null) {
+ throw new NoSuchUserException(userID);
+ }
+ return userFeatures[index];
+ }
+
+ public double[][] allItemFeatures() {
+ return itemFeatures;
+ }
+
+ public double[] getItemFeatures(long itemID) throws NoSuchItemException {
+ Integer index = itemIDMapping.get(itemID);
+ if (index == null) {
+ throw new NoSuchItemException(itemID);
+ }
+ return itemFeatures[index];
+ }
+
+ public int userIndex(long userID) throws NoSuchUserException {
+ Integer index = userIDMapping.get(userID);
+ if (index == null) {
+ throw new NoSuchUserException(userID);
+ }
+ return index;
+ }
+
+ public Iterable<Map.Entry<Long,Integer>> getUserIDMappings() {
+ return userIDMapping.entrySet();
+ }
+
+ public LongPrimitiveIterator getUserIDMappingKeys() {
+ return userIDMapping.keySetIterator();
+ }
+
+ public int itemIndex(long itemID) throws NoSuchItemException {
+ Integer index = itemIDMapping.get(itemID);
+ if (index == null) {
+ throw new NoSuchItemException(itemID);
+ }
+ return index;
+ }
+
+ public Iterable<Map.Entry<Long,Integer>> getItemIDMappings() {
+ return itemIDMapping.entrySet();
+ }
+
+ public LongPrimitiveIterator getItemIDMappingKeys() {
+ return itemIDMapping.keySetIterator();
+ }
+
+ public int numFeatures() {
+ return userFeatures.length > 0 ? userFeatures[0].length : 0;
+ }
+
+ public int numUsers() {
+ return userIDMapping.size();
+ }
+
+ public int numItems() {
+ return itemIDMapping.size();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof Factorization) {
+ Factorization other = (Factorization) o;
+ return userIDMapping.equals(other.userIDMapping) && itemIDMapping.equals(other.itemIDMapping)
+ && Arrays.deepEquals(userFeatures, other.userFeatures) && Arrays.deepEquals(itemFeatures, other.itemFeatures);
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ int hashCode = 31 * userIDMapping.hashCode() + itemIDMapping.hashCode();
+ hashCode = 31 * hashCode + Arrays.deepHashCode(userFeatures);
+ hashCode = 31 * hashCode + Arrays.deepHashCode(itemFeatures);
+ return hashCode;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
new file mode 100644
index 0000000..2cabe73
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/Factorizer.java
@@ -0,0 +1,30 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * Implementation must be able to create a factorization of a rating matrix
+ */
+public interface Factorizer extends Refreshable {
+
+ Factorization factorize() throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java
new file mode 100644
index 0000000..d1d23a5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategy.java
@@ -0,0 +1,148 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.io.BufferedInputStream;
+import java.io.BufferedOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Provides a file-based persistent store. */
+public class FilePersistenceStrategy implements PersistenceStrategy {
+
+ private final File file;
+
+ private static final Logger log = LoggerFactory.getLogger(FilePersistenceStrategy.class);
+
+ /**
+ * @param file the file to use for storage. If the file does not exist it will be created when required.
+ */
+ public FilePersistenceStrategy(File file) {
+ this.file = Preconditions.checkNotNull(file);
+ }
+
+ @Override
+ public Factorization load() throws IOException {
+ if (!file.exists()) {
+ log.info("{} does not yet exist, no factorization found", file.getAbsolutePath());
+ return null;
+ }
+ DataInputStream in = null;
+ try {
+ log.info("Reading factorization from {}...", file.getAbsolutePath());
+ in = new DataInputStream(new BufferedInputStream(new FileInputStream(file)));
+ return readBinary(in);
+ } finally {
+ Closeables.close(in, true);
+ }
+ }
+
+ @Override
+ public void maybePersist(Factorization factorization) throws IOException {
+ DataOutputStream out = null;
+ try {
+ log.info("Writing factorization to {}...", file.getAbsolutePath());
+ out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(file)));
+ writeBinary(factorization, out);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ protected static void writeBinary(Factorization factorization, DataOutput out) throws IOException {
+ out.writeInt(factorization.numFeatures());
+ out.writeInt(factorization.numUsers());
+ out.writeInt(factorization.numItems());
+
+ for (Map.Entry<Long,Integer> mappingEntry : factorization.getUserIDMappings()) {
+ long userID = mappingEntry.getKey();
+ out.writeInt(mappingEntry.getValue());
+ out.writeLong(userID);
+ try {
+ double[] userFeatures = factorization.getUserFeatures(userID);
+ for (int feature = 0; feature < factorization.numFeatures(); feature++) {
+ out.writeDouble(userFeatures[feature]);
+ }
+ } catch (NoSuchUserException e) {
+ throw new IOException("Unable to persist factorization", e);
+ }
+ }
+
+ for (Map.Entry<Long,Integer> entry : factorization.getItemIDMappings()) {
+ long itemID = entry.getKey();
+ out.writeInt(entry.getValue());
+ out.writeLong(itemID);
+ try {
+ double[] itemFeatures = factorization.getItemFeatures(itemID);
+ for (int feature = 0; feature < factorization.numFeatures(); feature++) {
+ out.writeDouble(itemFeatures[feature]);
+ }
+ } catch (NoSuchItemException e) {
+ throw new IOException("Unable to persist factorization", e);
+ }
+ }
+ }
+
+ public static Factorization readBinary(DataInput in) throws IOException {
+ int numFeatures = in.readInt();
+ int numUsers = in.readInt();
+ int numItems = in.readInt();
+
+ FastByIDMap<Integer> userIDMapping = new FastByIDMap<Integer>(numUsers);
+ double[][] userFeatures = new double[numUsers][numFeatures];
+
+ for (int n = 0; n < numUsers; n++) {
+ int userIndex = in.readInt();
+ long userID = in.readLong();
+ userIDMapping.put(userID, userIndex);
+ for (int feature = 0; feature < numFeatures; feature++) {
+ userFeatures[userIndex][feature] = in.readDouble();
+ }
+ }
+
+ FastByIDMap<Integer> itemIDMapping = new FastByIDMap<Integer>(numItems);
+ double[][] itemFeatures = new double[numItems][numFeatures];
+
+ for (int n = 0; n < numItems; n++) {
+ int itemIndex = in.readInt();
+ long itemID = in.readLong();
+ itemIDMapping.put(itemID, itemIndex);
+ for (int feature = 0; feature < numFeatures; feature++) {
+ itemFeatures[itemIndex][feature] = in.readDouble();
+ }
+ }
+
+ return new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java
new file mode 100644
index 0000000..0d1aab0
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/NoPersistenceStrategy.java
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.io.IOException;
+
+/**
+ * A {@link PersistenceStrategy} which does nothing.
+ */
+public class NoPersistenceStrategy implements PersistenceStrategy {
+
+ @Override
+ public Factorization load() throws IOException {
+ return null;
+ }
+
+ @Override
+ public void maybePersist(Factorization factorization) throws IOException {
+ // do nothing.
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java
new file mode 100644
index 0000000..8a6a702
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizer.java
@@ -0,0 +1,340 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Minimalistic implementation of Parallel SGD factorizer based on
+ * <a href="http://www.sze.hu/~gtakacs/download/jmlr_2009.pdf">
+ * "Scalable Collaborative Filtering Approaches for Large Recommender Systems"</a>
+ * and
+ * <a href="hwww.cs.wisc.edu/~brecht/papers/hogwildTR.pdf">
+ * "Hogwild!: A Lock-Free Approach to Parallelizing Stochastic Gradient Descent"</a> */
+public class ParallelSGDFactorizer extends AbstractFactorizer {
+
+ private final DataModel dataModel;
+ /** Parameter used to prevent overfitting. */
+ private final double lambda;
+ /** Number of features used to compute this factorization */
+ private final int rank;
+ /** Number of iterations */
+ private final int numEpochs;
+
+ private int numThreads;
+
+ // these next two control decayFactor^steps exponential type of annealing learning rate and decay factor
+ private double mu0 = 0.01;
+ private double decayFactor = 1;
+ // these next two control 1/steps^forget type annealing
+ private int stepOffset = 0;
+ // -1 equals even weighting of all examples, 0 means only use exponential annealing
+ private double forgettingExponent = 0;
+
+ // The following two should be inversely proportional :)
+ private double biasMuRatio = 0.5;
+ private double biasLambdaRatio = 0.1;
+
+ /** TODO: this is not safe as += is not atomic on many processors, can be replaced with AtomicDoubleArray
+ * but it works just fine right now */
+ /** user features */
+ protected volatile double[][] userVectors;
+ /** item features */
+ protected volatile double[][] itemVectors;
+
+ private final PreferenceShuffler shuffler;
+
+ private int epoch = 1;
+ /** place in user vector where the bias is stored */
+ private static final int USER_BIAS_INDEX = 1;
+ /** place in item vector where the bias is stored */
+ private static final int ITEM_BIAS_INDEX = 2;
+ private static final int FEATURE_OFFSET = 3;
+ /** Standard deviation for random initialization of features */
+ private static final double NOISE = 0.02;
+
+ private static final Logger logger = LoggerFactory.getLogger(ParallelSGDFactorizer.class);
+
+ protected static class PreferenceShuffler {
+
+ private Preference[] preferences;
+ private Preference[] unstagedPreferences;
+
+ protected final RandomWrapper random = RandomUtils.getRandom();
+
+ public PreferenceShuffler(DataModel dataModel) throws TasteException {
+ cachePreferences(dataModel);
+ shuffle();
+ stage();
+ }
+
+ private int countPreferences(DataModel dataModel) throws TasteException {
+ int numPreferences = 0;
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong());
+ numPreferences += preferencesFromUser.length();
+ }
+ return numPreferences;
+ }
+
+ private void cachePreferences(DataModel dataModel) throws TasteException {
+ int numPreferences = countPreferences(dataModel);
+ preferences = new Preference[numPreferences];
+
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ int index = 0;
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID);
+ for (Preference preference : preferencesFromUser) {
+ preferences[index++] = preference;
+ }
+ }
+ }
+
+ public final void shuffle() {
+ unstagedPreferences = preferences.clone();
+ /* Durstenfeld shuffle */
+ for (int i = unstagedPreferences.length - 1; i > 0; i--) {
+ int rand = random.nextInt(i + 1);
+ swapCachedPreferences(i, rand);
+ }
+ }
+
+ //merge this part into shuffle() will make compiler-optimizer do some real absurd stuff, test on OpenJDK7
+ private void swapCachedPreferences(int x, int y) {
+ Preference p = unstagedPreferences[x];
+
+ unstagedPreferences[x] = unstagedPreferences[y];
+ unstagedPreferences[y] = p;
+ }
+
+ public final void stage() {
+ preferences = unstagedPreferences;
+ }
+
+ public Preference get(int i) {
+ return preferences[i];
+ }
+
+ public int size() {
+ return preferences.length;
+ }
+
+ }
+
+ public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numEpochs)
+ throws TasteException {
+ super(dataModel);
+ this.dataModel = dataModel;
+ this.rank = numFeatures + FEATURE_OFFSET;
+ this.lambda = lambda;
+ this.numEpochs = numEpochs;
+
+ shuffler = new PreferenceShuffler(dataModel);
+
+ //max thread num set to n^0.25 as suggested by hogwild! paper
+ numThreads = Math.min(Runtime.getRuntime().availableProcessors(), (int) Math.pow((double) shuffler.size(), 0.25));
+ }
+
+ public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ double mu0, double decayFactor, int stepOffset, double forgettingExponent) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations);
+
+ this.mu0 = mu0;
+ this.decayFactor = decayFactor;
+ this.stepOffset = stepOffset;
+ this.forgettingExponent = forgettingExponent;
+ }
+
+ public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ double mu0, double decayFactor, int stepOffset, double forgettingExponent, int numThreads) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent);
+
+ this.numThreads = numThreads;
+ }
+
+ public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ double mu0, double decayFactor, int stepOffset, double forgettingExponent,
+ double biasMuRatio, double biasLambdaRatio) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent);
+
+ this.biasMuRatio = biasMuRatio;
+ this.biasLambdaRatio = biasLambdaRatio;
+ }
+
+ public ParallelSGDFactorizer(DataModel dataModel, int numFeatures, double lambda, int numIterations,
+ double mu0, double decayFactor, int stepOffset, double forgettingExponent,
+ double biasMuRatio, double biasLambdaRatio, int numThreads) throws TasteException {
+ this(dataModel, numFeatures, lambda, numIterations, mu0, decayFactor, stepOffset, forgettingExponent, biasMuRatio,
+ biasLambdaRatio);
+
+ this.numThreads = numThreads;
+ }
+
+ protected void initialize() throws TasteException {
+ RandomWrapper random = RandomUtils.getRandom();
+ userVectors = new double[dataModel.getNumUsers()][rank];
+ itemVectors = new double[dataModel.getNumItems()][rank];
+
+ double globalAverage = getAveragePreference();
+ for (int userIndex = 0; userIndex < userVectors.length; userIndex++) {
+ userVectors[userIndex][0] = globalAverage;
+ userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias
+ userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item feature contains item bias
+ for (int feature = FEATURE_OFFSET; feature < rank; feature++) {
+ userVectors[userIndex][feature] = random.nextGaussian() * NOISE;
+ }
+ }
+ for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) {
+ itemVectors[itemIndex][0] = 1; // corresponding user feature contains global average
+ itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user feature contains user bias
+ itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias
+ for (int feature = FEATURE_OFFSET; feature < rank; feature++) {
+ itemVectors[itemIndex][feature] = random.nextGaussian() * NOISE;
+ }
+ }
+ }
+
+ //TODO: needs optimization
+ private double getMu(int i) {
+ return mu0 * Math.pow(decayFactor, i - 1) * Math.pow(i + stepOffset, forgettingExponent);
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ initialize();
+
+ if (logger.isInfoEnabled()) {
+ logger.info("starting to compute the factorization...");
+ }
+
+ for (epoch = 1; epoch <= numEpochs; epoch++) {
+ shuffler.stage();
+
+ final double mu = getMu(epoch);
+ int subSize = shuffler.size() / numThreads + 1;
+
+ ExecutorService executor=Executors.newFixedThreadPool(numThreads);
+
+ try {
+ for (int t = 0; t < numThreads; t++) {
+ final int iStart = t * subSize;
+ final int iEnd = Math.min((t + 1) * subSize, shuffler.size());
+
+ executor.execute(new Runnable() {
+ @Override
+ public void run() {
+ for (int i = iStart; i < iEnd; i++) {
+ update(shuffler.get(i), mu);
+ }
+ }
+ });
+ }
+ } finally {
+ executor.shutdown();
+ shuffler.shuffle();
+
+ try {
+ boolean terminated = executor.awaitTermination(numEpochs * shuffler.size(), TimeUnit.MICROSECONDS);
+ if (!terminated) {
+ logger.error("subtasks takes forever, return anyway");
+ }
+ } catch (InterruptedException e) {
+ throw new TasteException("waiting fof termination interrupted", e);
+ }
+ }
+
+ }
+
+ return createFactorization(userVectors, itemVectors);
+ }
+
+ double getAveragePreference() throws TasteException {
+ RunningAverage average = new FullRunningAverage();
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+ average.addDatum(pref.getValue());
+ }
+ }
+ return average.getAverage();
+ }
+
+ /** TODO: this is the vanilla sgd by Tacaks 2009, I speculate that using scaling technique proposed in:
+ * Towards Optimal One Pass Large Scale Learning with Averaged Stochastic Gradient Descent section 5, page 6
+ * can be beneficial in term s of both speed and accuracy.
+ *
+ * Tacaks' method doesn't calculate gradient of regularization correctly, which has non-zero elements everywhere of
+ * the matrix. While Tacaks' method can only updates a single row/column, if one user has a lot of recommendation,
+ * her vector will be more affected by regularization using an isolated scaling factor for both user vectors and
+ * item vectors can remove this issue without inducing more update cost it even reduces it a bit by only performing
+ * one addition and one multiplication.
+ *
+ * BAD SIDE1: the scaling factor decreases fast, it has to be scaled up from time to time before dropped to zero or
+ * caused roundoff error
+ * BAD SIDE2: no body experiment on it before, and people generally use very small lambda
+ * so it's impact on accuracy may still be unknown.
+ * BAD SIDE3: don't know how to make it work for L1-regularization or
+ * "pseudorank?" (sum of singular values)-regularization */
+ protected void update(Preference preference, double mu) {
+ int userIndex = userIndex(preference.getUserID());
+ int itemIndex = itemIndex(preference.getItemID());
+
+ double[] userVector = userVectors[userIndex];
+ double[] itemVector = itemVectors[itemIndex];
+
+ double prediction = dot(userVector, itemVector);
+ double err = preference.getValue() - prediction;
+
+ // adjust features
+ for (int k = FEATURE_OFFSET; k < rank; k++) {
+ double userFeature = userVector[k];
+ double itemFeature = itemVector[k];
+
+ userVector[k] += mu * (err * itemFeature - lambda * userFeature);
+ itemVector[k] += mu * (err * userFeature - lambda * itemFeature);
+ }
+
+ // adjust user and item bias
+ userVector[USER_BIAS_INDEX] += biasMuRatio * mu * (err - biasLambdaRatio * lambda * userVector[USER_BIAS_INDEX]);
+ itemVector[ITEM_BIAS_INDEX] += biasMuRatio * mu * (err - biasLambdaRatio * lambda * itemVector[ITEM_BIAS_INDEX]);
+ }
+
+ private double dot(double[] userVector, double[] itemVector) {
+ double sum = 0;
+ for (int k = 0; k < rank; k++) {
+ sum += userVector[k] * itemVector[k];
+ }
+ return sum;
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java
new file mode 100644
index 0000000..abf3eca
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/PersistenceStrategy.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.io.IOException;
+
+/**
+ * Provides storage for {@link Factorization}s
+ */
+public interface PersistenceStrategy {
+
+ /**
+ * Load a factorization from a persistent store.
+ *
+ * @return a Factorization or null if the persistent store is empty.
+ *
+ * @throws IOException
+ */
+ Factorization load() throws IOException;
+
+ /**
+ * Write a factorization to a persistent store unless it already
+ * contains an identical factorization.
+ *
+ * @param factorization
+ *
+ * @throws IOException
+ */
+ void maybePersist(Factorization factorization) throws IOException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
new file mode 100644
index 0000000..2c9f0ae
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/RatingSGDFactorizer.java
@@ -0,0 +1,221 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+
+/** Matrix factorization with user and item biases for rating prediction, trained with plain vanilla SGD */
+public class RatingSGDFactorizer extends AbstractFactorizer {
+
+ protected static final int FEATURE_OFFSET = 3;
+
+ /** Multiplicative decay factor for learning_rate */
+ protected final double learningRateDecay;
+ /** Learning rate (step size) */
+ protected final double learningRate;
+ /** Parameter used to prevent overfitting. */
+ protected final double preventOverfitting;
+ /** Number of features used to compute this factorization */
+ protected final int numFeatures;
+ /** Number of iterations */
+ private final int numIterations;
+ /** Standard deviation for random initialization of features */
+ protected final double randomNoise;
+ /** User features */
+ protected double[][] userVectors;
+ /** Item features */
+ protected double[][] itemVectors;
+ protected final DataModel dataModel;
+ private long[] cachedUserIDs;
+ private long[] cachedItemIDs;
+
+ protected double biasLearningRate = 0.5;
+ protected double biasReg = 0.1;
+
+ /** place in user vector where the bias is stored */
+ protected static final int USER_BIAS_INDEX = 1;
+ /** place in item vector where the bias is stored */
+ protected static final int ITEM_BIAS_INDEX = 2;
+
+ public RatingSGDFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException {
+ this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0);
+ }
+
+ public RatingSGDFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting,
+ double randomNoise, int numIterations, double learningRateDecay) throws TasteException {
+ super(dataModel);
+ this.dataModel = dataModel;
+ this.numFeatures = numFeatures + FEATURE_OFFSET;
+ this.numIterations = numIterations;
+
+ this.learningRate = learningRate;
+ this.learningRateDecay = learningRateDecay;
+ this.preventOverfitting = preventOverfitting;
+ this.randomNoise = randomNoise;
+ }
+
+ protected void prepareTraining() throws TasteException {
+ RandomWrapper random = RandomUtils.getRandom();
+ userVectors = new double[dataModel.getNumUsers()][numFeatures];
+ itemVectors = new double[dataModel.getNumItems()][numFeatures];
+
+ double globalAverage = getAveragePreference();
+ for (int userIndex = 0; userIndex < userVectors.length; userIndex++) {
+ userVectors[userIndex][0] = globalAverage;
+ userVectors[userIndex][USER_BIAS_INDEX] = 0; // will store user bias
+ userVectors[userIndex][ITEM_BIAS_INDEX] = 1; // corresponding item feature contains item bias
+ for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
+ userVectors[userIndex][feature] = random.nextGaussian() * randomNoise;
+ }
+ }
+ for (int itemIndex = 0; itemIndex < itemVectors.length; itemIndex++) {
+ itemVectors[itemIndex][0] = 1; // corresponding user feature contains global average
+ itemVectors[itemIndex][USER_BIAS_INDEX] = 1; // corresponding user feature contains user bias
+ itemVectors[itemIndex][ITEM_BIAS_INDEX] = 0; // will store item bias
+ for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
+ itemVectors[itemIndex][feature] = random.nextGaussian() * randomNoise;
+ }
+ }
+
+ cachePreferences();
+ shufflePreferences();
+ }
+
+ private int countPreferences() throws TasteException {
+ int numPreferences = 0;
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userIDs.nextLong());
+ numPreferences += preferencesFromUser.length();
+ }
+ return numPreferences;
+ }
+
+ private void cachePreferences() throws TasteException {
+ int numPreferences = countPreferences();
+ cachedUserIDs = new long[numPreferences];
+ cachedItemIDs = new long[numPreferences];
+
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ int index = 0;
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID);
+ for (Preference preference : preferencesFromUser) {
+ cachedUserIDs[index] = userID;
+ cachedItemIDs[index] = preference.getItemID();
+ index++;
+ }
+ }
+ }
+
+ protected void shufflePreferences() {
+ RandomWrapper random = RandomUtils.getRandom();
+ /* Durstenfeld shuffle */
+ for (int currentPos = cachedUserIDs.length - 1; currentPos > 0; currentPos--) {
+ int swapPos = random.nextInt(currentPos + 1);
+ swapCachedPreferences(currentPos, swapPos);
+ }
+ }
+
+ private void swapCachedPreferences(int posA, int posB) {
+ long tmpUserIndex = cachedUserIDs[posA];
+ long tmpItemIndex = cachedItemIDs[posA];
+
+ cachedUserIDs[posA] = cachedUserIDs[posB];
+ cachedItemIDs[posA] = cachedItemIDs[posB];
+
+ cachedUserIDs[posB] = tmpUserIndex;
+ cachedItemIDs[posB] = tmpItemIndex;
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ prepareTraining();
+ double currentLearningRate = learningRate;
+
+
+ for (int it = 0; it < numIterations; it++) {
+ for (int index = 0; index < cachedUserIDs.length; index++) {
+ long userId = cachedUserIDs[index];
+ long itemId = cachedItemIDs[index];
+ float rating = dataModel.getPreferenceValue(userId, itemId);
+ updateParameters(userId, itemId, rating, currentLearningRate);
+ }
+ currentLearningRate *= learningRateDecay;
+ }
+ return createFactorization(userVectors, itemVectors);
+ }
+
+ double getAveragePreference() throws TasteException {
+ RunningAverage average = new FullRunningAverage();
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ for (Preference pref : dataModel.getPreferencesFromUser(it.nextLong())) {
+ average.addDatum(pref.getValue());
+ }
+ }
+ return average.getAverage();
+ }
+
+ protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) {
+ int userIndex = userIndex(userID);
+ int itemIndex = itemIndex(itemID);
+
+ double[] userVector = userVectors[userIndex];
+ double[] itemVector = itemVectors[itemIndex];
+ double prediction = predictRating(userIndex, itemIndex);
+ double err = rating - prediction;
+
+ // adjust user bias
+ userVector[USER_BIAS_INDEX] +=
+ biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * userVector[USER_BIAS_INDEX]);
+
+ // adjust item bias
+ itemVector[ITEM_BIAS_INDEX] +=
+ biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * itemVector[ITEM_BIAS_INDEX]);
+
+ // adjust features
+ for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
+ double userFeature = userVector[feature];
+ double itemFeature = itemVector[feature];
+
+ double deltaUserFeature = err * itemFeature - preventOverfitting * userFeature;
+ userVector[feature] += currentLearningRate * deltaUserFeature;
+
+ double deltaItemFeature = err * userFeature - preventOverfitting * itemFeature;
+ itemVector[feature] += currentLearningRate * deltaItemFeature;
+ }
+ }
+
+ private double predictRating(int userID, int itemID) {
+ double sum = 0;
+ for (int feature = 0; feature < numFeatures; feature++) {
+ sum += userVectors[userID][feature] * itemVectors[itemID][feature];
+ }
+ return sum;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java
new file mode 100644
index 0000000..8967134
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPlusPlusFactorizer.java
@@ -0,0 +1,178 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.common.RandomUtils;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.DataModel;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * SVD++, an enhancement of classical matrix factorization for rating prediction.
+ * Additionally to using ratings (how did people rate?) for learning, this model also takes into account
+ * who rated what.
+ *
+ * Yehuda Koren: Factorization Meets the Neighborhood: a Multifaceted Collaborative Filtering Model, KDD 2008.
+ * http://research.yahoo.com/files/kdd08koren.pdf
+ */
+public final class SVDPlusPlusFactorizer extends RatingSGDFactorizer {
+
+ private double[][] p;
+ private double[][] y;
+ private Map<Integer, List<Integer>> itemsByUser;
+
+ public SVDPlusPlusFactorizer(DataModel dataModel, int numFeatures, int numIterations) throws TasteException {
+ this(dataModel, numFeatures, 0.01, 0.1, 0.01, numIterations, 1.0);
+ biasLearningRate = 0.7;
+ biasReg = 0.33;
+ }
+
+ public SVDPlusPlusFactorizer(DataModel dataModel, int numFeatures, double learningRate, double preventOverfitting,
+ double randomNoise, int numIterations, double learningRateDecay) throws TasteException {
+ super(dataModel, numFeatures, learningRate, preventOverfitting, randomNoise, numIterations, learningRateDecay);
+ }
+
+ @Override
+ protected void prepareTraining() throws TasteException {
+ super.prepareTraining();
+ Random random = RandomUtils.getRandom();
+
+ p = new double[dataModel.getNumUsers()][numFeatures];
+ for (int i = 0; i < p.length; i++) {
+ for (int feature = 0; feature < FEATURE_OFFSET; feature++) {
+ p[i][feature] = 0;
+ }
+ for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
+ p[i][feature] = random.nextGaussian() * randomNoise;
+ }
+ }
+
+ y = new double[dataModel.getNumItems()][numFeatures];
+ for (int i = 0; i < y.length; i++) {
+ for (int feature = 0; feature < FEATURE_OFFSET; feature++) {
+ y[i][feature] = 0;
+ }
+ for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
+ y[i][feature] = random.nextGaussian() * randomNoise;
+ }
+ }
+
+ /* get internal item IDs which we will need several times */
+ itemsByUser = Maps.newHashMap();
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ long userId = userIDs.nextLong();
+ int userIndex = userIndex(userId);
+ FastIDSet itemIDsFromUser = dataModel.getItemIDsFromUser(userId);
+ List<Integer> itemIndexes = Lists.newArrayListWithCapacity(itemIDsFromUser.size());
+ itemsByUser.put(userIndex, itemIndexes);
+ for (long itemID2 : itemIDsFromUser) {
+ int i2 = itemIndex(itemID2);
+ itemIndexes.add(i2);
+ }
+ }
+ }
+
+ @Override
+ public Factorization factorize() throws TasteException {
+ prepareTraining();
+
+ super.factorize();
+
+ for (int userIndex = 0; userIndex < userVectors.length; userIndex++) {
+ for (int itemIndex : itemsByUser.get(userIndex)) {
+ for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
+ userVectors[userIndex][feature] += y[itemIndex][feature];
+ }
+ }
+ double denominator = Math.sqrt(itemsByUser.get(userIndex).size());
+ for (int feature = 0; feature < userVectors[userIndex].length; feature++) {
+ userVectors[userIndex][feature] =
+ (float) (userVectors[userIndex][feature] / denominator + p[userIndex][feature]);
+ }
+ }
+
+ return createFactorization(userVectors, itemVectors);
+ }
+
+
+ @Override
+ protected void updateParameters(long userID, long itemID, float rating, double currentLearningRate) {
+ int userIndex = userIndex(userID);
+ int itemIndex = itemIndex(itemID);
+
+ double[] userVector = p[userIndex];
+ double[] itemVector = itemVectors[itemIndex];
+
+ double[] pPlusY = new double[numFeatures];
+ for (int i2 : itemsByUser.get(userIndex)) {
+ for (int f = FEATURE_OFFSET; f < numFeatures; f++) {
+ pPlusY[f] += y[i2][f];
+ }
+ }
+ double denominator = Math.sqrt(itemsByUser.get(userIndex).size());
+ for (int feature = 0; feature < pPlusY.length; feature++) {
+ pPlusY[feature] = (float) (pPlusY[feature] / denominator + p[userIndex][feature]);
+ }
+
+ double prediction = predictRating(pPlusY, itemIndex);
+ double err = rating - prediction;
+ double normalized_error = err / denominator;
+
+ // adjust user bias
+ userVector[USER_BIAS_INDEX] +=
+ biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * userVector[USER_BIAS_INDEX]);
+
+ // adjust item bias
+ itemVector[ITEM_BIAS_INDEX] +=
+ biasLearningRate * currentLearningRate * (err - biasReg * preventOverfitting * itemVector[ITEM_BIAS_INDEX]);
+
+ // adjust features
+ for (int feature = FEATURE_OFFSET; feature < numFeatures; feature++) {
+ double pF = userVector[feature];
+ double iF = itemVector[feature];
+
+ double deltaU = err * iF - preventOverfitting * pF;
+ userVector[feature] += currentLearningRate * deltaU;
+
+ double deltaI = err * pPlusY[feature] - preventOverfitting * iF;
+ itemVector[feature] += currentLearningRate * deltaI;
+
+ double commonUpdate = normalized_error * iF;
+ for (int itemIndex2 : itemsByUser.get(userIndex)) {
+ double deltaI2 = commonUpdate - preventOverfitting * y[itemIndex2][feature];
+ y[itemIndex2][feature] += learningRate * deltaI2;
+ }
+ }
+ }
+
+ private double predictRating(double[] userVector, int itemID) {
+ double sum = 0;
+ for (int feature = 0; feature < numFeatures; feature++) {
+ sum += userVector[feature] * itemVectors[itemID][feature];
+ }
+ return sum;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
new file mode 100644
index 0000000..45c54da
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDPreference.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+
+final class SVDPreference extends GenericPreference {
+
+ private double cache;
+
+ SVDPreference(long userID, long itemID, float value, double cache) {
+ super(userID, itemID, value);
+ setCache(cache);
+ }
+
+ public double getCache() {
+ return cache;
+ }
+
+ public void setCache(double value) {
+ Preconditions.checkArgument(!Double.isNaN(value), "NaN cache value");
+ this.cache = value;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java
new file mode 100644
index 0000000..45d4af7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommender.java
@@ -0,0 +1,185 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.impl.recommender.AbstractRecommender;
+import org.apache.mahout.cf.taste.impl.recommender.AllUnknownItemsCandidateItemsStrategy;
+import org.apache.mahout.cf.taste.impl.recommender.TopItems;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A {@link org.apache.mahout.cf.taste.recommender.Recommender} that uses matrix factorization (a projection of users
+ * and items onto a feature space)
+ */
+public final class SVDRecommender extends AbstractRecommender {
+
+ private Factorization factorization;
+ private final Factorizer factorizer;
+ private final PersistenceStrategy persistenceStrategy;
+ private final RefreshHelper refreshHelper;
+
+ private static final Logger log = LoggerFactory.getLogger(SVDRecommender.class);
+
+ public SVDRecommender(DataModel dataModel, Factorizer factorizer) throws TasteException {
+ this(dataModel, factorizer, new AllUnknownItemsCandidateItemsStrategy(), getDefaultPersistenceStrategy());
+ }
+
+ public SVDRecommender(DataModel dataModel, Factorizer factorizer, CandidateItemsStrategy candidateItemsStrategy)
+ throws TasteException {
+ this(dataModel, factorizer, candidateItemsStrategy, getDefaultPersistenceStrategy());
+ }
+
+ /**
+ * Create an SVDRecommender using a persistent store to cache factorizations. A factorization is loaded from the
+ * store if present, otherwise a new factorization is computed and saved in the store.
+ *
+ * The {@link #refresh(java.util.Collection) refresh} method recomputes the factorization and overwrites the store.
+ *
+ * @param dataModel
+ * @param factorizer
+ * @param persistenceStrategy
+ * @throws TasteException
+ * @throws IOException
+ */
+ public SVDRecommender(DataModel dataModel, Factorizer factorizer, PersistenceStrategy persistenceStrategy)
+ throws TasteException {
+ this(dataModel, factorizer, getDefaultCandidateItemsStrategy(), persistenceStrategy);
+ }
+
+ /**
+ * Create an SVDRecommender using a persistent store to cache factorizations. A factorization is loaded from the
+ * store if present, otherwise a new factorization is computed and saved in the store.
+ *
+ * The {@link #refresh(java.util.Collection) refresh} method recomputes the factorization and overwrites the store.
+ *
+ * @param dataModel
+ * @param factorizer
+ * @param candidateItemsStrategy
+ * @param persistenceStrategy
+ *
+ * @throws TasteException
+ */
+ public SVDRecommender(DataModel dataModel, Factorizer factorizer, CandidateItemsStrategy candidateItemsStrategy,
+ PersistenceStrategy persistenceStrategy) throws TasteException {
+ super(dataModel, candidateItemsStrategy);
+ this.factorizer = Preconditions.checkNotNull(factorizer);
+ this.persistenceStrategy = Preconditions.checkNotNull(persistenceStrategy);
+ try {
+ factorization = persistenceStrategy.load();
+ } catch (IOException e) {
+ throw new TasteException("Error loading factorization", e);
+ }
+
+ if (factorization == null) {
+ train();
+ }
+
+ refreshHelper = new RefreshHelper(new Callable<Object>() {
+ @Override
+ public Object call() throws TasteException {
+ train();
+ return null;
+ }
+ });
+ refreshHelper.addDependency(getDataModel());
+ refreshHelper.addDependency(factorizer);
+ refreshHelper.addDependency(candidateItemsStrategy);
+ }
+
+ static PersistenceStrategy getDefaultPersistenceStrategy() {
+ return new NoPersistenceStrategy();
+ }
+
+ private void train() throws TasteException {
+ factorization = factorizer.factorize();
+ try {
+ persistenceStrategy.maybePersist(factorization);
+ } catch (IOException e) {
+ throw new TasteException("Error persisting factorization", e);
+ }
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
+ log.debug("Recommending items for user ID '{}'", userID);
+
+ PreferenceArray preferencesFromUser = getDataModel().getPreferencesFromUser(userID);
+ FastIDSet possibleItemIDs = getAllOtherItems(userID, preferencesFromUser, includeKnownItems);
+
+ List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer,
+ new Estimator(userID));
+ log.debug("Recommendations are: {}", topItems);
+
+ return topItems;
+ }
+
+ /**
+ * a preference is estimated by computing the dot-product of the user and item feature vectors
+ */
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ double[] userFeatures = factorization.getUserFeatures(userID);
+ double[] itemFeatures = factorization.getItemFeatures(itemID);
+ double estimate = 0;
+ for (int feature = 0; feature < userFeatures.length; feature++) {
+ estimate += userFeatures[feature] * itemFeatures[feature];
+ }
+ return (float) estimate;
+ }
+
+ private final class Estimator implements TopItems.Estimator<Long> {
+
+ private final long theUserID;
+
+ private Estimator(long theUserID) {
+ this.theUserID = theUserID;
+ }
+
+ @Override
+ public double estimate(Long itemID) throws TasteException {
+ return estimatePreference(theUserID, itemID);
+ }
+ }
+
+ /**
+ * Refresh the data model and factorization.
+ */
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java
new file mode 100644
index 0000000..e0d6f59
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractItemSimilarity.java
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+
+import java.util.Collection;
+
+public abstract class AbstractItemSimilarity implements ItemSimilarity {
+
+ private final DataModel dataModel;
+ private final RefreshHelper refreshHelper;
+
+ protected AbstractItemSimilarity(DataModel dataModel) {
+ Preconditions.checkArgument(dataModel != null, "dataModel is null");
+ this.dataModel = dataModel;
+ this.refreshHelper = new RefreshHelper(null);
+ refreshHelper.addDependency(this.dataModel);
+ }
+
+ protected DataModel getDataModel() {
+ return dataModel;
+ }
+
+ @Override
+ public long[] allSimilarItemIDs(long itemID) throws TasteException {
+ FastIDSet allSimilarItemIDs = new FastIDSet();
+ LongPrimitiveIterator allItemIDs = dataModel.getItemIDs();
+ while (allItemIDs.hasNext()) {
+ long possiblySimilarItemID = allItemIDs.nextLong();
+ if (!Double.isNaN(itemSimilarity(itemID, possiblySimilarItemID))) {
+ allSimilarItemIDs.add(possiblySimilarItemID);
+ }
+ }
+ return allSimilarItemIDs.toArray();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+}
[35/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java b/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
new file mode 100644
index 0000000..d02d974
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java
@@ -0,0 +1,296 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.net.URI;
+import java.util.Arrays;
+
+/**
+ * Temporary class used to compute the frequency distribution of the "class attribute".<br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+public class FrequenciesJob {
+
+ private static final Logger log = LoggerFactory.getLogger(FrequenciesJob.class);
+
+ /** directory that will hold this job's output */
+ private final Path outputPath;
+
+ /** file that contains the serialized dataset */
+ private final Path datasetPath;
+
+ /** directory that contains the data used in the first step */
+ private final Path dataPath;
+
+ /**
+ * @param base
+ * base directory
+ * @param dataPath
+ * data used in the first step
+ */
+ public FrequenciesJob(Path base, Path dataPath, Path datasetPath) {
+ this.outputPath = new Path(base, "frequencies.output");
+ this.dataPath = dataPath;
+ this.datasetPath = datasetPath;
+ }
+
+ /**
+ * @return counts[partition][label] = num tuples from 'partition' with class == label
+ */
+ public int[][] run(Configuration conf) throws IOException, ClassNotFoundException, InterruptedException {
+
+ // check the output
+ FileSystem fs = outputPath.getFileSystem(conf);
+ if (fs.exists(outputPath)) {
+ throw new IOException("Output path already exists : " + outputPath);
+ }
+
+ // put the dataset into the DistributedCache
+ URI[] files = {datasetPath.toUri()};
+ DistributedCache.setCacheFiles(files, conf);
+
+ Job job = new Job(conf);
+ job.setJarByClass(FrequenciesJob.class);
+
+ FileInputFormat.setInputPaths(job, dataPath);
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ job.setMapOutputKeyClass(LongWritable.class);
+ job.setMapOutputValueClass(IntWritable.class);
+ job.setOutputKeyClass(LongWritable.class);
+ job.setOutputValueClass(Frequencies.class);
+
+ job.setMapperClass(FrequenciesMapper.class);
+ job.setReducerClass(FrequenciesReducer.class);
+
+ job.setInputFormatClass(TextInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ // run the job
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+
+ int[][] counts = parseOutput(job);
+
+ HadoopUtil.delete(conf, outputPath);
+
+ return counts;
+ }
+
+ /**
+ * Extracts the output and processes it
+ *
+ * @return counts[partition][label] = num tuples from 'partition' with class == label
+ */
+ int[][] parseOutput(JobContext job) throws IOException {
+ Configuration conf = job.getConfiguration();
+
+ int numMaps = conf.getInt("mapred.map.tasks", -1);
+ log.info("mapred.map.tasks = {}", numMaps);
+
+ FileSystem fs = outputPath.getFileSystem(conf);
+
+ Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);
+
+ Frequencies[] values = new Frequencies[numMaps];
+
+ // read all the outputs
+ int index = 0;
+ for (Path path : outfiles) {
+ for (Frequencies value : new SequenceFileValueIterable<Frequencies>(path, conf)) {
+ values[index++] = value;
+ }
+ }
+
+ if (index < numMaps) {
+ throw new IllegalStateException("number of output Frequencies (" + index
+ + ") is lesser than the number of mappers!");
+ }
+
+ // sort the frequencies using the firstIds
+ Arrays.sort(values);
+ return Frequencies.extractCounts(values);
+ }
+
+ /**
+ * Outputs the first key and the label of each tuple
+ *
+ */
+ private static class FrequenciesMapper extends Mapper<LongWritable,Text,LongWritable,IntWritable> {
+
+ private LongWritable firstId;
+
+ private DataConverter converter;
+ private Dataset dataset;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+
+ dataset = Builder.loadDataset(conf);
+ setup(dataset);
+ }
+
+ /**
+ * Useful when testing
+ */
+ void setup(Dataset dataset) {
+ converter = new DataConverter(dataset);
+ }
+
+ @Override
+ protected void map(LongWritable key, Text value, Context context) throws IOException,
+ InterruptedException {
+ if (firstId == null) {
+ firstId = new LongWritable(key.get());
+ }
+
+ Instance instance = converter.convert(value.toString());
+
+ context.write(firstId, new IntWritable((int) dataset.getLabel(instance)));
+ }
+
+ }
+
+ private static class FrequenciesReducer extends Reducer<LongWritable,IntWritable,LongWritable,Frequencies> {
+
+ private int nblabels;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+ Dataset dataset = Builder.loadDataset(conf);
+ setup(dataset.nblabels());
+ }
+
+ /**
+ * Useful when testing
+ */
+ void setup(int nblabels) {
+ this.nblabels = nblabels;
+ }
+
+ @Override
+ protected void reduce(LongWritable key, Iterable<IntWritable> values, Context context)
+ throws IOException, InterruptedException {
+ int[] counts = new int[nblabels];
+ for (IntWritable value : values) {
+ counts[value.get()]++;
+ }
+
+ context.write(key, new Frequencies(key.get(), counts));
+ }
+ }
+
+ /**
+ * Output of the job
+ *
+ */
+ private static class Frequencies implements Writable, Comparable<Frequencies>, Cloneable {
+
+ /** first key of the partition used to sort the partitions */
+ private long firstId;
+
+ /** counts[c] = num tuples from the partition with label == c */
+ private int[] counts;
+
+ Frequencies() { }
+
+ Frequencies(long firstId, int[] counts) {
+ this.firstId = firstId;
+ this.counts = Arrays.copyOf(counts, counts.length);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ firstId = in.readLong();
+ counts = DFUtils.readIntArray(in);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeLong(firstId);
+ DFUtils.writeArray(out, counts);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ return other instanceof Frequencies && firstId == ((Frequencies) other).firstId;
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) firstId;
+ }
+
+ @Override
+ protected Frequencies clone() {
+ return new Frequencies(firstId, counts);
+ }
+
+ @Override
+ public int compareTo(Frequencies obj) {
+ if (firstId < obj.firstId) {
+ return -1;
+ } else if (firstId > obj.firstId) {
+ return 1;
+ } else {
+ return 0;
+ }
+ }
+
+ public static int[][] extractCounts(Frequencies[] partitions) {
+ int[][] counts = new int[partitions.length][];
+ for (int p = 0; p < partitions.length; p++) {
+ counts[p] = partitions[p].counts;
+ }
+ return counts;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java b/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
new file mode 100644
index 0000000..d82b383
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/TreeVisualizer.java
@@ -0,0 +1,263 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.lang.reflect.Field;
+import java.text.DecimalFormat;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.CategoricalNode;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.classifier.df.node.NumericalNode;
+
+/**
+ * This tool is to visualize the Decision tree
+ */
+public final class TreeVisualizer {
+
+ private TreeVisualizer() {}
+
+ private static String doubleToString(double value) {
+ DecimalFormat df = new DecimalFormat("0.##");
+ return df.format(value);
+ }
+
+ private static String toStringNode(Node node, Dataset dataset,
+ String[] attrNames, Map<String,Field> fields, int layer) {
+
+ StringBuilder buff = new StringBuilder();
+
+ try {
+ if (node instanceof CategoricalNode) {
+ CategoricalNode cnode = (CategoricalNode) node;
+ int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
+ double[] values = (double[]) fields.get("CategoricalNode.values").get(cnode);
+ Node[] childs = (Node[]) fields.get("CategoricalNode.childs").get(cnode);
+ String[][] attrValues = (String[][]) fields.get("Dataset.values").get(dataset);
+ for (int i = 0; i < attrValues[attr].length; i++) {
+ int index = ArrayUtils.indexOf(values, i);
+ if (index < 0) {
+ continue;
+ }
+ buff.append('\n');
+ for (int j = 0; j < layer; j++) {
+ buff.append("| ");
+ }
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+ .append(attrValues[attr][i]);
+ buff.append(toStringNode(childs[index], dataset, attrNames, fields, layer + 1));
+ }
+ } else if (node instanceof NumericalNode) {
+ NumericalNode nnode = (NumericalNode) node;
+ int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
+ double split = (Double) fields.get("NumericalNode.split").get(nnode);
+ Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
+ Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
+ buff.append('\n');
+ for (int j = 0; j < layer; j++) {
+ buff.append("| ");
+ }
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" < ")
+ .append(doubleToString(split));
+ buff.append(toStringNode(loChild, dataset, attrNames, fields, layer + 1));
+ buff.append('\n');
+ for (int j = 0; j < layer; j++) {
+ buff.append("| ");
+ }
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" >= ")
+ .append(doubleToString(split));
+ buff.append(toStringNode(hiChild, dataset, attrNames, fields, layer + 1));
+ } else if (node instanceof Leaf) {
+ Leaf leaf = (Leaf) node;
+ double label = (Double) fields.get("Leaf.label").get(leaf);
+ if (dataset.isNumerical(dataset.getLabelId())) {
+ buff.append(" : ").append(doubleToString(label));
+ } else {
+ buff.append(" : ").append(dataset.getLabelString(label));
+ }
+ }
+ } catch (IllegalAccessException iae) {
+ throw new IllegalStateException(iae);
+ }
+
+ return buff.toString();
+ }
+
+ private static Map<String,Field> getReflectMap() {
+ Map<String,Field> fields = new HashMap<String,Field>();
+
+ try {
+ Field m = CategoricalNode.class.getDeclaredField("attr");
+ m.setAccessible(true);
+ fields.put("CategoricalNode.attr", m);
+ m = CategoricalNode.class.getDeclaredField("values");
+ m.setAccessible(true);
+ fields.put("CategoricalNode.values", m);
+ m = CategoricalNode.class.getDeclaredField("childs");
+ m.setAccessible(true);
+ fields.put("CategoricalNode.childs", m);
+ m = NumericalNode.class.getDeclaredField("attr");
+ m.setAccessible(true);
+ fields.put("NumericalNode.attr", m);
+ m = NumericalNode.class.getDeclaredField("split");
+ m.setAccessible(true);
+ fields.put("NumericalNode.split", m);
+ m = NumericalNode.class.getDeclaredField("loChild");
+ m.setAccessible(true);
+ fields.put("NumericalNode.loChild", m);
+ m = NumericalNode.class.getDeclaredField("hiChild");
+ m.setAccessible(true);
+ fields.put("NumericalNode.hiChild", m);
+ m = Leaf.class.getDeclaredField("label");
+ m.setAccessible(true);
+ fields.put("Leaf.label", m);
+ m = Dataset.class.getDeclaredField("values");
+ m.setAccessible(true);
+ fields.put("Dataset.values", m);
+ } catch (NoSuchFieldException nsfe) {
+ throw new IllegalStateException(nsfe);
+ }
+
+ return fields;
+ }
+
+ /**
+ * Decision tree to String
+ *
+ * @param tree
+ * Node of tree
+ * @param attrNames
+ * attribute names
+ */
+ public static String toString(Node tree, Dataset dataset, String[] attrNames) {
+ return toStringNode(tree, dataset, attrNames, getReflectMap(), 0);
+ }
+
+ /**
+ * Print Decision tree
+ *
+ * @param tree
+ * Node of tree
+ * @param attrNames
+ * attribute names
+ */
+ public static void print(Node tree, Dataset dataset, String[] attrNames) {
+ System.out.println(toString(tree, dataset, attrNames));
+ }
+
+ private static String toStringPredict(Node node, Instance instance,
+ Dataset dataset, String[] attrNames, Map<String,Field> fields) {
+ StringBuilder buff = new StringBuilder();
+
+ try {
+ if (node instanceof CategoricalNode) {
+ CategoricalNode cnode = (CategoricalNode) node;
+ int attr = (Integer) fields.get("CategoricalNode.attr").get(cnode);
+ double[] values = (double[]) fields.get("CategoricalNode.values").get(
+ cnode);
+ Node[] childs = (Node[]) fields.get("CategoricalNode.childs")
+ .get(cnode);
+ String[][] attrValues = (String[][]) fields.get("Dataset.values").get(
+ dataset);
+
+ int index = ArrayUtils.indexOf(values, instance.get(attr));
+ if (index >= 0) {
+ buff.append(attrNames == null ? attr : attrNames[attr]).append(" = ")
+ .append(attrValues[attr][(int) instance.get(attr)]);
+ buff.append(" -> ");
+ buff.append(toStringPredict(childs[index], instance, dataset,
+ attrNames, fields));
+ }
+ } else if (node instanceof NumericalNode) {
+ NumericalNode nnode = (NumericalNode) node;
+ int attr = (Integer) fields.get("NumericalNode.attr").get(nnode);
+ double split = (Double) fields.get("NumericalNode.split").get(nnode);
+ Node loChild = (Node) fields.get("NumericalNode.loChild").get(nnode);
+ Node hiChild = (Node) fields.get("NumericalNode.hiChild").get(nnode);
+
+ if (instance.get(attr) < split) {
+ buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+ .append(" = ").append(doubleToString(instance.get(attr)))
+ .append(") < ").append(doubleToString(split));
+ buff.append(" -> ");
+ buff.append(toStringPredict(loChild, instance, dataset, attrNames,
+ fields));
+ } else {
+ buff.append('(').append(attrNames == null ? attr : attrNames[attr])
+ .append(" = ").append(doubleToString(instance.get(attr)))
+ .append(") >= ").append(doubleToString(split));
+ buff.append(" -> ");
+ buff.append(toStringPredict(hiChild, instance, dataset, attrNames,
+ fields));
+ }
+ } else if (node instanceof Leaf) {
+ Leaf leaf = (Leaf) node;
+ double label = (Double) fields.get("Leaf.label").get(leaf);
+ if (dataset.isNumerical(dataset.getLabelId())) {
+ buff.append(doubleToString(label));
+ } else {
+ buff.append(dataset.getLabelString(label));
+ }
+ }
+ } catch (IllegalAccessException iae) {
+ throw new IllegalStateException(iae);
+ }
+
+ return buff.toString();
+ }
+
+ /**
+ * Predict trace to String
+ *
+ * @param tree
+ * Node of tree
+ * @param attrNames
+ * attribute names
+ */
+ public static String[] predictTrace(Node tree, Data data, String[] attrNames) {
+ Map<String,Field> reflectMap = getReflectMap();
+ String[] prediction = new String[data.size()];
+ for (int i = 0; i < data.size(); i++) {
+ prediction[i] = toStringPredict(tree, data.get(i), data.getDataset(),
+ attrNames, reflectMap);
+ }
+ return prediction;
+ }
+
+ /**
+ * Print predict trace
+ *
+ * @param tree
+ * Node of tree
+ * @param attrNames
+ * attribute names
+ */
+ public static void predictTracePrint(Node tree, Data data, String[] attrNames) {
+ Map<String,Field> reflectMap = getReflectMap();
+ for (int i = 0; i < data.size(); i++) {
+ System.out.println(toStringPredict(tree, data.get(i), data.getDataset(),
+ attrNames, reflectMap));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java b/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
new file mode 100644
index 0000000..06876e1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java
@@ -0,0 +1,211 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Locale;
+import java.util.Random;
+import java.util.Scanner;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This tool is used to uniformly distribute the class of all the tuples of the dataset over a given number of
+ * partitions.<br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+public final class UDistrib {
+
+ private static final Logger log = LoggerFactory.getLogger(UDistrib.class);
+
+ private UDistrib() {}
+
+ /**
+ * Launch the uniform distribution tool. Requires the following command line arguments:<br>
+ *
+ * data : data path dataset : dataset path numpartitions : num partitions output : output path
+ *
+ * @throws java.io.IOException
+ */
+ public static void main(String[] args) throws IOException {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+ abuilder.withName("data").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+ abuilder.withName("dataset").withMinimum(1).create()).withDescription("Dataset path").create();
+
+ Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument(
+ abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Path to generated files").create();
+
+ Option partitionsOpt = obuilder.withLongName("numpartitions").withShortName("p").withRequired(true)
+ .withArgument(abuilder.withName("numparts").withMinimum(1).withMinimum(1).create()).withDescription(
+ "Number of partitions to create").create();
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(outputOpt).withOption(
+ datasetOpt).withOption(partitionsOpt).withOption(helpOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+
+ String data = cmdLine.getValue(dataOpt).toString();
+ String dataset = cmdLine.getValue(datasetOpt).toString();
+ int numPartitions = Integer.parseInt(cmdLine.getValue(partitionsOpt).toString());
+ String output = cmdLine.getValue(outputOpt).toString();
+
+ runTool(data, dataset, output, numPartitions);
+ } catch (OptionException e) {
+ log.warn(e.toString(), e);
+ CommandLineUtil.printHelp(group);
+ }
+
+ }
+
+ private static void runTool(String dataStr, String datasetStr, String output, int numPartitions) throws IOException {
+
+ Preconditions.checkArgument(numPartitions > 0, "numPartitions <= 0");
+
+ // make sure the output file does not exist
+ Path outputPath = new Path(output);
+ Configuration conf = new Configuration();
+ FileSystem fs = outputPath.getFileSystem(conf);
+
+ Preconditions.checkArgument(!fs.exists(outputPath), "Output path already exists");
+
+ // create a new file corresponding to each partition
+ // Path workingDir = fs.getWorkingDirectory();
+ // FileSystem wfs = workingDir.getFileSystem(conf);
+ // File parentFile = new File(workingDir.toString());
+ // File tempFile = FileUtil.createLocalTempFile(parentFile, "Parts", true);
+ // File tempFile = File.createTempFile("df.tools.UDistrib","");
+ // tempFile.deleteOnExit();
+ File tempFile = FileUtil.createLocalTempFile(new File(""), "df.tools.UDistrib", true);
+ Path partsPath = new Path(tempFile.toString());
+ FileSystem pfs = partsPath.getFileSystem(conf);
+
+ Path[] partPaths = new Path[numPartitions];
+ FSDataOutputStream[] files = new FSDataOutputStream[numPartitions];
+ for (int p = 0; p < numPartitions; p++) {
+ partPaths[p] = new Path(partsPath, String.format(Locale.ENGLISH, "part.%03d", p));
+ files[p] = pfs.create(partPaths[p]);
+ }
+
+ Path datasetPath = new Path(datasetStr);
+ Dataset dataset = Dataset.load(conf, datasetPath);
+
+ // currents[label] = next partition file where to place the tuple
+ int[] currents = new int[dataset.nblabels()];
+
+ // currents is initialized randomly in the range [0, numpartitions[
+ Random random = RandomUtils.getRandom();
+ for (int c = 0; c < currents.length; c++) {
+ currents[c] = random.nextInt(numPartitions);
+ }
+
+ // foreach tuple of the data
+ Path dataPath = new Path(dataStr);
+ FileSystem ifs = dataPath.getFileSystem(conf);
+ FSDataInputStream input = ifs.open(dataPath);
+ Scanner scanner = new Scanner(input, "UTF-8");
+ DataConverter converter = new DataConverter(dataset);
+
+ int id = 0;
+ while (scanner.hasNextLine()) {
+ if (id % 1000 == 0) {
+ log.info("progress : {}", id);
+ }
+
+ String line = scanner.nextLine();
+ if (line.isEmpty()) {
+ continue; // skip empty lines
+ }
+
+ // write the tuple in files[tuple.label]
+ Instance instance = converter.convert(line);
+ int label = (int) dataset.getLabel(instance);
+ files[currents[label]].writeBytes(line);
+ files[currents[label]].writeChar('\n');
+
+ // update currents
+ currents[label]++;
+ if (currents[label] == numPartitions) {
+ currents[label] = 0;
+ }
+ }
+
+ // close all the files.
+ scanner.close();
+ for (FSDataOutputStream file : files) {
+ Closeables.close(file, false);
+ }
+
+ // merge all output files
+ FileUtil.copyMerge(pfs, partsPath, fs, outputPath, true, conf, null);
+ /*
+ * FSDataOutputStream joined = fs.create(new Path(outputPath, "uniform.data")); for (int p = 0; p <
+ * numPartitions; p++) {log.info("Joining part : {}", p); FSDataInputStream partStream =
+ * fs.open(partPaths[p]);
+ *
+ * IOUtils.copyBytes(partStream, joined, conf, false);
+ *
+ * partStream.close(); }
+ *
+ * joined.close();
+ *
+ * fs.delete(partsPath, true);
+ */
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java b/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
new file mode 100644
index 0000000..049f9bf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/evaluation/Auc.java
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.evaluation;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.list.DoubleArrayList;
+
+import com.google.common.base.Preconditions;
+
+import java.util.Random;
+
+/**
+ * Computes AUC and a few other accuracy statistics without storing huge amounts of data. This is
+ * done by keeping uniform samples of the positive and negative scores. Then, when AUC is to be
+ * computed, the remaining scores are sorted and a rank-sum statistic is used to compute the AUC.
+ * Since AUC is invariant with respect to down-sampling of either positives or negatives, this is
+ * close to correct and is exactly correct if maxBufferSize or fewer positive and negative scores
+ * are examined.
+ */
+public class Auc {
+
+ private int maxBufferSize = 10000;
+ private final DoubleArrayList[] scores = {new DoubleArrayList(), new DoubleArrayList()};
+ private final Random rand;
+ private int samples;
+ private final double threshold;
+ private final Matrix confusion;
+ private final DenseMatrix entropy;
+
+ private boolean probabilityScore = true;
+
+ private boolean hasScore;
+
+ /**
+ * Allocates a new data-structure for accumulating information about AUC and a few other accuracy
+ * measures.
+ * @param threshold The threshold to use in computing the confusion matrix.
+ */
+ public Auc(double threshold) {
+ confusion = new DenseMatrix(2, 2);
+ entropy = new DenseMatrix(2, 2);
+ this.rand = RandomUtils.getRandom();
+ this.threshold = threshold;
+ }
+
+ public Auc() {
+ this(0.5);
+ }
+
+ /**
+ * Adds a score to the AUC buffers.
+ *
+ * @param trueValue Whether this score is for a true-positive or a true-negative example.
+ * @param score The score for this example.
+ */
+ public void add(int trueValue, double score) {
+ Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1");
+ hasScore = true;
+
+ int predictedClass = score > threshold ? 1 : 0;
+ confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1);
+
+ samples++;
+ if (isProbabilityScore()) {
+ double limited = Math.max(1.0e-20, Math.min(score, 1 - 1.0e-20));
+ double v0 = entropy.get(trueValue, 0);
+ entropy.set(trueValue, 0, (Math.log1p(-limited) - v0) / samples + v0);
+
+ double v1 = entropy.get(trueValue, 1);
+ entropy.set(trueValue, 1, (Math.log(limited) - v1) / samples + v1);
+ }
+
+ // add to buffers
+ DoubleArrayList buf = scores[trueValue];
+ if (buf.size() >= maxBufferSize) {
+ // but if too many points are seen, we insert into a random
+ // place and discard the predecessor. The random place could
+ // be anywhere, possibly not even in the buffer.
+ // this is a special case of Knuth's permutation algorithm
+ // but since we don't ever shuffle the first maxBufferSize
+ // samples, the result isn't just a fair sample of the prefixes
+ // of all permutations. The CONTENTs of the result, however,
+ // will be a fair and uniform sample of maxBufferSize elements
+ // chosen from all elements without replacement
+ int index = rand.nextInt(samples);
+ if (index < buf.size()) {
+ buf.set(index, score);
+ }
+ } else {
+ // for small buffers, we collect all points without permuting
+ // since we sort the data later, permuting now would just be
+ // pedantic
+ buf.add(score);
+ }
+ }
+
+ public void add(int trueValue, int predictedClass) {
+ hasScore = false;
+ Preconditions.checkArgument(trueValue == 0 || trueValue == 1, "True value must be 0 or 1");
+ confusion.set(trueValue, predictedClass, confusion.get(trueValue, predictedClass) + 1);
+ }
+
+ /**
+ * Computes the AUC of points seen so far. This can be moderately expensive since it requires
+ * that all points that have been retained be sorted.
+ *
+ * @return The value of the Area Under the receiver operating Curve.
+ */
+ public double auc() {
+ Preconditions.checkArgument(hasScore, "Can't compute AUC for classifier without a score");
+ scores[0].sort();
+ scores[1].sort();
+
+ double n0 = scores[0].size();
+ double n1 = scores[1].size();
+
+ if (n0 == 0 || n1 == 0) {
+ return 0.5;
+ }
+
+ // scan the data
+ int i0 = 0;
+ int i1 = 0;
+ int rank = 1;
+ double rankSum = 0;
+ while (i0 < n0 && i1 < n1) {
+
+ double v0 = scores[0].get(i0);
+ double v1 = scores[1].get(i1);
+
+ if (v0 < v1) {
+ i0++;
+ rank++;
+ } else if (v1 < v0) {
+ i1++;
+ rankSum += rank;
+ rank++;
+ } else {
+ // ties have to be handled delicately
+ double tieScore = v0;
+
+ // how many negatives are tied?
+ int k0 = 0;
+ while (i0 < n0 && scores[0].get(i0) == tieScore) {
+ k0++;
+ i0++;
+ }
+
+ // and how many positives
+ int k1 = 0;
+ while (i1 < n1 && scores[1].get(i1) == tieScore) {
+ k1++;
+ i1++;
+ }
+
+ // we found k0 + k1 tied values which have
+ // ranks in the half open interval [rank, rank + k0 + k1)
+ // the average rank is assigned to all
+ rankSum += (rank + (k0 + k1 - 1) / 2.0) * k1;
+ rank += k0 + k1;
+ }
+ }
+
+ if (i1 < n1) {
+ rankSum += (rank + (n1 - i1 - 1) / 2.0) * (n1 - i1);
+ rank += (int) (n1 - i1);
+ }
+
+ return (rankSum / n1 - (n1 + 1) / 2) / n0;
+ }
+
+ /**
+ * Returns the confusion matrix for the classifier supposing that we were to use a particular
+ * threshold.
+ * @return The confusion matrix.
+ */
+ public Matrix confusion() {
+ return confusion;
+ }
+
+ /**
+ * Returns a matrix related to the confusion matrix and to the log-likelihood. For a
+ * pretty accurate classifier, N + entropy is nearly the same as the confusion matrix
+ * because log(1-eps) \approx -eps if eps is small.
+ *
+ * For lower accuracy classifiers, this measure will give us a better picture of how
+ * things work our.
+ *
+ * Also, by definition, log-likelihood = sum(diag(entropy))
+ * @return Returns a cell by cell break-down of the log-likelihood
+ */
+ public Matrix entropy() {
+ if (!hasScore) {
+ // find a constant score that would optimize log-likelihood, but use a dash of Bayesian
+ // conservatism to avoid dividing by zero or taking log(0)
+ double p = (0.5 + confusion.get(1, 1)) / (1 + confusion.get(0, 0) + confusion.get(1, 1));
+ entropy.set(0, 0, confusion.get(0, 0) * Math.log1p(-p));
+ entropy.set(0, 1, confusion.get(0, 1) * Math.log(p));
+ entropy.set(1, 0, confusion.get(1, 0) * Math.log1p(-p));
+ entropy.set(1, 1, confusion.get(1, 1) * Math.log(p));
+ }
+ return entropy;
+ }
+
+ public void setMaxBufferSize(int maxBufferSize) {
+ this.maxBufferSize = maxBufferSize;
+ }
+
+ public boolean isProbabilityScore() {
+ return probabilityScore;
+ }
+
+ public void setProbabilityScore(boolean probabilityScore) {
+ this.probabilityScore = probabilityScore;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java
new file mode 100644
index 0000000..d3e9ff3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/MultilayerPerceptron.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+
+import java.io.IOException;
+
+/**
+ * A Multilayer Perceptron (MLP) is a kind of feed-forward artificial neural
+ * network, which is a mathematical model inspired by the biological neural
+ * network. The Multilayer Perceptron can be used for various machine learning
+ * tasks such as classification and regression.
+ *
+ * A detailed introduction about MLP can be found at
+ * http://ufldl.stanford.edu/wiki/index.php/Neural_Networks.
+ *
+ * For this particular implementation, the users can freely control the topology
+ * of the MLP, including: 1. The size of the input layer; 2. The number of
+ * hidden layers; 3. The size of each hidden layer; 4. The size of the output
+ * layer. 5. The cost function. 6. The squashing function.
+ *
+ * The model is trained in an online learning approach, where the weights of
+ * neurons in the MLP is updated incremented using backPropagation algorithm
+ * proposed by (Rumelhart, D. E., Hinton, G. E., and Williams, R. J. (1986)
+ * Learning representations by back-propagating errors. Nature, 323, 533--536.)
+ */
+public class MultilayerPerceptron extends NeuralNetwork implements OnlineLearner {
+
+ /**
+ * The default constructor.
+ */
+ public MultilayerPerceptron() {
+ super();
+ }
+
+ /**
+ * Initialize the MLP by specifying the location of the model.
+ *
+ * @param modelPath The path of the model.
+ */
+ public MultilayerPerceptron(String modelPath) throws IOException {
+ super(modelPath);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ // construct the training instance, where append the actual to instance
+ Vector trainingInstance = new DenseVector(instance.size() + 1);
+ for (int i = 0; i < instance.size(); ++i) {
+ trainingInstance.setQuick(i, instance.getQuick(i));
+ }
+ trainingInstance.setQuick(instance.size(), actual);
+ this.trainOnline(trainingInstance);
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual,
+ Vector instance) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void close() {
+ // DO NOTHING
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java
new file mode 100644
index 0000000..cfbe5c4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetwork.java
@@ -0,0 +1,743 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.WritableUtils;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+/**
+ * AbstractNeuralNetwork defines the general operations for a neural network
+ * based model. Typically, all derivative models such as Multilayer Perceptron
+ * and Autoencoder consist of neurons and the weights between neurons.
+ */
+public abstract class NeuralNetwork {
+
+ private static final Logger log = LoggerFactory.getLogger(NeuralNetwork.class);
+
+ /* The default learning rate */
+ public static final double DEFAULT_LEARNING_RATE = 0.5;
+ /* The default regularization weight */
+ public static final double DEFAULT_REGULARIZATION_WEIGHT = 0;
+ /* The default momentum weight */
+ public static final double DEFAULT_MOMENTUM_WEIGHT = 0.1;
+
+ public static enum TrainingMethod { GRADIENT_DESCENT }
+
+ /* The name of the model */
+ protected String modelType;
+
+ /* The path to store the model */
+ protected String modelPath;
+
+ protected double learningRate;
+
+ /* The weight of regularization */
+ protected double regularizationWeight;
+
+ /* The momentum weight */
+ protected double momentumWeight;
+
+ /* The cost function of the model */
+ protected String costFunctionName;
+
+ /* Record the size of each layer */
+ protected List<Integer> layerSizeList;
+
+ /* Training method used for training the model */
+ protected TrainingMethod trainingMethod;
+
+ /* Weights between neurons at adjacent layers */
+ protected List<Matrix> weightMatrixList;
+
+ /* Previous weight updates between neurons at adjacent layers */
+ protected List<Matrix> prevWeightUpdatesList;
+
+ /* Different layers can have different squashing function */
+ protected List<String> squashingFunctionList;
+
+ /* The index of the final layer */
+ protected int finalLayerIndex;
+
+ /**
+ * The default constructor that initializes the learning rate, regularization
+ * weight, and momentum weight by default.
+ */
+ public NeuralNetwork() {
+ log.info("Initialize model...");
+ learningRate = DEFAULT_LEARNING_RATE;
+ regularizationWeight = DEFAULT_REGULARIZATION_WEIGHT;
+ momentumWeight = DEFAULT_MOMENTUM_WEIGHT;
+ trainingMethod = TrainingMethod.GRADIENT_DESCENT;
+ costFunctionName = "Minus_Squared";
+ modelType = getClass().getSimpleName();
+
+ layerSizeList = Lists.newArrayList();
+ layerSizeList = Lists.newArrayList();
+ weightMatrixList = Lists.newArrayList();
+ prevWeightUpdatesList = Lists.newArrayList();
+ squashingFunctionList = Lists.newArrayList();
+ }
+
+ /**
+ * Initialize the NeuralNetwork by specifying learning rate, momentum weight
+ * and regularization weight.
+ *
+ * @param learningRate The learning rate.
+ * @param momentumWeight The momentum weight.
+ * @param regularizationWeight The regularization weight.
+ */
+ public NeuralNetwork(double learningRate, double momentumWeight, double regularizationWeight) {
+ this();
+ setLearningRate(learningRate);
+ setMomentumWeight(momentumWeight);
+ setRegularizationWeight(regularizationWeight);
+ }
+
+ /**
+ * Initialize the NeuralNetwork by specifying the location of the model.
+ *
+ * @param modelPath The location that the model is stored.
+ */
+ public NeuralNetwork(String modelPath) throws IOException {
+ this.modelPath = modelPath;
+ readFromModel();
+ }
+
+ /**
+ * Get the type of the model.
+ *
+ * @return The name of the model.
+ */
+ public String getModelType() {
+ return this.modelType;
+ }
+
+ /**
+ * Set the degree of aggression during model training, a large learning rate
+ * can increase the training speed, but it also decreases the chance of model
+ * converge.
+ *
+ * @param learningRate Learning rate must be a non-negative value. Recommend in range (0, 0.5).
+ * @return The model instance.
+ */
+ public final NeuralNetwork setLearningRate(double learningRate) {
+ Preconditions.checkArgument(learningRate > 0, "Learning rate must be larger than 0.");
+ this.learningRate = learningRate;
+ return this;
+ }
+
+ /**
+ * Get the value of learning rate.
+ *
+ * @return The value of learning rate.
+ */
+ public double getLearningRate() {
+ return learningRate;
+ }
+
+ /**
+ * Set the regularization weight. More complex the model is, less weight the
+ * regularization is.
+ *
+ * @param regularizationWeight regularization must be in the range [0, 0.1).
+ * @return The model instance.
+ */
+ public final NeuralNetwork setRegularizationWeight(double regularizationWeight) {
+ Preconditions.checkArgument(regularizationWeight >= 0
+ && regularizationWeight < 0.1, "Regularization weight must be in range [0, 0.1)");
+ this.regularizationWeight = regularizationWeight;
+ return this;
+ }
+
+ /**
+ * Get the weight of the regularization.
+ *
+ * @return The weight of regularization.
+ */
+ public double getRegularizationWeight() {
+ return regularizationWeight;
+ }
+
+ /**
+ * Set the momentum weight for the model.
+ *
+ * @param momentumWeight momentumWeight must be in range [0, 0.5].
+ * @return The model instance.
+ */
+ public final NeuralNetwork setMomentumWeight(double momentumWeight) {
+ Preconditions.checkArgument(momentumWeight >= 0 && momentumWeight <= 1.0,
+ "Momentum weight must be in range [0, 1.0]");
+ this.momentumWeight = momentumWeight;
+ return this;
+ }
+
+ /**
+ * Get the momentum weight.
+ *
+ * @return The value of momentum.
+ */
+ public double getMomentumWeight() {
+ return momentumWeight;
+ }
+
+ /**
+ * Set the training method.
+ *
+ * @param method The training method, currently supports GRADIENT_DESCENT.
+ * @return The instance of the model.
+ */
+ public NeuralNetwork setTrainingMethod(TrainingMethod method) {
+ this.trainingMethod = method;
+ return this;
+ }
+
+ /**
+ * Get the training method.
+ *
+ * @return The training method enumeration.
+ */
+ public TrainingMethod getTrainingMethod() {
+ return trainingMethod;
+ }
+
+ /**
+ * Set the cost function for the model.
+ *
+ * @param costFunction the name of the cost function. Currently supports
+ * "Minus_Squared", "Cross_Entropy".
+ */
+ public NeuralNetwork setCostFunction(String costFunction) {
+ this.costFunctionName = costFunction;
+ return this;
+ }
+
+ /**
+ * Add a layer of neurons with specified size. If the added layer is not the
+ * first layer, it will automatically connect the neurons between with the
+ * previous layer.
+ *
+ * @param size The size of the layer. (bias neuron excluded)
+ * @param isFinalLayer If false, add a bias neuron.
+ * @param squashingFunctionName The squashing function for this layer, input
+ * layer is f(x) = x by default.
+ * @return The layer index, starts with 0.
+ */
+ public int addLayer(int size, boolean isFinalLayer, String squashingFunctionName) {
+ Preconditions.checkArgument(size > 0, "Size of layer must be larger than 0.");
+ log.info("Add layer with size {} and squashing function {}", size, squashingFunctionName);
+ int actualSize = size;
+ if (!isFinalLayer) {
+ actualSize += 1;
+ }
+
+ layerSizeList.add(actualSize);
+ int layerIndex = layerSizeList.size() - 1;
+ if (isFinalLayer) {
+ finalLayerIndex = layerIndex;
+ }
+
+ // Add weights between current layer and previous layer, and input layer has no squashing function
+ if (layerIndex > 0) {
+ int sizePrevLayer = layerSizeList.get(layerIndex - 1);
+ // Row count equals to size of current size and column count equal to size of previous layer
+ int row = isFinalLayer ? actualSize : actualSize - 1;
+ Matrix weightMatrix = new DenseMatrix(row, sizePrevLayer);
+ // Initialize weights
+ final RandomWrapper rnd = RandomUtils.getRandom();
+ weightMatrix.assign(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return rnd.nextDouble() - 0.5;
+ }
+ });
+ weightMatrixList.add(weightMatrix);
+ prevWeightUpdatesList.add(new DenseMatrix(row, sizePrevLayer));
+ squashingFunctionList.add(squashingFunctionName);
+ }
+ return layerIndex;
+ }
+
+ /**
+ * Get the size of a particular layer.
+ *
+ * @param layer The index of the layer, starting from 0.
+ * @return The size of the corresponding layer.
+ */
+ public int getLayerSize(int layer) {
+ Preconditions.checkArgument(layer >= 0 && layer < this.layerSizeList.size(),
+ String.format("Input must be in range [0, %d]\n", this.layerSizeList.size() - 1));
+ return layerSizeList.get(layer);
+ }
+
+ /**
+ * Get the layer size list.
+ *
+ * @return The sizes of the layers.
+ */
+ protected List<Integer> getLayerSizeList() {
+ return layerSizeList;
+ }
+
+ /**
+ * Get the weights between layer layerIndex and layerIndex + 1
+ *
+ * @param layerIndex The index of the layer.
+ * @return The weights in form of {@link Matrix}.
+ */
+ public Matrix getWeightsByLayer(int layerIndex) {
+ return weightMatrixList.get(layerIndex);
+ }
+
+ /**
+ * Update the weight matrices with given matrices.
+ *
+ * @param matrices The weight matrices, must be the same dimension as the
+ * existing matrices.
+ */
+ public void updateWeightMatrices(Matrix[] matrices) {
+ for (int i = 0; i < matrices.length; ++i) {
+ Matrix matrix = weightMatrixList.get(i);
+ weightMatrixList.set(i, matrix.plus(matrices[i]));
+ }
+ }
+
+ /**
+ * Set the weight matrices.
+ *
+ * @param matrices The weight matrices, must be the same dimension of the
+ * existing matrices.
+ */
+ public void setWeightMatrices(Matrix[] matrices) {
+ weightMatrixList = Lists.newArrayList();
+ Collections.addAll(weightMatrixList, matrices);
+ }
+
+ /**
+ * Set the weight matrix for a specified layer.
+ *
+ * @param index The index of the matrix, starting from 0 (between layer 0 and 1).
+ * @param matrix The instance of {@link Matrix}.
+ */
+ public void setWeightMatrix(int index, Matrix matrix) {
+ Preconditions.checkArgument(0 <= index && index < weightMatrixList.size(),
+ String.format("index [%s] should be in range [%s, %s).", index, 0, weightMatrixList.size()));
+ weightMatrixList.set(index, matrix);
+ }
+
+ /**
+ * Get all the weight matrices.
+ *
+ * @return The weight matrices.
+ */
+ public Matrix[] getWeightMatrices() {
+ Matrix[] matrices = new Matrix[weightMatrixList.size()];
+ weightMatrixList.toArray(matrices);
+ return matrices;
+ }
+
+ /**
+ * Get the output calculated by the model.
+ *
+ * @param instance The feature instance in form of {@link Vector}, each dimension contains the value of the corresponding feature.
+ * @return The output vector.
+ */
+ public Vector getOutput(Vector instance) {
+ Preconditions.checkArgument(layerSizeList.get(0) == instance.size() + 1,
+ String.format("The dimension of input instance should be %d, but the input has dimension %d.",
+ layerSizeList.get(0) - 1, instance.size()));
+
+ // add bias feature
+ Vector instanceWithBias = new DenseVector(instance.size() + 1);
+ // set bias to be a little bit less than 1.0
+ instanceWithBias.set(0, 0.99999);
+ for (int i = 1; i < instanceWithBias.size(); ++i) {
+ instanceWithBias.set(i, instance.get(i - 1));
+ }
+
+ List<Vector> outputCache = getOutputInternal(instanceWithBias);
+ // return the output of the last layer
+ Vector result = outputCache.get(outputCache.size() - 1);
+ // remove bias
+ return result.viewPart(1, result.size() - 1);
+ }
+
+ /**
+ * Calculate output internally, the intermediate output of each layer will be
+ * stored.
+ *
+ * @param instance The feature instance in form of {@link Vector}, each dimension contains the value of the corresponding feature.
+ * @return Cached output of each layer.
+ */
+ protected List<Vector> getOutputInternal(Vector instance) {
+ List<Vector> outputCache = Lists.newArrayList();
+ // fill with instance
+ Vector intermediateOutput = instance;
+ outputCache.add(intermediateOutput);
+
+ for (int i = 0; i < layerSizeList.size() - 1; ++i) {
+ intermediateOutput = forward(i, intermediateOutput);
+ outputCache.add(intermediateOutput);
+ }
+ return outputCache;
+ }
+
+ /**
+ * Forward the calculation for one layer.
+ *
+ * @param fromLayer The index of the previous layer.
+ * @param intermediateOutput The intermediate output of previous layer.
+ * @return The intermediate results of the current layer.
+ */
+ protected Vector forward(int fromLayer, Vector intermediateOutput) {
+ Matrix weightMatrix = weightMatrixList.get(fromLayer);
+
+ Vector vec = weightMatrix.times(intermediateOutput);
+ vec = vec.assign(NeuralNetworkFunctions.getDoubleFunction(squashingFunctionList.get(fromLayer)));
+
+ // add bias
+ Vector vecWithBias = new DenseVector(vec.size() + 1);
+ vecWithBias.set(0, 1);
+ for (int i = 0; i < vec.size(); ++i) {
+ vecWithBias.set(i + 1, vec.get(i));
+ }
+ return vecWithBias;
+ }
+
+ /**
+ * Train the neural network incrementally with given training instance.
+ *
+ * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals
+ * to the size of the input layer (bias neuron excluded) + the size
+ * of the output layer (a.k.a. the dimension of the labels).
+ */
+ public void trainOnline(Vector trainingInstance) {
+ Matrix[] matrices = trainByInstance(trainingInstance);
+ updateWeightMatrices(matrices);
+ }
+
+ /**
+ * Get the updated weights using one training instance.
+ *
+ * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals
+ * to the size of the input layer (bias neuron excluded) + the size
+ * of the output layer (a.k.a. the dimension of the labels).
+ * @return The update of each weight, in form of {@link Matrix} list.
+ */
+ public Matrix[] trainByInstance(Vector trainingInstance) {
+ // validate training instance
+ int inputDimension = layerSizeList.get(0) - 1;
+ int outputDimension = layerSizeList.get(this.layerSizeList.size() - 1);
+ Preconditions.checkArgument(inputDimension + outputDimension == trainingInstance.size(),
+ String.format("The dimension of training instance is %d, but requires %d.", trainingInstance.size(),
+ inputDimension + outputDimension));
+
+ if (trainingMethod.equals(TrainingMethod.GRADIENT_DESCENT)) {
+ return trainByInstanceGradientDescent(trainingInstance);
+ }
+ throw new IllegalArgumentException("Training method is not supported.");
+ }
+
+ /**
+ * Train by gradient descent. Get the updated weights using one training
+ * instance.
+ *
+ * @param trainingInstance An training instance, including the features and the label(s). Its dimension must equals
+ * to the size of the input layer (bias neuron excluded) + the size
+ * of the output layer (a.k.a. the dimension of the labels).
+ * @return The weight update matrices.
+ */
+ private Matrix[] trainByInstanceGradientDescent(Vector trainingInstance) {
+ int inputDimension = layerSizeList.get(0) - 1;
+
+ Vector inputInstance = new DenseVector(layerSizeList.get(0));
+ inputInstance.set(0, 1); // add bias
+ for (int i = 0; i < inputDimension; ++i) {
+ inputInstance.set(i + 1, trainingInstance.get(i));
+ }
+
+ Vector labels =
+ trainingInstance.viewPart(inputInstance.size() - 1, trainingInstance.size() - inputInstance.size() + 1);
+
+ // initialize weight update matrices
+ Matrix[] weightUpdateMatrices = new Matrix[weightMatrixList.size()];
+ for (int m = 0; m < weightUpdateMatrices.length; ++m) {
+ weightUpdateMatrices[m] =
+ new DenseMatrix(weightMatrixList.get(m).rowSize(), weightMatrixList.get(m).columnSize());
+ }
+
+ List<Vector> internalResults = getOutputInternal(inputInstance);
+
+ Vector deltaVec = new DenseVector(layerSizeList.get(layerSizeList.size() - 1));
+ Vector output = internalResults.get(internalResults.size() - 1);
+
+ final DoubleFunction derivativeSquashingFunction =
+ NeuralNetworkFunctions.getDerivativeDoubleFunction(squashingFunctionList.get(squashingFunctionList.size() - 1));
+
+ final DoubleDoubleFunction costFunction =
+ NeuralNetworkFunctions.getDerivativeDoubleDoubleFunction(costFunctionName);
+
+ Matrix lastWeightMatrix = weightMatrixList.get(weightMatrixList.size() - 1);
+
+ for (int i = 0; i < deltaVec.size(); ++i) {
+ double costFuncDerivative = costFunction.apply(labels.get(i), output.get(i + 1));
+ // Add regularization
+ costFuncDerivative += regularizationWeight * lastWeightMatrix.viewRow(i).zSum();
+ deltaVec.set(i, costFuncDerivative);
+ deltaVec.set(i, deltaVec.get(i) * derivativeSquashingFunction.apply(output.get(i + 1)));
+ }
+
+ // Start from previous layer of output layer
+ for (int layer = layerSizeList.size() - 2; layer >= 0; --layer) {
+ deltaVec = backPropagate(layer, deltaVec, internalResults, weightUpdateMatrices[layer]);
+ }
+
+ prevWeightUpdatesList = Arrays.asList(weightUpdateMatrices);
+
+ return weightUpdateMatrices;
+ }
+
+ /**
+ * Back-propagate the errors to from next layer to current layer. The weight
+ * updated information will be stored in the weightUpdateMatrices, and the
+ * delta of the prevLayer will be returned.
+ *
+ * @param currentLayerIndex Index of current layer.
+ * @param nextLayerDelta Delta of next layer.
+ * @param outputCache The output cache to store intermediate results.
+ * @param weightUpdateMatrix The weight update, in form of {@link Matrix}.
+ * @return The weight updates.
+ */
+ private Vector backPropagate(int currentLayerIndex, Vector nextLayerDelta,
+ List<Vector> outputCache, Matrix weightUpdateMatrix) {
+
+ // Get layer related information
+ final DoubleFunction derivativeSquashingFunction =
+ NeuralNetworkFunctions.getDerivativeDoubleFunction(squashingFunctionList.get(currentLayerIndex));
+ Vector curLayerOutput = outputCache.get(currentLayerIndex);
+ Matrix weightMatrix = weightMatrixList.get(currentLayerIndex);
+ Matrix prevWeightMatrix = prevWeightUpdatesList.get(currentLayerIndex);
+
+ // Next layer is not output layer, remove the delta of bias neuron
+ if (currentLayerIndex != layerSizeList.size() - 2) {
+ nextLayerDelta = nextLayerDelta.viewPart(1, nextLayerDelta.size() - 1);
+ }
+
+ Vector delta = weightMatrix.transpose().times(nextLayerDelta);
+
+ delta = delta.assign(curLayerOutput, new DoubleDoubleFunction() {
+ @Override
+ public double apply(double deltaElem, double curLayerOutputElem) {
+ return deltaElem * derivativeSquashingFunction.apply(curLayerOutputElem);
+ }
+ });
+
+ // Update weights
+ for (int i = 0; i < weightUpdateMatrix.rowSize(); ++i) {
+ for (int j = 0; j < weightUpdateMatrix.columnSize(); ++j) {
+ weightUpdateMatrix.set(i, j, -learningRate * nextLayerDelta.get(i) *
+ curLayerOutput.get(j) + this.momentumWeight * prevWeightMatrix.get(i, j));
+ }
+ }
+
+ return delta;
+ }
+
+ /**
+ * Read the model meta-data from the specified location.
+ *
+ * @throws IOException
+ */
+ protected void readFromModel() throws IOException {
+ log.info("Load model from {}", modelPath);
+ Preconditions.checkArgument(modelPath != null, "Model path has not been set.");
+ FSDataInputStream is = null;
+ try {
+ Path path = new Path(modelPath);
+ FileSystem fs = path.getFileSystem(new Configuration());
+ is = new FSDataInputStream(fs.open(path));
+ readFields(is);
+ } finally {
+ Closeables.close(is, true);
+ }
+ }
+
+ /**
+ * Write the model data to specified location.
+ *
+ * @throws IOException
+ */
+ public void writeModelToFile() throws IOException {
+ log.info("Write model to {}.", modelPath);
+ Preconditions.checkArgument(modelPath != null, "Model path has not been set.");
+ FSDataOutputStream stream = null;
+ try {
+ Path path = new Path(modelPath);
+ FileSystem fs = path.getFileSystem(new Configuration());
+ stream = fs.create(path, true);
+ write(stream);
+ } finally {
+ Closeables.close(stream, false);
+ }
+ }
+
+ /**
+ * Set the model path.
+ *
+ * @param modelPath The path of the model.
+ */
+ public void setModelPath(String modelPath) {
+ this.modelPath = modelPath;
+ }
+
+ /**
+ * Get the model path.
+ *
+ * @return The path of the model.
+ */
+ public String getModelPath() {
+ return modelPath;
+ }
+
+ /**
+ * Write the fields of the model to output.
+ *
+ * @param output The output instance.
+ * @throws IOException
+ */
+ public void write(DataOutput output) throws IOException {
+ // Write model type
+ WritableUtils.writeString(output, modelType);
+ // Write learning rate
+ output.writeDouble(learningRate);
+ // Write model path
+ if (modelPath != null) {
+ WritableUtils.writeString(output, modelPath);
+ } else {
+ WritableUtils.writeString(output, "null");
+ }
+
+ // Write regularization weight
+ output.writeDouble(regularizationWeight);
+ // Write momentum weight
+ output.writeDouble(momentumWeight);
+
+ // Write cost function
+ WritableUtils.writeString(output, costFunctionName);
+
+ // Write layer size list
+ output.writeInt(layerSizeList.size());
+ for (Integer aLayerSizeList : layerSizeList) {
+ output.writeInt(aLayerSizeList);
+ }
+
+ WritableUtils.writeEnum(output, trainingMethod);
+
+ // Write squashing functions
+ output.writeInt(squashingFunctionList.size());
+ for (String aSquashingFunctionList : squashingFunctionList) {
+ WritableUtils.writeString(output, aSquashingFunctionList);
+ }
+
+ // Write weight matrices
+ output.writeInt(this.weightMatrixList.size());
+ for (Matrix aWeightMatrixList : weightMatrixList) {
+ MatrixWritable.writeMatrix(output, aWeightMatrixList);
+ }
+ }
+
+ /**
+ * Read the fields of the model from input.
+ *
+ * @param input The input instance.
+ * @throws IOException
+ */
+ public void readFields(DataInput input) throws IOException {
+ // Read model type
+ modelType = WritableUtils.readString(input);
+ if (!modelType.equals(this.getClass().getSimpleName())) {
+ throw new IllegalArgumentException("The specified location does not contains the valid NeuralNetwork model.");
+ }
+ // Read learning rate
+ learningRate = input.readDouble();
+ // Read model path
+ modelPath = WritableUtils.readString(input);
+ if (modelPath.equals("null")) {
+ modelPath = null;
+ }
+
+ // Read regularization weight
+ regularizationWeight = input.readDouble();
+ // Read momentum weight
+ momentumWeight = input.readDouble();
+
+ // Read cost function
+ costFunctionName = WritableUtils.readString(input);
+
+ // Read layer size list
+ int numLayers = input.readInt();
+ layerSizeList = Lists.newArrayList();
+ for (int i = 0; i < numLayers; i++) {
+ layerSizeList.add(input.readInt());
+ }
+
+ trainingMethod = WritableUtils.readEnum(input, TrainingMethod.class);
+
+ // Read squash functions
+ int squashingFunctionSize = input.readInt();
+ squashingFunctionList = Lists.newArrayList();
+ for (int i = 0; i < squashingFunctionSize; i++) {
+ squashingFunctionList.add(WritableUtils.readString(input));
+ }
+
+ // Read weights and construct matrices of previous updates
+ int numOfMatrices = input.readInt();
+ weightMatrixList = Lists.newArrayList();
+ prevWeightUpdatesList = Lists.newArrayList();
+ for (int i = 0; i < numOfMatrices; i++) {
+ Matrix matrix = MatrixWritable.readMatrix(input);
+ weightMatrixList.add(matrix);
+ prevWeightUpdatesList.add(new DenseMatrix(matrix.rowSize(), matrix.columnSize()));
+ }
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java
new file mode 100644
index 0000000..8fd0176
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/NeuralNetworkFunctions.java
@@ -0,0 +1,150 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * The functions that will be used by NeuralNetwork.
+ */
+public class NeuralNetworkFunctions {
+
+ /**
+ * The derivation of identity function (f(x) = x).
+ */
+ public static DoubleFunction derivativeIdentityFunction = new DoubleFunction() {
+ @Override
+ public double apply(double x) {
+ return 1;
+ }
+ };
+
+ /**
+ * The derivation of minus squared function (f(t, o) = (o - t)^2).
+ */
+ public static DoubleDoubleFunction derivativeMinusSquared = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double target, double output) {
+ return 2 * (output - target);
+ }
+ };
+
+ /**
+ * The cross entropy function (f(t, o) = -t * log(o) - (1 - t) * log(1 - o)).
+ */
+ public static DoubleDoubleFunction crossEntropy = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double target, double output) {
+ return -target * Math.log(output) - (1 - target) * Math.log(1 - output);
+ }
+ };
+
+ /**
+ * The derivation of cross entropy function (f(t, o) = -t * log(o) - (1 - t) *
+ * log(1 - o)).
+ */
+ public static DoubleDoubleFunction derivativeCrossEntropy = new DoubleDoubleFunction() {
+ @Override
+ public double apply(double target, double output) {
+ double adjustedTarget = target;
+ double adjustedActual = output;
+ if (adjustedActual == 1) {
+ adjustedActual = 0.999;
+ } else if (output == 0) {
+ adjustedActual = 0.001;
+ }
+ if (adjustedTarget == 1) {
+ adjustedTarget = 0.999;
+ } else if (adjustedTarget == 0) {
+ adjustedTarget = 0.001;
+ }
+ return -adjustedTarget / adjustedActual + (1 - adjustedTarget) / (1 - adjustedActual);
+ }
+ };
+
+ /**
+ * Get the corresponding function by its name.
+ * Currently supports: "Identity", "Sigmoid".
+ *
+ * @param function The name of the function.
+ * @return The corresponding double function.
+ */
+ public static DoubleFunction getDoubleFunction(String function) {
+ if (function.equalsIgnoreCase("Identity")) {
+ return Functions.IDENTITY;
+ } else if (function.equalsIgnoreCase("Sigmoid")) {
+ return Functions.SIGMOID;
+ } else {
+ throw new IllegalArgumentException("Function not supported.");
+ }
+ }
+
+ /**
+ * Get the derivation double function by the name.
+ * Currently supports: "Identity", "Sigmoid".
+ *
+ * @param function The name of the function.
+ * @return The double function.
+ */
+ public static DoubleFunction getDerivativeDoubleFunction(String function) {
+ if (function.equalsIgnoreCase("Identity")) {
+ return derivativeIdentityFunction;
+ } else if (function.equalsIgnoreCase("Sigmoid")) {
+ return Functions.SIGMOIDGRADIENT;
+ } else {
+ throw new IllegalArgumentException("Function not supported.");
+ }
+ }
+
+ /**
+ * Get the corresponding double-double function by the name.
+ * Currently supports: "Minus_Squared", "Cross_Entropy".
+ *
+ * @param function The name of the function.
+ * @return The double-double function.
+ */
+ public static DoubleDoubleFunction getDoubleDoubleFunction(String function) {
+ if (function.equalsIgnoreCase("Minus_Squared")) {
+ return Functions.MINUS_SQUARED;
+ } else if (function.equalsIgnoreCase("Cross_Entropy")) {
+ return derivativeCrossEntropy;
+ } else {
+ throw new IllegalArgumentException("Function not supported.");
+ }
+ }
+
+ /**
+ * Get the corresponding derivation of double double function by the name.
+ * Currently supports: "Minus_Squared", "Cross_Entropy".
+ *
+ * @param function The name of the function.
+ * @return The double-double-function.
+ */
+ public static DoubleDoubleFunction getDerivativeDoubleDoubleFunction(String function) {
+ if (function.equalsIgnoreCase("Minus_Squared")) {
+ return derivativeMinusSquared;
+ } else if (function.equalsIgnoreCase("Cross_Entropy")) {
+ return derivativeCrossEntropy;
+ } else {
+ throw new IllegalArgumentException("Function not supported.");
+ }
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java
new file mode 100644
index 0000000..36d6792
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptron.java
@@ -0,0 +1,227 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.mlp;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.commons.csv.CSVUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+/** Run {@link MultilayerPerceptron} classification. */
+public class RunMultilayerPerceptron {
+
+ private static final Logger log = LoggerFactory.getLogger(RunMultilayerPerceptron.class);
+
+ static class Parameters {
+ String inputFilePathStr;
+ String inputFileFormat;
+ String modelFilePathStr;
+ String outputFilePathStr;
+ int columnStart;
+ int columnEnd;
+ boolean skipHeader;
+ }
+
+ public static void main(String[] args) throws Exception {
+
+ Parameters parameters = new Parameters();
+
+ if (parseArgs(args, parameters)) {
+ log.info("Load model from {}.", parameters.modelFilePathStr);
+ MultilayerPerceptron mlp = new MultilayerPerceptron(parameters.modelFilePathStr);
+
+ log.info("Topology of MLP: {}.", Arrays.toString(mlp.getLayerSizeList().toArray()));
+
+ // validate the data
+ log.info("Read the data...");
+ Path inputFilePath = new Path(parameters.inputFilePathStr);
+ FileSystem inputFS = inputFilePath.getFileSystem(new Configuration());
+ if (!inputFS.exists(inputFilePath)) {
+ log.error("Input file '{}' does not exists!", parameters.inputFilePathStr);
+ mlp.close();
+ return;
+ }
+
+ Path outputFilePath = new Path(parameters.outputFilePathStr);
+ FileSystem outputFS = inputFilePath.getFileSystem(new Configuration());
+ if (outputFS.exists(outputFilePath)) {
+ log.error("Output file '{}' already exists!", parameters.outputFilePathStr);
+ mlp.close();
+ return;
+ }
+
+ if (!parameters.inputFileFormat.equals("csv")) {
+ log.error("Currently only supports for csv format.");
+ mlp.close();
+ return; // current only supports csv format
+ }
+
+ log.info("Read from column {} to column {}.", parameters.columnStart, parameters.columnEnd);
+
+ BufferedWriter writer = null;
+ BufferedReader reader = null;
+
+ try {
+ writer = new BufferedWriter(new OutputStreamWriter(outputFS.create(outputFilePath)));
+ reader = new BufferedReader(new InputStreamReader(inputFS.open(inputFilePath)));
+
+ String line;
+
+ if (parameters.skipHeader) {
+ reader.readLine();
+ }
+
+ while ((line = reader.readLine()) != null) {
+ String[] tokens = CSVUtils.parseLine(line);
+ double[] features = new double[Math.min(parameters.columnEnd, tokens.length) - parameters.columnStart + 1];
+
+ for (int i = parameters.columnStart, j = 0; i < Math.min(parameters.columnEnd + 1, tokens.length); ++i, ++j) {
+ features[j] = Double.parseDouble(tokens[i]);
+ }
+ Vector featureVec = new DenseVector(features);
+ Vector res = mlp.getOutput(featureVec);
+ int mostProbablyLabelIndex = res.maxValueIndex();
+ writer.write(String.valueOf(mostProbablyLabelIndex));
+ }
+ mlp.close();
+ log.info("Labeling finished.");
+ } finally {
+ Closeables.close(reader, true);
+ Closeables.close(writer, true);
+ }
+ }
+ }
+
+ /**
+ * Parse the arguments.
+ *
+ * @param args The input arguments.
+ * @param parameters The parameters need to be filled.
+ * @return true or false
+ * @throws Exception
+ */
+ private static boolean parseArgs(String[] args, Parameters parameters) throws Exception {
+ // build the options
+ log.info("Validate and parse arguments...");
+ DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ GroupBuilder groupBuilder = new GroupBuilder();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+ Option inputFileFormatOption = optionBuilder
+ .withLongName("format")
+ .withShortName("f")
+ .withArgument(argumentBuilder.withName("file type").withDefault("csv").withMinimum(1).withMaximum(1).create())
+ .withDescription("type of input file, currently support 'csv'")
+ .create();
+
+ List<Integer> columnRangeDefault = Lists.newArrayList();
+ columnRangeDefault.add(0);
+ columnRangeDefault.add(Integer.MAX_VALUE);
+
+ Option skipHeaderOption = optionBuilder.withLongName("skipHeader")
+ .withShortName("sh").withRequired(false)
+ .withDescription("whether to skip the first row of the input file")
+ .create();
+
+ Option inputColumnRangeOption = optionBuilder
+ .withLongName("columnRange")
+ .withShortName("cr")
+ .withDescription("the column range of the input file, start from 0")
+ .withArgument(argumentBuilder.withName("range").withMinimum(2).withMaximum(2)
+ .withDefaults(columnRangeDefault).create()).create();
+
+ Group inputFileTypeGroup = groupBuilder.withOption(skipHeaderOption)
+ .withOption(inputColumnRangeOption).withOption(inputFileFormatOption)
+ .create();
+
+ Option inputOption = optionBuilder
+ .withLongName("input")
+ .withShortName("i")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("file path").withMinimum(1).withMaximum(1).create())
+ .withDescription("the file path of unlabelled dataset")
+ .withChildren(inputFileTypeGroup).create();
+
+ Option modelOption = optionBuilder
+ .withLongName("model")
+ .withShortName("mo")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("model file").withMinimum(1).withMaximum(1).create())
+ .withDescription("the file path of the model").create();
+
+ Option labelsOption = optionBuilder
+ .withLongName("labels")
+ .withShortName("labels")
+ .withArgument(argumentBuilder.withName("label-name").withMinimum(2).create())
+ .withDescription("an ordered list of label names").create();
+
+ Group labelsGroup = groupBuilder.withOption(labelsOption).create();
+
+ Option outputOption = optionBuilder
+ .withLongName("output")
+ .withShortName("o")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withConsumeRemaining("file path").withMinimum(1).withMaximum(1).create())
+ .withDescription("the file path of labelled results").withChildren(labelsGroup).create();
+
+ // parse the input
+ Parser parser = new Parser();
+ Group normalOption = groupBuilder.withOption(inputOption).withOption(modelOption).withOption(outputOption).create();
+ parser.setGroup(normalOption);
+ CommandLine commandLine = parser.parseAndHelp(args);
+ if (commandLine == null) {
+ return false;
+ }
+
+ // obtain the arguments
+ parameters.inputFilePathStr = TrainMultilayerPerceptron.getString(commandLine, inputOption);
+ parameters.inputFileFormat = TrainMultilayerPerceptron.getString(commandLine, inputFileFormatOption);
+ parameters.skipHeader = commandLine.hasOption(skipHeaderOption);
+ parameters.modelFilePathStr = TrainMultilayerPerceptron.getString(commandLine, modelOption);
+ parameters.outputFilePathStr = TrainMultilayerPerceptron.getString(commandLine, outputOption);
+
+ List<?> columnRange = commandLine.getValues(inputColumnRangeOption);
+ parameters.columnStart = Integer.parseInt(columnRange.get(0).toString());
+ parameters.columnEnd = Integer.parseInt(columnRange.get(1).toString());
+
+ return true;
+ }
+
+}
[15/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/HighDFWordsPruner.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/HighDFWordsPruner.java b/mr/src/main/java/org/apache/mahout/vectorizer/HighDFWordsPruner.java
new file mode 100644
index 0000000..c3813c3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/HighDFWordsPruner.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.vectorizer.common.PartialVectorMerger;
+import org.apache.mahout.vectorizer.pruner.PrunedPartialVectorMergeReducer;
+import org.apache.mahout.vectorizer.pruner.WordsPrunerReducer;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+public final class HighDFWordsPruner {
+
+ public static final String STD_CALC_DIR = "stdcalc";
+ public static final String MAX_DF = "max.df";
+ public static final String MIN_DF = "min.df";
+
+ private HighDFWordsPruner() {
+ }
+
+ public static void pruneVectors(Path tfDir, Path prunedTFDir, Path prunedPartialTFDir, long maxDF,
+ long minDF, Configuration baseConf,
+ Pair<Long[], List<Path>> docFrequenciesFeatures,
+ float normPower,
+ boolean logNormalize,
+ int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+
+ int partialVectorIndex = 0;
+ List<Path> partialVectorPaths = new ArrayList<>();
+ for (Path path : docFrequenciesFeatures.getSecond()) {
+ Path partialVectorOutputPath = new Path(prunedPartialTFDir, "partial-" + partialVectorIndex++);
+ partialVectorPaths.add(partialVectorOutputPath);
+ pruneVectorsPartial(tfDir, partialVectorOutputPath, path, maxDF, minDF, baseConf);
+ }
+
+ mergePartialVectors(partialVectorPaths, prunedTFDir, baseConf, normPower, logNormalize, numReducers);
+ HadoopUtil.delete(new Configuration(baseConf), prunedPartialTFDir);
+ }
+
+ private static void pruneVectorsPartial(Path input, Path output, Path dictionaryFilePath, long maxDF,
+ long minDF, Configuration baseConf) throws IOException, InterruptedException,
+ ClassNotFoundException {
+
+ Configuration conf = new Configuration(baseConf);
+ // this conf parameter needs to be set enable serialisation of conf
+ // values
+ conf.set("io.serializations",
+ "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ conf.setLong(MAX_DF, maxDF);
+ conf.setLong(MIN_DF, minDF);
+ DistributedCache.addCacheFile(dictionaryFilePath.toUri(), conf);
+
+ Job job = HadoopUtil.prepareJob(input, output, SequenceFileInputFormat.class,
+ Mapper.class, null, null, WordsPrunerReducer.class,
+ Text.class, VectorWritable.class, SequenceFileOutputFormat.class,
+ conf);
+ job.setJobName(": Prune Vectors: input-folder: " + input
+ + ", dictionary-file: " + dictionaryFilePath.toString());
+
+ HadoopUtil.delete(conf, output);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+ public static void mergePartialVectors(Iterable<Path> partialVectorPaths,
+ Path output,
+ Configuration baseConf,
+ float normPower,
+ boolean logNormalize,
+ int numReducers)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ Configuration conf = new Configuration(baseConf);
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ conf.setFloat(PartialVectorMerger.NORMALIZATION_POWER, normPower);
+ conf.setBoolean(PartialVectorMerger.LOG_NORMALIZE, logNormalize);
+
+ Job job = new Job(conf);
+ job.setJobName("PrunerPartialVectorMerger::MergePartialVectors");
+ job.setJarByClass(PartialVectorMerger.class);
+
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(VectorWritable.class);
+
+ FileInputFormat.setInputPaths(job, getCommaSeparatedPaths(partialVectorPaths));
+
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setMapperClass(Mapper.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setReducerClass(PrunedPartialVectorMergeReducer.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setNumReduceTasks(numReducers);
+
+ HadoopUtil.delete(conf, output);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+ private static String getCommaSeparatedPaths(Iterable<Path> paths) {
+ StringBuilder commaSeparatedPaths = new StringBuilder(100);
+ String sep = "";
+ for (Path path : paths) {
+ commaSeparatedPaths.append(sep).append(path.toString());
+ sep = ",";
+ }
+ return commaSeparatedPaths.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/SimpleTextEncodingVectorizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/SimpleTextEncodingVectorizer.java b/mr/src/main/java/org/apache/mahout/vectorizer/SimpleTextEncodingVectorizer.java
new file mode 100644
index 0000000..e6339a1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/SimpleTextEncodingVectorizer.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * <p>Runs a Map/Reduce job that encodes {@link org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder} the
+ * input and writes it to the output as a sequence file.</p>
+ *
+ * <p>Only works on basic text, where the value in the SequenceFile is a blob of text.</p>
+ */
+//TODO: find commonalities w/ DictionaryVectorizer and abstract them out
+public class SimpleTextEncodingVectorizer implements Vectorizer {
+
+ private static final Logger log = LoggerFactory.getLogger(SimpleTextEncodingVectorizer.class);
+
+ @Override
+ public void createVectors(Path input, Path output, VectorizerConfig config)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ //do this for convenience of using prepareJob
+ Job job = HadoopUtil.prepareJob(input, output,
+ SequenceFileInputFormat.class,
+ EncodingMapper.class,
+ Text.class,
+ VectorWritable.class,
+ SequenceFileOutputFormat.class,
+ config.getConf());
+ Configuration conf = job.getConfiguration();
+ conf.set(EncodingMapper.USE_SEQUENTIAL, String.valueOf(config.isSequentialAccess()));
+ conf.set(EncodingMapper.USE_NAMED_VECTORS, String.valueOf(config.isNamedVectors()));
+ conf.set(EncodingMapper.ANALYZER_NAME, config.getAnalyzerClassName());
+ conf.set(EncodingMapper.ENCODER_FIELD_NAME, config.getEncoderName());
+ conf.set(EncodingMapper.ENCODER_CLASS, config.getEncoderClass());
+ conf.set(EncodingMapper.CARDINALITY, String.valueOf(config.getCardinality()));
+ job.setNumReduceTasks(0);
+ boolean finished = job.waitForCompletion(true);
+
+ log.info("result of run: {}", finished);
+ if (!finished) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/SparseVectorsFromSequenceFiles.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/SparseVectorsFromSequenceFiles.java b/mr/src/main/java/org/apache/mahout/vectorizer/SparseVectorsFromSequenceFiles.java
new file mode 100644
index 0000000..ee56124
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/SparseVectorsFromSequenceFiles.java
@@ -0,0 +1,369 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.lucene.AnalyzerUtils;
+import org.apache.mahout.math.hadoop.stats.BasicStats;
+import org.apache.mahout.vectorizer.collocations.llr.LLRReducer;
+import org.apache.mahout.vectorizer.common.PartialVectorMerger;
+import org.apache.mahout.vectorizer.tfidf.TFIDFConverter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.List;
+
+/**
+ * Converts a given set of sequence files into SparseVectors
+ */
+public final class SparseVectorsFromSequenceFiles extends AbstractJob {
+
+ private static final Logger log = LoggerFactory.getLogger(SparseVectorsFromSequenceFiles.class);
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new SparseVectorsFromSequenceFiles(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option inputDirOpt = DefaultOptionCreator.inputOption().create();
+
+ Option outputDirOpt = DefaultOptionCreator.outputOption().create();
+
+ Option minSupportOpt = obuilder.withLongName("minSupport").withArgument(
+ abuilder.withName("minSupport").withMinimum(1).withMaximum(1).create()).withDescription(
+ "(Optional) Minimum Support. Default Value: 2").withShortName("s").create();
+
+ Option analyzerNameOpt = obuilder.withLongName("analyzerName").withArgument(
+ abuilder.withName("analyzerName").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The class name of the analyzer").withShortName("a").create();
+
+ Option chunkSizeOpt = obuilder.withLongName("chunkSize").withArgument(
+ abuilder.withName("chunkSize").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The chunkSize in MegaBytes. Default Value: 100MB").withShortName("chunk").create();
+
+ Option weightOpt = obuilder.withLongName("weight").withRequired(false).withArgument(
+ abuilder.withName("weight").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The kind of weight to use. Currently TF or TFIDF. Default: TFIDF").withShortName("wt").create();
+
+ Option minDFOpt = obuilder.withLongName("minDF").withRequired(false).withArgument(
+ abuilder.withName("minDF").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The minimum document frequency. Default is 1").withShortName("md").create();
+
+ Option maxDFPercentOpt = obuilder.withLongName("maxDFPercent").withRequired(false).withArgument(
+ abuilder.withName("maxDFPercent").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The max percentage of docs for the DF. Can be used to remove really high frequency terms."
+ + " Expressed as an integer between 0 and 100. Default is 99. If maxDFSigma is also set, "
+ + "it will override this value.").withShortName("x").create();
+
+ Option maxDFSigmaOpt = obuilder.withLongName("maxDFSigma").withRequired(false).withArgument(
+ abuilder.withName("maxDFSigma").withMinimum(1).withMaximum(1).create()).withDescription(
+ "What portion of the tf (tf-idf) vectors to be used, expressed in times the standard deviation (sigma) "
+ + "of the document frequencies of these vectors. Can be used to remove really high frequency terms."
+ + " Expressed as a double value. Good value to be specified is 3.0. In case the value is less "
+ + "than 0 no vectors will be filtered out. Default is -1.0. Overrides maxDFPercent")
+ .withShortName("xs").create();
+
+ Option minLLROpt = obuilder.withLongName("minLLR").withRequired(false).withArgument(
+ abuilder.withName("minLLR").withMinimum(1).withMaximum(1).create()).withDescription(
+ "(Optional)The minimum Log Likelihood Ratio(Float) Default is " + LLRReducer.DEFAULT_MIN_LLR)
+ .withShortName("ml").create();
+
+ Option numReduceTasksOpt = obuilder.withLongName("numReducers").withArgument(
+ abuilder.withName("numReducers").withMinimum(1).withMaximum(1).create()).withDescription(
+ "(Optional) Number of reduce tasks. Default Value: 1").withShortName("nr").create();
+
+ Option powerOpt = obuilder.withLongName("norm").withRequired(false).withArgument(
+ abuilder.withName("norm").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The norm to use, expressed as either a float or \"INF\" if you want to use the Infinite norm. "
+ + "Must be greater or equal to 0. The default is not to normalize").withShortName("n").create();
+
+ Option logNormalizeOpt = obuilder.withLongName("logNormalize").withRequired(false)
+ .withDescription(
+ "(Optional) Whether output vectors should be logNormalize. If set true else false")
+ .withShortName("lnorm").create();
+
+ Option maxNGramSizeOpt = obuilder.withLongName("maxNGramSize").withRequired(false).withArgument(
+ abuilder.withName("ngramSize").withMinimum(1).withMaximum(1).create())
+ .withDescription(
+ "(Optional) The maximum size of ngrams to create"
+ + " (2 = bigrams, 3 = trigrams, etc) Default Value:1").withShortName("ng").create();
+
+ Option sequentialAccessVectorOpt = obuilder.withLongName("sequentialAccessVector").withRequired(false)
+ .withDescription(
+ "(Optional) Whether output vectors should be SequentialAccessVectors. If set true else false")
+ .withShortName("seq").create();
+
+ Option namedVectorOpt = obuilder.withLongName("namedVector").withRequired(false)
+ .withDescription(
+ "(Optional) Whether output vectors should be NamedVectors. If set true else false")
+ .withShortName("nv").create();
+
+ Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withDescription(
+ "If set, overwrite the output directory").withShortName("ow").create();
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(minSupportOpt).withOption(analyzerNameOpt)
+ .withOption(chunkSizeOpt).withOption(outputDirOpt).withOption(inputDirOpt).withOption(minDFOpt)
+ .withOption(maxDFSigmaOpt).withOption(maxDFPercentOpt).withOption(weightOpt).withOption(powerOpt)
+ .withOption(minLLROpt).withOption(numReduceTasksOpt).withOption(maxNGramSizeOpt).withOption(overwriteOutput)
+ .withOption(helpOpt).withOption(sequentialAccessVectorOpt).withOption(namedVectorOpt)
+ .withOption(logNormalizeOpt)
+ .create();
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ parser.setHelpOption(helpOpt);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return -1;
+ }
+
+ Path inputDir = new Path((String) cmdLine.getValue(inputDirOpt));
+ Path outputDir = new Path((String) cmdLine.getValue(outputDirOpt));
+
+ int chunkSize = 100;
+ if (cmdLine.hasOption(chunkSizeOpt)) {
+ chunkSize = Integer.parseInt((String) cmdLine.getValue(chunkSizeOpt));
+ }
+ int minSupport = 2;
+ if (cmdLine.hasOption(minSupportOpt)) {
+ String minSupportString = (String) cmdLine.getValue(minSupportOpt);
+ minSupport = Integer.parseInt(minSupportString);
+ }
+
+ int maxNGramSize = 1;
+
+ if (cmdLine.hasOption(maxNGramSizeOpt)) {
+ try {
+ maxNGramSize = Integer.parseInt(cmdLine.getValue(maxNGramSizeOpt).toString());
+ } catch (NumberFormatException ex) {
+ log.warn("Could not parse ngram size option");
+ }
+ }
+ log.info("Maximum n-gram size is: {}", maxNGramSize);
+
+ if (cmdLine.hasOption(overwriteOutput)) {
+ HadoopUtil.delete(getConf(), outputDir);
+ }
+
+ float minLLRValue = LLRReducer.DEFAULT_MIN_LLR;
+ if (cmdLine.hasOption(minLLROpt)) {
+ minLLRValue = Float.parseFloat(cmdLine.getValue(minLLROpt).toString());
+ }
+ log.info("Minimum LLR value: {}", minLLRValue);
+
+ int reduceTasks = 1;
+ if (cmdLine.hasOption(numReduceTasksOpt)) {
+ reduceTasks = Integer.parseInt(cmdLine.getValue(numReduceTasksOpt).toString());
+ }
+ log.info("Number of reduce tasks: {}", reduceTasks);
+
+ Class<? extends Analyzer> analyzerClass = StandardAnalyzer.class;
+ if (cmdLine.hasOption(analyzerNameOpt)) {
+ String className = cmdLine.getValue(analyzerNameOpt).toString();
+ analyzerClass = Class.forName(className).asSubclass(Analyzer.class);
+ // try instantiating it, b/c there isn't any point in setting it if
+ // you can't instantiate it
+ AnalyzerUtils.createAnalyzer(analyzerClass);
+ }
+
+ boolean processIdf;
+
+ if (cmdLine.hasOption(weightOpt)) {
+ String wString = cmdLine.getValue(weightOpt).toString();
+ if ("tf".equalsIgnoreCase(wString)) {
+ processIdf = false;
+ } else if ("tfidf".equalsIgnoreCase(wString)) {
+ processIdf = true;
+ } else {
+ throw new OptionException(weightOpt);
+ }
+ } else {
+ processIdf = true;
+ }
+
+ int minDf = 1;
+ if (cmdLine.hasOption(minDFOpt)) {
+ minDf = Integer.parseInt(cmdLine.getValue(minDFOpt).toString());
+ }
+ int maxDFPercent = 99;
+ if (cmdLine.hasOption(maxDFPercentOpt)) {
+ maxDFPercent = Integer.parseInt(cmdLine.getValue(maxDFPercentOpt).toString());
+ }
+ double maxDFSigma = -1.0;
+ if (cmdLine.hasOption(maxDFSigmaOpt)) {
+ maxDFSigma = Double.parseDouble(cmdLine.getValue(maxDFSigmaOpt).toString());
+ }
+
+ float norm = PartialVectorMerger.NO_NORMALIZING;
+ if (cmdLine.hasOption(powerOpt)) {
+ String power = cmdLine.getValue(powerOpt).toString();
+ if ("INF".equals(power)) {
+ norm = Float.POSITIVE_INFINITY;
+ } else {
+ norm = Float.parseFloat(power);
+ }
+ }
+
+ boolean logNormalize = false;
+ if (cmdLine.hasOption(logNormalizeOpt)) {
+ logNormalize = true;
+ }
+ log.info("Tokenizing documents in {}", inputDir);
+ Configuration conf = getConf();
+ Path tokenizedPath = new Path(outputDir, DocumentProcessor.TOKENIZED_DOCUMENT_OUTPUT_FOLDER);
+ //TODO: move this into DictionaryVectorizer , and then fold SparseVectorsFrom with EncodedVectorsFrom
+ // to have one framework for all of this.
+ DocumentProcessor.tokenizeDocuments(inputDir, analyzerClass, tokenizedPath, conf);
+
+ boolean sequentialAccessOutput = false;
+ if (cmdLine.hasOption(sequentialAccessVectorOpt)) {
+ sequentialAccessOutput = true;
+ }
+
+ boolean namedVectors = false;
+ if (cmdLine.hasOption(namedVectorOpt)) {
+ namedVectors = true;
+ }
+ boolean shouldPrune = maxDFSigma >= 0.0 || maxDFPercent > 0.00;
+ String tfDirName = shouldPrune
+ ? DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER + "-toprune"
+ : DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER;
+ log.info("Creating Term Frequency Vectors");
+ if (processIdf) {
+ DictionaryVectorizer.createTermFrequencyVectors(tokenizedPath,
+ outputDir,
+ tfDirName,
+ conf,
+ minSupport,
+ maxNGramSize,
+ minLLRValue,
+ -1.0f,
+ false,
+ reduceTasks,
+ chunkSize,
+ sequentialAccessOutput,
+ namedVectors);
+ } else {
+ DictionaryVectorizer.createTermFrequencyVectors(tokenizedPath,
+ outputDir,
+ tfDirName,
+ conf,
+ minSupport,
+ maxNGramSize,
+ minLLRValue,
+ norm,
+ logNormalize,
+ reduceTasks,
+ chunkSize,
+ sequentialAccessOutput,
+ namedVectors);
+ }
+
+ Pair<Long[], List<Path>> docFrequenciesFeatures = null;
+ // Should document frequency features be processed
+ if (shouldPrune || processIdf) {
+ log.info("Calculating IDF");
+ docFrequenciesFeatures =
+ TFIDFConverter.calculateDF(new Path(outputDir, tfDirName), outputDir, conf, chunkSize);
+ }
+
+ long maxDF = maxDFPercent; //if we are pruning by std dev, then this will get changed
+ if (shouldPrune) {
+ long vectorCount = docFrequenciesFeatures.getFirst()[1];
+ if (maxDFSigma >= 0.0) {
+ Path dfDir = new Path(outputDir, TFIDFConverter.WORDCOUNT_OUTPUT_FOLDER);
+ Path stdCalcDir = new Path(outputDir, HighDFWordsPruner.STD_CALC_DIR);
+
+ // Calculate the standard deviation
+ double stdDev = BasicStats.stdDevForGivenMean(dfDir, stdCalcDir, 0.0, conf);
+ maxDF = (int) (100.0 * maxDFSigma * stdDev / vectorCount);
+ }
+
+ long maxDFThreshold = (long) (vectorCount * (maxDF / 100.0f));
+
+ // Prune the term frequency vectors
+ Path tfDir = new Path(outputDir, tfDirName);
+ Path prunedTFDir = new Path(outputDir, DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER);
+ Path prunedPartialTFDir =
+ new Path(outputDir, DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER + "-partial");
+ log.info("Pruning");
+ if (processIdf) {
+ HighDFWordsPruner.pruneVectors(tfDir,
+ prunedTFDir,
+ prunedPartialTFDir,
+ maxDFThreshold,
+ minDf,
+ conf,
+ docFrequenciesFeatures,
+ -1.0f,
+ false,
+ reduceTasks);
+ } else {
+ HighDFWordsPruner.pruneVectors(tfDir,
+ prunedTFDir,
+ prunedPartialTFDir,
+ maxDFThreshold,
+ minDf,
+ conf,
+ docFrequenciesFeatures,
+ norm,
+ logNormalize,
+ reduceTasks);
+ }
+ HadoopUtil.delete(new Configuration(conf), tfDir);
+ }
+ if (processIdf) {
+ TFIDFConverter.processTfIdf(
+ new Path(outputDir, DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER),
+ outputDir, conf, docFrequenciesFeatures, minDf, maxDF, norm, logNormalize,
+ sequentialAccessOutput, namedVectors, reduceTasks);
+ }
+ } catch (OptionException e) {
+ log.error("Exception", e);
+ CommandLineUtil.printHelp(group);
+ }
+ return 0;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/TF.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/TF.java b/mr/src/main/java/org/apache/mahout/vectorizer/TF.java
new file mode 100644
index 0000000..1818580
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/TF.java
@@ -0,0 +1,30 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+/**
+ * {@link Weight} based on term frequency only
+ */
+public class TF implements Weight {
+
+ @Override
+ public double calculate(int tf, int df, int length, int numDocs) {
+ //ignore length
+ return tf;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/TFIDF.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/TFIDF.java b/mr/src/main/java/org/apache/mahout/vectorizer/TFIDF.java
new file mode 100644
index 0000000..0a537eb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/TFIDF.java
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+import org.apache.lucene.search.similarities.DefaultSimilarity;
+//TODO: add a new class that supports arbitrary Lucene similarity implementations
+public class TFIDF implements Weight {
+
+ private final DefaultSimilarity sim = new DefaultSimilarity();
+
+ @Override
+ public double calculate(int tf, int df, int length, int numDocs) {
+ // ignore length
+ return sim.tf(tf) * sim.idf(df, numDocs);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/Vectorizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/Vectorizer.java b/mr/src/main/java/org/apache/mahout/vectorizer/Vectorizer.java
new file mode 100644
index 0000000..45f0043
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/Vectorizer.java
@@ -0,0 +1,29 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+import org.apache.hadoop.fs.Path;
+
+import java.io.IOException;
+
+public interface Vectorizer {
+
+ void createVectors(Path input, Path output, VectorizerConfig config)
+ throws IOException, ClassNotFoundException, InterruptedException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/VectorizerConfig.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/VectorizerConfig.java b/mr/src/main/java/org/apache/mahout/vectorizer/VectorizerConfig.java
new file mode 100644
index 0000000..edaf2f3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/VectorizerConfig.java
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+import org.apache.hadoop.conf.Configuration;
+
+/**
+ * The config for a Vectorizer. Not all implementations need use all variables.
+ */
+public final class VectorizerConfig {
+
+ private Configuration conf;
+ private String analyzerClassName;
+ private String encoderName;
+ private boolean sequentialAccess;
+ private boolean namedVectors;
+ private int cardinality;
+ private String encoderClass;
+ private String tfDirName;
+ private int minSupport;
+ private int maxNGramSize;
+ private float minLLRValue;
+ private float normPower;
+ private boolean logNormalize;
+ private int numReducers;
+ private int chunkSizeInMegabytes;
+
+ public VectorizerConfig(Configuration conf,
+ String analyzerClassName,
+ String encoderClass,
+ String encoderName,
+ boolean sequentialAccess,
+ boolean namedVectors,
+ int cardinality) {
+ this.conf = conf;
+ this.analyzerClassName = analyzerClassName;
+ this.encoderClass = encoderClass;
+ this.encoderName = encoderName;
+ this.sequentialAccess = sequentialAccess;
+ this.namedVectors = namedVectors;
+ this.cardinality = cardinality;
+ }
+
+ public Configuration getConf() {
+ return conf;
+ }
+
+ public void setConf(Configuration conf) {
+ this.conf = conf;
+ }
+
+ public String getAnalyzerClassName() {
+ return analyzerClassName;
+ }
+
+ public void setAnalyzerClassName(String analyzerClassName) {
+ this.analyzerClassName = analyzerClassName;
+ }
+
+ public String getEncoderName() {
+ return encoderName;
+ }
+
+ public void setEncoderName(String encoderName) {
+ this.encoderName = encoderName;
+ }
+
+ public boolean isSequentialAccess() {
+ return sequentialAccess;
+ }
+
+ public void setSequentialAccess(boolean sequentialAccess) {
+ this.sequentialAccess = sequentialAccess;
+ }
+
+
+ public String getTfDirName() {
+ return tfDirName;
+ }
+
+ public void setTfDirName(String tfDirName) {
+ this.tfDirName = tfDirName;
+ }
+
+ public boolean isNamedVectors() {
+ return namedVectors;
+ }
+
+ public void setNamedVectors(boolean namedVectors) {
+ this.namedVectors = namedVectors;
+ }
+
+ public int getCardinality() {
+ return cardinality;
+ }
+
+ public void setCardinality(int cardinality) {
+ this.cardinality = cardinality;
+ }
+
+ public String getEncoderClass() {
+ return encoderClass;
+ }
+
+ public void setEncoderClass(String encoderClass) {
+ this.encoderClass = encoderClass;
+ }
+
+ public int getMinSupport() {
+ return minSupport;
+ }
+
+ public void setMinSupport(int minSupport) {
+ this.minSupport = minSupport;
+ }
+
+ public int getMaxNGramSize() {
+ return maxNGramSize;
+ }
+
+ public void setMaxNGramSize(int maxNGramSize) {
+ this.maxNGramSize = maxNGramSize;
+ }
+
+ public float getMinLLRValue() {
+ return minLLRValue;
+ }
+
+ public void setMinLLRValue(float minLLRValue) {
+ this.minLLRValue = minLLRValue;
+ }
+
+ public float getNormPower() {
+ return normPower;
+ }
+
+ public void setNormPower(float normPower) {
+ this.normPower = normPower;
+ }
+
+ public boolean isLogNormalize() {
+ return logNormalize;
+ }
+
+ public void setLogNormalize(boolean logNormalize) {
+ this.logNormalize = logNormalize;
+ }
+
+ public int getNumReducers() {
+ return numReducers;
+ }
+
+ public void setNumReducers(int numReducers) {
+ this.numReducers = numReducers;
+ }
+
+ public int getChunkSizeInMegabytes() {
+ return chunkSizeInMegabytes;
+ }
+
+ public void setChunkSizeInMegabytes(int chunkSizeInMegabytes) {
+ this.chunkSizeInMegabytes = chunkSizeInMegabytes;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/Weight.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/Weight.java b/mr/src/main/java/org/apache/mahout/vectorizer/Weight.java
new file mode 100644
index 0000000..e36159d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/Weight.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer;
+
+public interface Weight {
+
+ /**
+ * Experimental
+ *
+ * @param tf term freq
+ * @param df doc freq
+ * @param length Length of the document
+ * @param numDocs the total number of docs
+ * @return The weight
+ */
+ double calculate(int tf, int df, int length, int numDocs);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocCombiner.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocCombiner.java b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocCombiner.java
new file mode 100644
index 0000000..54cadbd
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocCombiner.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.collocations.llr;
+
+import java.io.IOException;
+
+import org.apache.hadoop.mapreduce.Reducer;
+
+/** Combiner for pass1 of the CollocationDriver. Combines frequencies for values for the same key */
+public class CollocCombiner extends Reducer<GramKey, Gram, GramKey, Gram> {
+
+ @Override
+ protected void reduce(GramKey key, Iterable<Gram> values, Context context) throws IOException, InterruptedException {
+
+ int freq = 0;
+ Gram value = null;
+
+ // accumulate frequencies from values, preserve the last value
+ // to write to the context.
+ for (Gram value1 : values) {
+ value = value1;
+ freq += value.getFrequency();
+ }
+
+ if (value != null) {
+ value.setFrequency(freq);
+ context.write(key, value);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocDriver.java b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocDriver.java
new file mode 100644
index 0000000..a07ddbd
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocDriver.java
@@ -0,0 +1,284 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.collocations.llr;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.lucene.AnalyzerUtils;
+import org.apache.mahout.vectorizer.DocumentProcessor;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Driver for LLR Collocation discovery mapreduce job */
+public final class CollocDriver extends AbstractJob {
+ //public static final String DEFAULT_OUTPUT_DIRECTORY = "output";
+
+ public static final String SUBGRAM_OUTPUT_DIRECTORY = "subgrams";
+
+ public static final String NGRAM_OUTPUT_DIRECTORY = "ngrams";
+
+ public static final String EMIT_UNIGRAMS = "emit-unigrams";
+
+ public static final boolean DEFAULT_EMIT_UNIGRAMS = false;
+
+ private static final int DEFAULT_MAX_NGRAM_SIZE = 2;
+
+ private static final int DEFAULT_PASS1_NUM_REDUCE_TASKS = 1;
+
+ private static final Logger log = LoggerFactory.getLogger(CollocDriver.class);
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new CollocDriver(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.numReducersOption().create());
+
+ addOption("maxNGramSize",
+ "ng",
+ "(Optional) The max size of ngrams to create (2 = bigrams, 3 = trigrams, etc) default: 2",
+ String.valueOf(DEFAULT_MAX_NGRAM_SIZE));
+ addOption("minSupport", "s", "(Optional) Minimum Support. Default Value: "
+ + CollocReducer.DEFAULT_MIN_SUPPORT, String.valueOf(CollocReducer.DEFAULT_MIN_SUPPORT));
+ addOption("minLLR", "ml", "(Optional)The minimum Log Likelihood Ratio(Float) Default is "
+ + LLRReducer.DEFAULT_MIN_LLR, String.valueOf(LLRReducer.DEFAULT_MIN_LLR));
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addOption("analyzerName", "a", "The class name of the analyzer to use for preprocessing", null);
+
+ addFlag("preprocess", "p", "If set, input is SequenceFile<Text,Text> where the value is the document, "
+ + " which will be tokenized using the specified analyzer.");
+ addFlag("unigram", "u", "If set, unigrams will be emitted in the final output alongside collocations");
+
+ Map<String, List<String>> argMap = parseArguments(args);
+
+ if (argMap == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path output = getOutputPath();
+
+ int maxNGramSize = DEFAULT_MAX_NGRAM_SIZE;
+ if (hasOption("maxNGramSize")) {
+ try {
+ maxNGramSize = Integer.parseInt(getOption("maxNGramSize"));
+ } catch (NumberFormatException ex) {
+ log.warn("Could not parse ngram size option");
+ }
+ }
+ log.info("Maximum n-gram size is: {}", maxNGramSize);
+
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+
+ int minSupport = CollocReducer.DEFAULT_MIN_SUPPORT;
+ if (getOption("minSupport") != null) {
+ minSupport = Integer.parseInt(getOption("minSupport"));
+ }
+ log.info("Minimum Support value: {}", minSupport);
+
+ float minLLRValue = LLRReducer.DEFAULT_MIN_LLR;
+ if (getOption("minLLR") != null) {
+ minLLRValue = Float.parseFloat(getOption("minLLR"));
+ }
+ log.info("Minimum LLR value: {}", minLLRValue);
+
+ int reduceTasks = DEFAULT_PASS1_NUM_REDUCE_TASKS;
+ if (getOption("maxRed") != null) {
+ reduceTasks = Integer.parseInt(getOption("maxRed"));
+ }
+ log.info("Number of pass1 reduce tasks: {}", reduceTasks);
+
+ boolean emitUnigrams = argMap.containsKey("emitUnigrams");
+
+ if (argMap.containsKey("preprocess")) {
+ log.info("Input will be preprocessed");
+ Class<? extends Analyzer> analyzerClass = StandardAnalyzer.class;
+ if (getOption("analyzerName") != null) {
+ String className = getOption("analyzerName");
+ analyzerClass = Class.forName(className).asSubclass(Analyzer.class);
+ // try instantiating it, b/c there isn't any point in setting it if
+ // you can't instantiate it
+ AnalyzerUtils.createAnalyzer(analyzerClass);
+ }
+
+ Path tokenizedPath = new Path(output, DocumentProcessor.TOKENIZED_DOCUMENT_OUTPUT_FOLDER);
+
+ DocumentProcessor.tokenizeDocuments(input, analyzerClass, tokenizedPath, getConf());
+ input = tokenizedPath;
+ } else {
+ log.info("Input will NOT be preprocessed");
+ }
+
+ // parse input and extract collocations
+ long ngramCount =
+ generateCollocations(input, output, getConf(), emitUnigrams, maxNGramSize, reduceTasks, minSupport);
+
+ // tally collocations and perform LLR calculation
+ computeNGramsPruneByLLR(output, getConf(), ngramCount, emitUnigrams, minLLRValue, reduceTasks);
+
+ return 0;
+ }
+
+ /**
+ * Generate all ngrams for the {@link org.apache.mahout.vectorizer.DictionaryVectorizer} job
+ *
+ * @param input
+ * input path containing tokenized documents
+ * @param output
+ * output path where ngrams are generated including unigrams
+ * @param baseConf
+ * job configuration
+ * @param maxNGramSize
+ * minValue = 2.
+ * @param minSupport
+ * minimum support to prune ngrams including unigrams
+ * @param minLLRValue
+ * minimum threshold to prune ngrams
+ * @param reduceTasks
+ * number of reducers used
+ */
+ public static void generateAllGrams(Path input,
+ Path output,
+ Configuration baseConf,
+ int maxNGramSize,
+ int minSupport,
+ float minLLRValue,
+ int reduceTasks)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ // parse input and extract collocations
+ long ngramCount = generateCollocations(input, output, baseConf, true, maxNGramSize, reduceTasks, minSupport);
+
+ // tally collocations and perform LLR calculation
+ computeNGramsPruneByLLR(output, baseConf, ngramCount, true, minLLRValue, reduceTasks);
+ }
+
+ /**
+ * pass1: generate collocations, ngrams
+ */
+ private static long generateCollocations(Path input,
+ Path output,
+ Configuration baseConf,
+ boolean emitUnigrams,
+ int maxNGramSize,
+ int reduceTasks,
+ int minSupport)
+ throws IOException, ClassNotFoundException, InterruptedException {
+
+ Configuration con = new Configuration(baseConf);
+ con.setBoolean(EMIT_UNIGRAMS, emitUnigrams);
+ con.setInt(CollocMapper.MAX_SHINGLE_SIZE, maxNGramSize);
+ con.setInt(CollocReducer.MIN_SUPPORT, minSupport);
+
+ Job job = new Job(con);
+ job.setJobName(CollocDriver.class.getSimpleName() + ".generateCollocations:" + input);
+ job.setJarByClass(CollocDriver.class);
+
+ job.setMapOutputKeyClass(GramKey.class);
+ job.setMapOutputValueClass(Gram.class);
+ job.setPartitionerClass(GramKeyPartitioner.class);
+ job.setGroupingComparatorClass(GramKeyGroupComparator.class);
+
+ job.setOutputKeyClass(Gram.class);
+ job.setOutputValueClass(Gram.class);
+
+ job.setCombinerClass(CollocCombiner.class);
+
+ FileInputFormat.setInputPaths(job, input);
+
+ Path outputPath = new Path(output, SUBGRAM_OUTPUT_DIRECTORY);
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setMapperClass(CollocMapper.class);
+
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setReducerClass(CollocReducer.class);
+ job.setNumReduceTasks(reduceTasks);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+
+ return job.getCounters().findCounter(CollocMapper.Count.NGRAM_TOTAL).getValue();
+ }
+
+ /**
+ * pass2: perform the LLR calculation
+ */
+ private static void computeNGramsPruneByLLR(Path output,
+ Configuration baseConf,
+ long nGramTotal,
+ boolean emitUnigrams,
+ float minLLRValue,
+ int reduceTasks)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Configuration conf = new Configuration(baseConf);
+ conf.setLong(LLRReducer.NGRAM_TOTAL, nGramTotal);
+ conf.setBoolean(EMIT_UNIGRAMS, emitUnigrams);
+ conf.setFloat(LLRReducer.MIN_LLR, minLLRValue);
+
+ Job job = new Job(conf);
+ job.setJobName(CollocDriver.class.getSimpleName() + ".computeNGrams: " + output);
+ job.setJarByClass(CollocDriver.class);
+
+ job.setMapOutputKeyClass(Gram.class);
+ job.setMapOutputValueClass(Gram.class);
+
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(DoubleWritable.class);
+
+ FileInputFormat.setInputPaths(job, new Path(output, SUBGRAM_OUTPUT_DIRECTORY));
+ Path outPath = new Path(output, NGRAM_OUTPUT_DIRECTORY);
+ FileOutputFormat.setOutputPath(job, outPath);
+
+ job.setMapperClass(Mapper.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setReducerClass(LLRReducer.class);
+ job.setNumReduceTasks(reduceTasks);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocMapper.java b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocMapper.java
new file mode 100644
index 0000000..fd99293
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocMapper.java
@@ -0,0 +1,178 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.collocations.llr;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.lucene.analysis.shingle.ShingleFilter;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.lucene.analysis.tokenattributes.TypeAttribute;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.lucene.IteratorTokenStream;
+import org.apache.mahout.math.function.ObjectIntProcedure;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * Pass 1 of the Collocation discovery job which generated ngrams and emits ngrams an their component n-1grams.
+ * Input is a SequeceFile<Text,StringTuple>, where the key is a document id and the value is the tokenized documents.
+ * <p/>
+ */
+public class CollocMapper extends Mapper<Text, StringTuple, GramKey, Gram> {
+
+ private static final byte[] EMPTY = new byte[0];
+
+ public static final String MAX_SHINGLE_SIZE = "maxShingleSize";
+
+ private static final int DEFAULT_MAX_SHINGLE_SIZE = 2;
+
+ public enum Count {
+ NGRAM_TOTAL
+ }
+
+ private static final Logger log = LoggerFactory.getLogger(CollocMapper.class);
+
+ private int maxShingleSize;
+
+ private boolean emitUnigrams;
+
+ /**
+ * Collocation finder: pass 1 map phase.
+ * <p/>
+ * Receives a token stream which gets passed through a Lucene ShingleFilter. The ShingleFilter delivers ngrams of
+ * the appropriate size which are then decomposed into head and tail subgrams which are collected in the
+ * following manner
+ * <p/>
+ * <pre>
+ * k:head_key, v:head_subgram
+ * k:head_key,ngram_key, v:ngram
+ * k:tail_key, v:tail_subgram
+ * k:tail_key,ngram_key, v:ngram
+ * </pre>
+ * <p/>
+ * The 'head' or 'tail' prefix is used to specify whether the subgram in question is the head or tail of the
+ * ngram. In this implementation the head of the ngram is a (n-1)gram, and the tail is a (1)gram.
+ * <p/>
+ * For example, given 'click and clack' and an ngram length of 3:
+ * <pre>
+ * k: head_'click and' v:head_'click and'
+ * k: head_'click and',ngram_'click and clack' v:ngram_'click and clack'
+ * k: tail_'clack', v:tail_'clack'
+ * k: tail_'clack',ngram_'click and clack' v:ngram_'click and clack'
+ * </pre>
+ * <p/>
+ * Also counts the total number of ngrams encountered and adds it to the counter
+ * CollocDriver.Count.NGRAM_TOTAL
+ * </p>
+ *
+ * @throws IOException if there's a problem with the ShingleFilter reading data or the collector collecting output.
+ */
+ @Override
+ protected void map(Text key, StringTuple value, final Context context) throws IOException, InterruptedException {
+
+ try (ShingleFilter sf = new ShingleFilter(new IteratorTokenStream(value.getEntries().iterator()), maxShingleSize)){
+ sf.reset();
+ int count = 0; // ngram count
+
+ OpenObjectIntHashMap<String> ngrams =
+ new OpenObjectIntHashMap<>(value.getEntries().size() * (maxShingleSize - 1));
+ OpenObjectIntHashMap<String> unigrams = new OpenObjectIntHashMap<>(value.getEntries().size());
+
+ do {
+ String term = sf.getAttribute(CharTermAttribute.class).toString();
+ String type = sf.getAttribute(TypeAttribute.class).type();
+ if ("shingle".equals(type)) {
+ count++;
+ ngrams.adjustOrPutValue(term, 1, 1);
+ } else if (emitUnigrams && !term.isEmpty()) { // unigram
+ unigrams.adjustOrPutValue(term, 1, 1);
+ }
+ } while (sf.incrementToken());
+
+ final GramKey gramKey = new GramKey();
+
+ ngrams.forEachPair(new ObjectIntProcedure<String>() {
+ @Override
+ public boolean apply(String term, int frequency) {
+ // obtain components, the leading (n-1)gram and the trailing unigram.
+ int i = term.lastIndexOf(' '); // TODO: fix for non-whitespace delimited languages.
+ if (i != -1) { // bigram, trigram etc
+
+ try {
+ Gram ngram = new Gram(term, frequency, Gram.Type.NGRAM);
+ Gram head = new Gram(term.substring(0, i), frequency, Gram.Type.HEAD);
+ Gram tail = new Gram(term.substring(i + 1), frequency, Gram.Type.TAIL);
+
+ gramKey.set(head, EMPTY);
+ context.write(gramKey, head);
+
+ gramKey.set(head, ngram.getBytes());
+ context.write(gramKey, ngram);
+
+ gramKey.set(tail, EMPTY);
+ context.write(gramKey, tail);
+
+ gramKey.set(tail, ngram.getBytes());
+ context.write(gramKey, ngram);
+
+ } catch (IOException | InterruptedException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+ return true;
+ }
+ });
+
+ unigrams.forEachPair(new ObjectIntProcedure<String>() {
+ @Override
+ public boolean apply(String term, int frequency) {
+ try {
+ Gram unigram = new Gram(term, frequency, Gram.Type.UNIGRAM);
+ gramKey.set(unigram, EMPTY);
+ context.write(gramKey, unigram);
+ } catch (IOException | InterruptedException e) {
+ throw new IllegalStateException(e);
+ }
+ return true;
+ }
+ });
+
+ context.getCounter(Count.NGRAM_TOTAL).increment(count);
+ sf.end();
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ this.maxShingleSize = conf.getInt(MAX_SHINGLE_SIZE, DEFAULT_MAX_SHINGLE_SIZE);
+
+ this.emitUnigrams = conf.getBoolean(CollocDriver.EMIT_UNIGRAMS, CollocDriver.DEFAULT_EMIT_UNIGRAMS);
+
+ if (log.isInfoEnabled()) {
+ log.info("Max Ngram size is {}", this.maxShingleSize);
+ log.info("Emit Unitgrams is {}", emitUnigrams);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocReducer.java b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocReducer.java
new file mode 100644
index 0000000..1fe13e3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocReducer.java
@@ -0,0 +1,176 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.collocations.llr;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Reducer for Pass 1 of the collocation identification job. Generates counts for ngrams and subgrams.
+ */
+public class CollocReducer extends Reducer<GramKey, Gram, Gram, Gram> {
+
+ private static final Logger log = LoggerFactory.getLogger(CollocReducer.class);
+
+ public static final String MIN_SUPPORT = "minSupport";
+
+ public static final int DEFAULT_MIN_SUPPORT = 2;
+
+ public enum Skipped {
+ LESS_THAN_MIN_SUPPORT, MALFORMED_KEY_TUPLE, MALFORMED_TUPLE, MALFORMED_TYPES, MALFORMED_UNIGRAM
+ }
+
+ private int minSupport;
+
+ /**
+ * collocation finder: pass 1 reduce phase:
+ * <p/>
+ * given input from the mapper,
+ *
+ * <pre>
+ * k:head_subgram,ngram, v:ngram:partial freq
+ * k:head_subgram v:head_subgram:partial freq
+ * k:tail_subgram,ngram, v:ngram:partial freq
+ * k:tail_subgram v:tail_subgram:partial freq
+ * k:unigram v:unigram:partial freq
+ * </pre>
+ * sum gram frequencies and output for llr calculation
+ * <p/>
+ * output is:
+ * <pre>
+ * k:ngram:ngramfreq v:head_subgram:head_subgramfreq
+ * k:ngram:ngramfreq v:tail_subgram:tail_subgramfreq
+ * k:unigram:unigramfreq v:unigram:unigramfreq
+ * </pre>
+ * Each ngram's frequency is essentially counted twice, once for head, once for tail.
+ * frequency should be the same for the head and tail. Fix this to count only for the
+ * head and move the count into the value?
+ */
+ @Override
+ protected void reduce(GramKey key, Iterable<Gram> values, Context context) throws IOException, InterruptedException {
+
+ Gram.Type keyType = key.getType();
+
+ if (keyType == Gram.Type.UNIGRAM) {
+ // sum frequencies for unigrams.
+ processUnigram(values.iterator(), context);
+ } else if (keyType == Gram.Type.HEAD || keyType == Gram.Type.TAIL) {
+ // sum frequencies for subgrams, ngram and collect for each ngram.
+ processSubgram(values.iterator(), context);
+ } else {
+ context.getCounter(Skipped.MALFORMED_TYPES).increment(1);
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ this.minSupport = conf.getInt(MIN_SUPPORT, DEFAULT_MIN_SUPPORT);
+
+ boolean emitUnigrams = conf.getBoolean(CollocDriver.EMIT_UNIGRAMS, CollocDriver.DEFAULT_EMIT_UNIGRAMS);
+
+ log.info("Min support is {}", minSupport);
+ log.info("Emit Unitgrams is {}", emitUnigrams);
+ }
+
+ /**
+ * Sum frequencies for unigrams and deliver to the collector
+ */
+ protected void processUnigram(Iterator<Gram> values, Context context)
+ throws IOException, InterruptedException {
+
+ int freq = 0;
+ Gram value = null;
+
+ // accumulate frequencies from values.
+ while (values.hasNext()) {
+ value = values.next();
+ freq += value.getFrequency();
+ }
+
+ if (freq < minSupport) {
+ context.getCounter(Skipped.LESS_THAN_MIN_SUPPORT).increment(1);
+ return;
+ }
+
+ value.setFrequency(freq);
+ context.write(value, value);
+
+ }
+
+ /** Sum frequencies for subgram, ngrams and deliver ngram, subgram pairs to the collector.
+ * <p/>
+ * Sort order guarantees that the subgram/subgram pairs will be seen first and then
+ * subgram/ngram1 pairs, subgram/ngram2 pairs ... subgram/ngramN pairs, so frequencies for
+ * ngrams can be calcualted here as well.
+ * <p/>
+ * We end up calculating frequencies for ngrams for each sugram (head, tail) here, which is
+ * some extra work.
+ * @throws InterruptedException
+ */
+ protected void processSubgram(Iterator<Gram> values, Context context)
+ throws IOException, InterruptedException {
+
+ Gram subgram = null;
+ Gram currentNgram = null;
+
+ while (values.hasNext()) {
+ Gram value = values.next();
+
+ if (value.getType() == Gram.Type.HEAD || value.getType() == Gram.Type.TAIL) {
+ // collect frequency for subgrams.
+ if (subgram == null) {
+ subgram = new Gram(value);
+ } else {
+ subgram.incrementFrequency(value.getFrequency());
+ }
+ } else if (!value.equals(currentNgram)) {
+ // we've collected frequency for all subgrams and we've encountered a new ngram.
+ // collect the old ngram if there was one and we have sufficient support and
+ // create the new ngram.
+ if (currentNgram != null) {
+ if (currentNgram.getFrequency() < minSupport) {
+ context.getCounter(Skipped.LESS_THAN_MIN_SUPPORT).increment(1);
+ } else {
+ context.write(currentNgram, subgram);
+ }
+ }
+
+ currentNgram = new Gram(value);
+ } else {
+ currentNgram.incrementFrequency(value.getFrequency());
+ }
+ }
+
+ // collect last ngram.
+ if (currentNgram != null) {
+ if (currentNgram.getFrequency() < minSupport) {
+ context.getCounter(Skipped.LESS_THAN_MIN_SUPPORT).increment(1);
+ return;
+ }
+
+ context.write(currentNgram, subgram);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/Gram.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/Gram.java b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/Gram.java
new file mode 100644
index 0000000..58234b3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/Gram.java
@@ -0,0 +1,239 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.collocations.llr;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.charset.CharacterCodingException;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.io.BinaryComparable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.mahout.math.Varint;
+
+/**
+ * Writable for holding data generated from the collocation discovery jobs. Depending on the job configuration
+ * gram may be one or more words. In some contexts this is used to hold a complete ngram, while in others it
+ * holds a part of an existing ngram (subgram). Tracks the frequency of the gram and its position in the ngram
+ * in which is was found.
+ */
+public class Gram extends BinaryComparable implements WritableComparable<BinaryComparable> {
+
+ public enum Type {
+ HEAD('h'),
+ TAIL('t'),
+ UNIGRAM('u'),
+ NGRAM('n');
+
+ private final char x;
+
+ Type(char c) {
+ this.x = c;
+ }
+
+ @Override
+ public String toString() {
+ return String.valueOf(x);
+ }
+ }
+
+ private byte[] bytes;
+ private int length;
+ private int frequency;
+
+ public Gram() {
+
+ }
+
+ /**
+ * Copy constructor
+ */
+ public Gram(Gram other) {
+ frequency = other.frequency;
+ length = other.length;
+ bytes = other.bytes.clone();
+ }
+
+ /**
+ * Create an gram with a frequency of 1
+ *
+ * @param ngram
+ * the gram string
+ * @param type
+ * whether the gram is at the head or tail of its text unit or it is a unigram
+ */
+ public Gram(String ngram, Type type) {
+ this(ngram, 1, type);
+ }
+
+
+ /**
+ *
+ * Create a gram with the specified frequency.
+ *
+ * @param ngram
+ * the gram string
+ * @param frequency
+ * the gram frequency
+ * @param type
+ * whether the gram is at the head of its text unit or tail or unigram
+ */
+ public Gram(String ngram, int frequency, Type type) {
+ Preconditions.checkNotNull(ngram);
+ try {
+ // extra character is used for storing type which is part
+ // of the sort key.
+ ByteBuffer bb = Text.encode('\0' + ngram, true);
+ bytes = bb.array();
+ length = bb.limit();
+ } catch (CharacterCodingException e) {
+ throw new IllegalStateException("Should not have happened ",e);
+ }
+
+ encodeType(type, bytes, 0);
+ this.frequency = frequency;
+ }
+
+
+ @Override
+ public byte[] getBytes() {
+ return bytes;
+ }
+
+ @Override
+ public int getLength() {
+ return length;
+ }
+
+ /**
+ * @return the gram is at the head of its text unit or tail or unigram.
+ */
+ public Type getType() {
+ return decodeType(bytes, 0);
+ }
+
+ /**
+ * @return gram term string
+ */
+ public String getString() {
+ try {
+ return Text.decode(bytes, 1, length - 1);
+ } catch (CharacterCodingException e) {
+ throw new IllegalStateException("Should not have happened " + e);
+ }
+ }
+
+ /**
+ * @return gram frequency
+ */
+ public int getFrequency() {
+ return frequency;
+ }
+
+ /**
+ * @param frequency
+ * gram's frequency
+ */
+ public void setFrequency(int frequency) {
+ this.frequency = frequency;
+ }
+
+ public void incrementFrequency(int i) {
+ this.frequency += i;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int newLength = Varint.readUnsignedVarInt(in);
+ setCapacity(newLength, false);
+ in.readFully(bytes, 0, newLength);
+ int newFrequency = Varint.readUnsignedVarInt(in);
+ length = newLength;
+ frequency = newFrequency;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ Varint.writeUnsignedVarInt(length, out);
+ out.write(bytes, 0, length);
+ Varint.writeUnsignedVarInt(frequency, out);
+ }
+
+ /* Cribbed from o.a.hadoop.io.Text:
+ * Sets the capacity of this object to <em>at least</em>
+ * {@code len} bytes. If the current buffer is longer,
+ * then the capacity and existing content of the buffer are
+ * unchanged. If {@code len} is larger
+ * than the current capacity, this object's capacity is
+ * increased to match.
+ * @param len the number of bytes we need
+ * @param keepData should the old data be kept
+ */
+ private void setCapacity(int len, boolean keepData) {
+ len++; // extra byte to hold type
+ if (bytes == null || bytes.length < len) {
+ byte[] newBytes = new byte[len];
+ if (bytes != null && keepData) {
+ System.arraycopy(bytes, 0, newBytes, 0, length);
+ }
+ bytes = newBytes;
+ }
+ }
+
+ @Override
+ public String toString() {
+ return '\'' + getString() + "'[" + getType() + "]:" + frequency;
+ }
+
+ public static void encodeType(Type type, byte[] buf, int offset) {
+ switch (type) {
+ case HEAD:
+ buf[offset] = 0x1;
+ break;
+ case TAIL:
+ buf[offset] = 0x2;
+ break;
+ case UNIGRAM:
+ buf[offset] = 0x3;
+ break;
+ case NGRAM:
+ buf[offset] = 0x4;
+ break;
+ default:
+ throw new IllegalStateException("switch/case problem in encodeType");
+ }
+ }
+
+ public static Type decodeType(byte[] buf, int offset) {
+ switch (buf[offset]) {
+ case 0x1:
+ return Type.HEAD;
+ case 0x2:
+ return Type.TAIL;
+ case 0x3:
+ return Type.UNIGRAM;
+ case 0x4:
+ return Type.NGRAM;
+ default:
+ throw new IllegalStateException("switch/case problem in decodeType");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKey.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKey.java b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKey.java
new file mode 100644
index 0000000..e33ed51
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKey.java
@@ -0,0 +1,133 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.collocations.llr;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.nio.charset.CharacterCodingException;
+
+import org.apache.hadoop.io.BinaryComparable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.mahout.math.Varint;
+import org.apache.mahout.vectorizer.collocations.llr.Gram.Type;
+
+/** A GramKey, based on the identity fields of Gram (type, string) plus a byte[] used for secondary ordering */
+public final class GramKey extends BinaryComparable implements WritableComparable<BinaryComparable> {
+
+ private int primaryLength;
+ private int length;
+ private byte[] bytes;
+
+ public GramKey() {
+
+ }
+
+ /** create a GramKey based on the specified Gram and order
+ *
+ * @param gram
+ * @param order
+ */
+ public GramKey(Gram gram, byte[] order) {
+ set(gram, order);
+ }
+
+ /** set the gram held by this key */
+ public void set(Gram gram, byte[] order) {
+ primaryLength = gram.getLength();
+ length = primaryLength + order.length;
+ setCapacity(length, false);
+ System.arraycopy(gram.getBytes(), 0, bytes, 0, primaryLength);
+ if (order.length > 0) {
+ System.arraycopy(order, 0, bytes, primaryLength, order.length);
+ }
+ }
+
+ @Override
+ public byte[] getBytes() {
+ return bytes;
+ }
+
+ @Override
+ public int getLength() {
+ return length;
+ }
+
+ public int getPrimaryLength() {
+ return primaryLength;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int newLength = Varint.readUnsignedVarInt(in);
+ int newPrimaryLength = Varint.readUnsignedVarInt(in);
+ setCapacity(newLength, false);
+ in.readFully(bytes, 0, newLength);
+ length = newLength;
+ primaryLength = newPrimaryLength;
+
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ Varint.writeUnsignedVarInt(length, out);
+ Varint.writeUnsignedVarInt(primaryLength, out);
+ out.write(bytes, 0, length);
+ }
+
+ /* Cribbed from o.a.hadoop.io.Text:
+ * Sets the capacity of this object to <em>at least</em>
+ * {@code len} bytes. If the current buffer is longer,
+ * then the capacity and existing content of the buffer are
+ * unchanged. If {@code len} is larger
+ * than the current capacity, this object's capacity is
+ * increased to match.
+ * @param len the number of bytes we need
+ * @param keepData should the old data be kept
+ */
+ private void setCapacity(int len, boolean keepData) {
+ if (bytes == null || bytes.length < len) {
+ byte[] newBytes = new byte[len];
+ if (bytes != null && keepData) {
+ System.arraycopy(bytes, 0, newBytes, 0, length);
+ }
+ bytes = newBytes;
+ }
+ }
+
+ /**
+ * @return the gram is at the head of its text unit or tail or unigram.
+ */
+ public Type getType() {
+ return Gram.decodeType(bytes, 0);
+ }
+
+ public String getPrimaryString() {
+ try {
+ return Text.decode(bytes, 1, primaryLength - 1);
+ } catch (CharacterCodingException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return '\'' + getPrimaryString() + "'[" + getType() + ']';
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyGroupComparator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyGroupComparator.java b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyGroupComparator.java
new file mode 100644
index 0000000..7b73d71
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyGroupComparator.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.collocations.llr;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.io.WritableComparator;
+
+import java.io.Serializable;
+
+/** Group GramKeys based on their Gram, ignoring the secondary sort key, so that all keys with the same Gram are sent
+ * to the same call of the reduce method, sorted in natural order (for GramKeys).
+ */
+class GramKeyGroupComparator extends WritableComparator implements Serializable {
+
+ GramKeyGroupComparator() {
+ super(GramKey.class, true);
+ }
+
+ @Override
+ public int compare(WritableComparable a, WritableComparable b) {
+ GramKey gka = (GramKey) a;
+ GramKey gkb = (GramKey) b;
+
+ return WritableComparator.compareBytes(gka.getBytes(), 0, gka.getPrimaryLength(),
+ gkb.getBytes(), 0, gkb.getPrimaryLength());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyPartitioner.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyPartitioner.java b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyPartitioner.java
new file mode 100644
index 0000000..a68339f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyPartitioner.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.collocations.llr;
+
+import org.apache.hadoop.mapreduce.Partitioner;
+
+/**
+ * Partition GramKeys based on their Gram, ignoring the secondary sort key so that all GramKeys with the same
+ * gram are sent to the same partition.
+ */
+public final class GramKeyPartitioner extends Partitioner<GramKey, Gram> {
+
+ @Override
+ public int getPartition(GramKey key, Gram value, int numPartitions) {
+ int hash = 1;
+ byte[] bytes = key.getBytes();
+ int length = key.getPrimaryLength();
+ // Copied from WritableComparator.hashBytes(); skips first byte, type byte
+ for (int i = 1; i < length; i++) {
+ hash = (31 * hash) + bytes[i];
+ }
+ return (hash & Integer.MAX_VALUE) % numPartitions;
+ }
+
+}
[30/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java b/mr/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
new file mode 100644
index 0000000..930fd44
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/canopy/Canopy.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.canopy;
+
+import org.apache.mahout.clustering.iterator.DistanceMeasureCluster;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Vector;
+
+/**
+ * This class models a canopy as a center point, the number of points that are contained within it according
+ * to the application of some distance metric, and a point total which is the sum of all the points and is
+ * used to compute the centroid when needed.
+ */
+@Deprecated
+public class Canopy extends DistanceMeasureCluster {
+
+ /** Used for deserialization as a writable */
+ public Canopy() { }
+
+ /**
+ * Create a new Canopy containing the given point and canopyId
+ *
+ * @param center a point in vector space
+ * @param canopyId an int identifying the canopy local to this process only
+ * @param measure a DistanceMeasure to use
+ */
+ public Canopy(Vector center, int canopyId, DistanceMeasure measure) {
+ super(center, canopyId, measure);
+ observe(center);
+ }
+
+ public String asFormatString() {
+ return "C" + this.getId() + ": " + this.computeCentroid().asFormatString();
+ }
+
+ @Override
+ public String toString() {
+ return getIdentifier() + ": " + getCenter().asFormatString();
+ }
+
+ @Override
+ public String getIdentifier() {
+ return "C-" + getId();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java b/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
new file mode 100644
index 0000000..3ce4757
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyClusterer.java
@@ -0,0 +1,220 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.canopy;
+
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.mahout.clustering.AbstractCluster;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.collect.Lists;
+
+@Deprecated
+public class CanopyClusterer {
+
+ private static final Logger log = LoggerFactory.getLogger(CanopyClusterer.class);
+
+ private int nextCanopyId;
+
+ // the T1 distance threshold
+ private double t1;
+
+ // the T2 distance threshold
+ private double t2;
+
+ // the T3 distance threshold
+ private double t3;
+
+ // the T4 distance threshold
+ private double t4;
+
+ // the distance measure
+ private DistanceMeasure measure;
+
+ public CanopyClusterer(DistanceMeasure measure, double t1, double t2) {
+ this.t1 = t1;
+ this.t2 = t2;
+ this.t3 = t1;
+ this.t4 = t2;
+ this.measure = measure;
+ }
+
+ public double getT1() {
+ return t1;
+ }
+
+ public double getT2() {
+ return t2;
+ }
+
+ public double getT3() {
+ return t3;
+ }
+
+ public double getT4() {
+ return t4;
+ }
+
+ /**
+ * Used by CanopyReducer to set t1=t3 and t2=t4 configuration values
+ */
+ public void useT3T4() {
+ t1 = t3;
+ t2 = t4;
+ }
+
+ /**
+ * This is the same algorithm as the reference but inverted to iterate over
+ * existing canopies instead of the points. Because of this it does not need
+ * to actually store the points, instead storing a total points vector and
+ * the number of points. From this a centroid can be computed.
+ * <p/>
+ * This method is used by the CanopyMapper, CanopyReducer and CanopyDriver.
+ *
+ * @param point
+ * the point to be added
+ * @param canopies
+ * the List<Canopy> to be appended
+ */
+ public void addPointToCanopies(Vector point, Collection<Canopy> canopies) {
+ boolean pointStronglyBound = false;
+ for (Canopy canopy : canopies) {
+ double dist = measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point);
+ if (dist < t1) {
+ if (log.isDebugEnabled()) {
+ log.debug("Added point: {} to canopy: {}", AbstractCluster.formatVector(point, null), canopy.getIdentifier());
+ }
+ canopy.observe(point);
+ }
+ pointStronglyBound = pointStronglyBound || dist < t2;
+ }
+ if (!pointStronglyBound) {
+ if (log.isDebugEnabled()) {
+ log.debug("Created new Canopy:{} at center:{}", nextCanopyId, AbstractCluster.formatVector(point, null));
+ }
+ canopies.add(new Canopy(point, nextCanopyId++, measure));
+ }
+ }
+
+ /**
+ * Return if the point is covered by the canopy
+ *
+ * @param point
+ * a point
+ * @return if the point is covered
+ */
+ public boolean canopyCovers(Canopy canopy, Vector point) {
+ return measure.distance(canopy.getCenter().getLengthSquared(), canopy.getCenter(), point) < t1;
+ }
+
+ /**
+ * Iterate through the points, adding new canopies. Return the canopies.
+ *
+ * @param points
+ * a list<Vector> defining the points to be clustered
+ * @param measure
+ * a DistanceMeasure to use
+ * @param t1
+ * the T1 distance threshold
+ * @param t2
+ * the T2 distance threshold
+ * @return the List<Canopy> created
+ */
+ public static List<Canopy> createCanopies(List<Vector> points,
+ DistanceMeasure measure,
+ double t1,
+ double t2) {
+ List<Canopy> canopies = Lists.newArrayList();
+ /**
+ * Reference Implementation: Given a distance metric, one can create
+ * canopies as follows: Start with a list of the data points in any
+ * order, and with two distance thresholds, T1 and T2, where T1 > T2.
+ * (These thresholds can be set by the user, or selected by
+ * cross-validation.) Pick a point on the list and measure its distance
+ * to all other points. Put all points that are within distance
+ * threshold T1 into a canopy. Remove from the list all points that are
+ * within distance threshold T2. Repeat until the list is empty.
+ */
+ int nextCanopyId = 0;
+ while (!points.isEmpty()) {
+ Iterator<Vector> ptIter = points.iterator();
+ Vector p1 = ptIter.next();
+ ptIter.remove();
+ Canopy canopy = new Canopy(p1, nextCanopyId++, measure);
+ canopies.add(canopy);
+ while (ptIter.hasNext()) {
+ Vector p2 = ptIter.next();
+ double dist = measure.distance(p1, p2);
+ // Put all points that are within distance threshold T1 into the
+ // canopy
+ if (dist < t1) {
+ canopy.observe(p2);
+ }
+ // Remove from the list all points that are within distance
+ // threshold T2
+ if (dist < t2) {
+ ptIter.remove();
+ }
+ }
+ for (Canopy c : canopies) {
+ c.computeParameters();
+ }
+ }
+ return canopies;
+ }
+
+ /**
+ * Iterate through the canopies, adding their centroids to a list
+ *
+ * @param canopies
+ * a List<Canopy>
+ * @return the List<Vector>
+ */
+ public static List<Vector> getCenters(Iterable<Canopy> canopies) {
+ List<Vector> result = Lists.newArrayList();
+ for (Canopy canopy : canopies) {
+ result.add(canopy.getCenter());
+ }
+ return result;
+ }
+
+ /**
+ * Iterate through the canopies, resetting their center to their centroids
+ *
+ * @param canopies
+ * a List<Canopy>
+ */
+ public static void updateCentroids(Iterable<Canopy> canopies) {
+ for (Canopy canopy : canopies) {
+ canopy.computeParameters();
+ }
+ }
+
+ public void setT3(double t3) {
+ this.t3 = t3;
+ }
+
+ public void setT4(double t4) {
+ this.t4 = t4;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyConfigKeys.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyConfigKeys.java b/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyConfigKeys.java
new file mode 100644
index 0000000..2f24026
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyConfigKeys.java
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.canopy;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+
+@Deprecated
+public final class CanopyConfigKeys {
+
+ private CanopyConfigKeys() {}
+
+ public static final String T1_KEY = "org.apache.mahout.clustering.canopy.t1";
+
+ public static final String T2_KEY = "org.apache.mahout.clustering.canopy.t2";
+
+ public static final String T3_KEY = "org.apache.mahout.clustering.canopy.t3";
+
+ public static final String T4_KEY = "org.apache.mahout.clustering.canopy.t4";
+
+ // keys used by Driver, Mapper, Combiner & Reducer
+ public static final String DISTANCE_MEASURE_KEY = "org.apache.mahout.clustering.canopy.measure";
+
+ public static final String CF_KEY = "org.apache.mahout.clustering.canopy.canopyFilter";
+
+ /**
+ * Create a {@link CanopyClusterer} from the Hadoop configuration.
+ *
+ * @param configuration Hadoop configuration
+ *
+ * @return CanopyClusterer
+ */
+ public static CanopyClusterer configureCanopyClusterer(Configuration configuration) {
+ double t1 = Double.parseDouble(configuration.get(T1_KEY));
+ double t2 = Double.parseDouble(configuration.get(T2_KEY));
+
+ DistanceMeasure measure = ClassUtils.instantiateAs(configuration.get(DISTANCE_MEASURE_KEY), DistanceMeasure.class);
+ measure.configure(configuration);
+
+ CanopyClusterer canopyClusterer = new CanopyClusterer(measure, t1, t2);
+
+ String d = configuration.get(T3_KEY);
+ if (d != null) {
+ canopyClusterer.setT3(Double.parseDouble(d));
+ }
+
+ d = configuration.get(T4_KEY);
+ if (d != null) {
+ canopyClusterer.setT4(Double.parseDouble(d));
+ }
+ return canopyClusterer;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java b/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java
new file mode 100644
index 0000000..06dc947
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyDriver.java
@@ -0,0 +1,379 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.canopy;
+
+import java.io.IOException;
+import java.util.Collection;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.AbstractCluster;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassificationDriver;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.clustering.iterator.CanopyClusteringPolicy;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.clustering.topdown.PathDirectory;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+@Deprecated
+public class CanopyDriver extends AbstractJob {
+
+ public static final String DEFAULT_CLUSTERED_POINTS_DIRECTORY = "clusteredPoints";
+
+ private static final Logger log = LoggerFactory.getLogger(CanopyDriver.class);
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new CanopyDriver(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+ addOption(DefaultOptionCreator.t1Option().create());
+ addOption(DefaultOptionCreator.t2Option().create());
+ addOption(DefaultOptionCreator.t3Option().create());
+ addOption(DefaultOptionCreator.t4Option().create());
+ addOption(DefaultOptionCreator.clusterFilterOption().create());
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addOption(DefaultOptionCreator.clusteringOption().create());
+ addOption(DefaultOptionCreator.methodOption().create());
+ addOption(DefaultOptionCreator.outlierThresholdOption().create());
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ Configuration conf = getConf();
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(conf, output);
+ }
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ double t1 = Double.parseDouble(getOption(DefaultOptionCreator.T1_OPTION));
+ double t2 = Double.parseDouble(getOption(DefaultOptionCreator.T2_OPTION));
+ double t3 = t1;
+ if (hasOption(DefaultOptionCreator.T3_OPTION)) {
+ t3 = Double.parseDouble(getOption(DefaultOptionCreator.T3_OPTION));
+ }
+ double t4 = t2;
+ if (hasOption(DefaultOptionCreator.T4_OPTION)) {
+ t4 = Double.parseDouble(getOption(DefaultOptionCreator.T4_OPTION));
+ }
+ int clusterFilter = 0;
+ if (hasOption(DefaultOptionCreator.CLUSTER_FILTER_OPTION)) {
+ clusterFilter = Integer
+ .parseInt(getOption(DefaultOptionCreator.CLUSTER_FILTER_OPTION));
+ }
+ boolean runClustering = hasOption(DefaultOptionCreator.CLUSTERING_OPTION);
+ boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION)
+ .equalsIgnoreCase(DefaultOptionCreator.SEQUENTIAL_METHOD);
+ DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
+ double clusterClassificationThreshold = 0.0;
+ if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
+ clusterClassificationThreshold = Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
+ }
+ run(conf, input, output, measure, t1, t2, t3, t4, clusterFilter,
+ runClustering, clusterClassificationThreshold, runSequential);
+ return 0;
+ }
+
+ /**
+ * Build a directory of Canopy clusters from the input arguments and, if
+ * requested, cluster the input vectors using these clusters
+ *
+ * @param conf
+ * the Configuration
+ * @param input
+ * the Path to the directory containing input vectors
+ * @param output
+ * the Path for all output directories
+ * @param measure
+ * the DistanceMeasure
+ * @param t1
+ * the double T1 distance metric
+ * @param t2
+ * the double T2 distance metric
+ * @param t3
+ * the reducer's double T1 distance metric
+ * @param t4
+ * the reducer's double T2 distance metric
+ * @param clusterFilter
+ * the minimum canopy size output by the mappers
+ * @param runClustering
+ * cluster the input vectors if true
+ * @param clusterClassificationThreshold
+ * vectors having pdf below this value will not be clustered. Its value should be between 0 and 1.
+ * @param runSequential
+ * execute sequentially if true
+ */
+ public static void run(Configuration conf, Path input, Path output,
+ DistanceMeasure measure, double t1, double t2, double t3, double t4,
+ int clusterFilter, boolean runClustering, double clusterClassificationThreshold, boolean runSequential)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Path clustersOut = buildClusters(conf, input, output, measure, t1, t2, t3,
+ t4, clusterFilter, runSequential);
+ if (runClustering) {
+ clusterData(conf, input, clustersOut, output, clusterClassificationThreshold, runSequential);
+ }
+ }
+
+ /**
+ * Convenience method to provide backward compatibility
+ */
+ public static void run(Configuration conf, Path input, Path output,
+ DistanceMeasure measure, double t1, double t2, boolean runClustering,
+ double clusterClassificationThreshold, boolean runSequential) throws IOException, InterruptedException,
+ ClassNotFoundException {
+ run(conf, input, output, measure, t1, t2, t1, t2, 0, runClustering,
+ clusterClassificationThreshold, runSequential);
+ }
+
+ /**
+ * Convenience method creates new Configuration() Build a directory of Canopy
+ * clusters from the input arguments and, if requested, cluster the input
+ * vectors using these clusters
+ *
+ * @param input
+ * the Path to the directory containing input vectors
+ * @param output
+ * the Path for all output directories
+ * @param t1
+ * the double T1 distance metric
+ * @param t2
+ * the double T2 distance metric
+ * @param runClustering
+ * cluster the input vectors if true
+ * @param clusterClassificationThreshold
+ * vectors having pdf below this value will not be clustered. Its value should be between 0 and 1.
+ * @param runSequential
+ * execute sequentially if true
+ */
+ public static void run(Path input, Path output, DistanceMeasure measure,
+ double t1, double t2, boolean runClustering, double clusterClassificationThreshold, boolean runSequential)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ run(new Configuration(), input, output, measure, t1, t2, runClustering,
+ clusterClassificationThreshold, runSequential);
+ }
+
+ /**
+ * Convenience method for backwards compatibility
+ *
+ */
+ public static Path buildClusters(Configuration conf, Path input, Path output,
+ DistanceMeasure measure, double t1, double t2, int clusterFilter,
+ boolean runSequential) throws IOException, InterruptedException,
+ ClassNotFoundException {
+ return buildClusters(conf, input, output, measure, t1, t2, t1, t2,
+ clusterFilter, runSequential);
+ }
+
+ /**
+ * Build a directory of Canopy clusters from the input vectors and other
+ * arguments. Run sequential or mapreduce execution as requested
+ *
+ * @param conf
+ * the Configuration to use
+ * @param input
+ * the Path to the directory containing input vectors
+ * @param output
+ * the Path for all output directories
+ * @param measure
+ * the DistanceMeasure
+ * @param t1
+ * the double T1 distance metric
+ * @param t2
+ * the double T2 distance metric
+ * @param t3
+ * the reducer's double T1 distance metric
+ * @param t4
+ * the reducer's double T2 distance metric
+ * @param clusterFilter
+ * the int minimum size of canopies produced
+ * @param runSequential
+ * a boolean indicates to run the sequential (reference) algorithm
+ * @return the canopy output directory Path
+ */
+ public static Path buildClusters(Configuration conf, Path input, Path output,
+ DistanceMeasure measure, double t1, double t2, double t3, double t4,
+ int clusterFilter, boolean runSequential) throws IOException,
+ InterruptedException, ClassNotFoundException {
+ log.info("Build Clusters Input: {} Out: {} Measure: {} t1: {} t2: {}",
+ input, output, measure, t1, t2);
+ if (runSequential) {
+ return buildClustersSeq(input, output, measure, t1, t2, clusterFilter);
+ } else {
+ return buildClustersMR(conf, input, output, measure, t1, t2, t3, t4,
+ clusterFilter);
+ }
+ }
+
+ /**
+ * Build a directory of Canopy clusters from the input vectors and other
+ * arguments. Run sequential execution
+ *
+ * @param input
+ * the Path to the directory containing input vectors
+ * @param output
+ * the Path for all output directories
+ * @param measure
+ * the DistanceMeasure
+ * @param t1
+ * the double T1 distance metric
+ * @param t2
+ * the double T2 distance metric
+ * @param clusterFilter
+ * the int minimum size of canopies produced
+ * @return the canopy output directory Path
+ */
+ private static Path buildClustersSeq(Path input, Path output,
+ DistanceMeasure measure, double t1, double t2, int clusterFilter)
+ throws IOException {
+ CanopyClusterer clusterer = new CanopyClusterer(measure, t1, t2);
+ Collection<Canopy> canopies = Lists.newArrayList();
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(input.toUri(), conf);
+
+ for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(
+ input, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
+ clusterer.addPointToCanopies(vw.get(), canopies);
+ }
+
+ Path canopyOutputDir = new Path(output, Cluster.CLUSTERS_DIR + '0' + Cluster.FINAL_ITERATION_SUFFIX);
+ Path path = new Path(canopyOutputDir, "part-r-00000");
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path,
+ Text.class, ClusterWritable.class);
+ try {
+ ClusterWritable clusterWritable = new ClusterWritable();
+ for (Canopy canopy : canopies) {
+ canopy.computeParameters();
+ if (log.isDebugEnabled()) {
+ log.debug("Writing Canopy:{} center:{} numPoints:{} radius:{}",
+ canopy.getIdentifier(),
+ AbstractCluster.formatVector(canopy.getCenter(), null),
+ canopy.getNumObservations(),
+ AbstractCluster.formatVector(canopy.getRadius(), null));
+ }
+ if (canopy.getNumObservations() > clusterFilter) {
+ clusterWritable.setValue(canopy);
+ writer.append(new Text(canopy.getIdentifier()), clusterWritable);
+ }
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ return canopyOutputDir;
+ }
+
+ /**
+ * Build a directory of Canopy clusters from the input vectors and other
+ * arguments. Run mapreduce execution
+ *
+ * @param conf
+ * the Configuration
+ * @param input
+ * the Path to the directory containing input vectors
+ * @param output
+ * the Path for all output directories
+ * @param measure
+ * the DistanceMeasure
+ * @param t1
+ * the double T1 distance metric
+ * @param t2
+ * the double T2 distance metric
+ * @param t3
+ * the reducer's double T1 distance metric
+ * @param t4
+ * the reducer's double T2 distance metric
+ * @param clusterFilter
+ * the int minimum size of canopies produced
+ * @return the canopy output directory Path
+ */
+ private static Path buildClustersMR(Configuration conf, Path input,
+ Path output, DistanceMeasure measure, double t1, double t2, double t3,
+ double t4, int clusterFilter) throws IOException, InterruptedException,
+ ClassNotFoundException {
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, measure.getClass()
+ .getName());
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(t1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(t2));
+ conf.set(CanopyConfigKeys.T3_KEY, String.valueOf(t3));
+ conf.set(CanopyConfigKeys.T4_KEY, String.valueOf(t4));
+ conf.set(CanopyConfigKeys.CF_KEY, String.valueOf(clusterFilter));
+
+ Job job = new Job(conf, "Canopy Driver running buildClusters over input: "
+ + input);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setMapperClass(CanopyMapper.class);
+ job.setMapOutputKeyClass(Text.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+ job.setReducerClass(CanopyReducer.class);
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(ClusterWritable.class);
+ job.setNumReduceTasks(1);
+ job.setJarByClass(CanopyDriver.class);
+
+ FileInputFormat.addInputPath(job, input);
+ Path canopyOutputDir = new Path(output, Cluster.CLUSTERS_DIR + '0' + Cluster.FINAL_ITERATION_SUFFIX);
+ FileOutputFormat.setOutputPath(job, canopyOutputDir);
+ if (!job.waitForCompletion(true)) {
+ throw new InterruptedException("Canopy Job failed processing " + input);
+ }
+ return canopyOutputDir;
+ }
+
+ private static void clusterData(Configuration conf,
+ Path points,
+ Path canopies,
+ Path output,
+ double clusterClassificationThreshold,
+ boolean runSequential)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ ClusterClassifier.writePolicy(new CanopyClusteringPolicy(), canopies);
+ ClusterClassificationDriver.run(conf, points, output, new Path(output, PathDirectory.CLUSTERED_POINTS_DIRECTORY),
+ clusterClassificationThreshold, true, runSequential);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java b/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java
new file mode 100644
index 0000000..265d3da
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyMapper.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.canopy;
+
+import java.io.IOException;
+import java.util.Collection;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VectorWritable;
+
+@Deprecated
+class CanopyMapper extends
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
+
+ private final Collection<Canopy> canopies = Lists.newArrayList();
+
+ private CanopyClusterer canopyClusterer;
+
+ private int clusterFilter;
+
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable point,
+ Context context) throws IOException, InterruptedException {
+ canopyClusterer.addPointToCanopies(point.get(), canopies);
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ super.setup(context);
+ canopyClusterer = CanopyConfigKeys.configureCanopyClusterer(context.getConfiguration());
+ clusterFilter = Integer.parseInt(context.getConfiguration().get(
+ CanopyConfigKeys.CF_KEY));
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+ for (Canopy canopy : canopies) {
+ canopy.computeParameters();
+ if (canopy.getNumObservations() > clusterFilter) {
+ context.write(new Text("centroid"), new VectorWritable(canopy
+ .getCenter()));
+ }
+ }
+ super.cleanup(context);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java b/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java
new file mode 100644
index 0000000..cdd7d5e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/canopy/CanopyReducer.java
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.canopy;
+
+import java.io.IOException;
+import java.util.Collection;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+@Deprecated
+public class CanopyReducer extends Reducer<Text, VectorWritable, Text, ClusterWritable> {
+
+ private final Collection<Canopy> canopies = Lists.newArrayList();
+
+ private CanopyClusterer canopyClusterer;
+
+ private int clusterFilter;
+
+ CanopyClusterer getCanopyClusterer() {
+ return canopyClusterer;
+ }
+
+ @Override
+ protected void reduce(Text arg0, Iterable<VectorWritable> values,
+ Context context) throws IOException, InterruptedException {
+ for (VectorWritable value : values) {
+ Vector point = value.get();
+ canopyClusterer.addPointToCanopies(point, canopies);
+ }
+ for (Canopy canopy : canopies) {
+ canopy.computeParameters();
+ if (canopy.getNumObservations() > clusterFilter) {
+ ClusterWritable clusterWritable = new ClusterWritable();
+ clusterWritable.setValue(canopy);
+ context.write(new Text(canopy.getIdentifier()), clusterWritable);
+ }
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ super.setup(context);
+ canopyClusterer = CanopyConfigKeys.configureCanopyClusterer(context.getConfiguration());
+ canopyClusterer.useT3T4();
+ clusterFilter = Integer.parseInt(context.getConfiguration().get(
+ CanopyConfigKeys.CF_KEY));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java b/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java
new file mode 100644
index 0000000..6b88388
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationConfigKeys.java
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+/**
+ * Constants used in Cluster Classification.
+ */
+public final class ClusterClassificationConfigKeys {
+
+ public static final String CLUSTERS_IN = "clusters_in";
+
+ public static final String OUTLIER_REMOVAL_THRESHOLD = "pdf_threshold";
+
+ public static final String EMIT_MOST_LIKELY = "emit_most_likely";
+
+ private ClusterClassificationConfigKeys() {
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java b/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
new file mode 100644
index 0000000..6e2c3cf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
@@ -0,0 +1,313 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.clustering.iterator.ClusteringPolicy;
+import org.apache.mahout.clustering.iterator.DistanceMeasureCluster;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * Classifies the vectors into different clusters found by the clustering
+ * algorithm.
+ */
+public final class ClusterClassificationDriver extends AbstractJob {
+
+ /**
+ * CLI to run Cluster Classification Driver.
+ */
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.methodOption().create());
+ addOption(DefaultOptionCreator.clustersInOption()
+ .withDescription("The input centroids, as Vectors. Must be a SequenceFile of Writable, Cluster/Canopy.")
+ .create());
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path output = getOutputPath();
+
+ if (getConf() == null) {
+ setConf(new Configuration());
+ }
+ Path clustersIn = new Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
+ boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
+ DefaultOptionCreator.SEQUENTIAL_METHOD);
+
+ double clusterClassificationThreshold = 0.0;
+ if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
+ clusterClassificationThreshold = Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
+ }
+
+ run(getConf(), input, clustersIn, output, clusterClassificationThreshold, true, runSequential);
+
+ return 0;
+ }
+
+ /**
+ * Constructor to be used by the ToolRunner.
+ */
+ private ClusterClassificationDriver() {
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new ClusterClassificationDriver(), args);
+ }
+
+ /**
+ * Uses {@link ClusterClassifier} to classify input vectors into their
+ * respective clusters.
+ *
+ * @param input
+ * the input vectors
+ * @param clusteringOutputPath
+ * the output path of clustering ( it reads clusters-*-final file
+ * from here )
+ * @param output
+ * the location to store the classified vectors
+ * @param clusterClassificationThreshold
+ * the threshold value of probability distribution function from 0.0
+ * to 1.0. Any vector with pdf less that this threshold will not be
+ * classified for the cluster.
+ * @param runSequential
+ * Run the process sequentially or in a mapreduce way.
+ * @throws IOException
+ * @throws InterruptedException
+ * @throws ClassNotFoundException
+ */
+ public static void run(Configuration conf, Path input, Path clusteringOutputPath, Path output, Double clusterClassificationThreshold,
+ boolean emitMostLikely, boolean runSequential) throws IOException, InterruptedException, ClassNotFoundException {
+ if (runSequential) {
+ classifyClusterSeq(conf, input, clusteringOutputPath, output, clusterClassificationThreshold, emitMostLikely);
+ } else {
+ classifyClusterMR(conf, input, clusteringOutputPath, output, clusterClassificationThreshold, emitMostLikely);
+ }
+
+ }
+
+ private static void classifyClusterSeq(Configuration conf, Path input, Path clusters, Path output,
+ Double clusterClassificationThreshold, boolean emitMostLikely) throws IOException {
+ List<Cluster> clusterModels = populateClusterModels(clusters, conf);
+ ClusteringPolicy policy = ClusterClassifier.readPolicy(finalClustersPath(conf, clusters));
+ ClusterClassifier clusterClassifier = new ClusterClassifier(clusterModels, policy);
+ selectCluster(input, clusterModels, clusterClassifier, output, clusterClassificationThreshold, emitMostLikely);
+
+ }
+
+ /**
+ * Populates a list with clusters present in clusters-*-final directory.
+ *
+ * @param clusterOutputPath
+ * The output path of the clustering.
+ * @param conf
+ * The Hadoop Configuration
+ * @return The list of clusters found by the clustering.
+ * @throws IOException
+ */
+ private static List<Cluster> populateClusterModels(Path clusterOutputPath, Configuration conf) throws IOException {
+ List<Cluster> clusterModels = Lists.newArrayList();
+ Path finalClustersPath = finalClustersPath(conf, clusterOutputPath);
+ Iterator<?> it = new SequenceFileDirValueIterator<Writable>(finalClustersPath, PathType.LIST,
+ PathFilters.partFilter(), null, false, conf);
+ while (it.hasNext()) {
+ ClusterWritable next = (ClusterWritable) it.next();
+ Cluster cluster = next.getValue();
+ cluster.configure(conf);
+ clusterModels.add(cluster);
+ }
+ return clusterModels;
+ }
+
+ private static Path finalClustersPath(Configuration conf, Path clusterOutputPath) throws IOException {
+ FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
+ FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
+ return clusterFiles[0].getPath();
+ }
+
+ /**
+ * Classifies the vector into its respective cluster.
+ *
+ * @param input
+ * the path containing the input vector.
+ * @param clusterModels
+ * the clusters
+ * @param clusterClassifier
+ * used to classify the vectors into different clusters
+ * @param output
+ * the path to store classified data
+ * @param clusterClassificationThreshold
+ * the threshold value of probability distribution function from 0.0
+ * to 1.0. Any vector with pdf less that this threshold will not be
+ * classified for the cluster
+ * @param emitMostLikely
+ * emit the vectors with the max pdf values per cluster
+ * @throws IOException
+ */
+ private static void selectCluster(Path input, List<Cluster> clusterModels, ClusterClassifier clusterClassifier,
+ Path output, Double clusterClassificationThreshold, boolean emitMostLikely) throws IOException {
+ Configuration conf = new Configuration();
+ SequenceFile.Writer writer = new SequenceFile.Writer(input.getFileSystem(conf), conf, new Path(output,
+ "part-m-" + 0), IntWritable.class, WeightedPropertyVectorWritable.class);
+ for (Pair<Writable, VectorWritable> vw : new SequenceFileDirIterable<Writable, VectorWritable>(input, PathType.LIST,
+ PathFilters.logsCRCFilter(), conf)) {
+ // Converting to NamedVectors to preserve the vectorId else its not obvious as to which point
+ // belongs to which cluster - fix for MAHOUT-1410
+ Class<? extends Writable> keyClass = vw.getFirst().getClass();
+ Vector vector = vw.getSecond().get();
+ if (!keyClass.equals(NamedVector.class)) {
+ if (keyClass.equals(Text.class)) {
+ vector = new NamedVector(vector, vw.getFirst().toString());
+ } else if (keyClass.equals(IntWritable.class)) {
+ vector = new NamedVector(vector, Integer.toString(((IntWritable) vw.getFirst()).get()));
+ }
+ }
+ Vector pdfPerCluster = clusterClassifier.classify(vector);
+ if (shouldClassify(pdfPerCluster, clusterClassificationThreshold)) {
+ classifyAndWrite(clusterModels, clusterClassificationThreshold, emitMostLikely, writer, new VectorWritable(vector), pdfPerCluster);
+ }
+ }
+ writer.close();
+ }
+
+ private static void classifyAndWrite(List<Cluster> clusterModels, Double clusterClassificationThreshold,
+ boolean emitMostLikely, SequenceFile.Writer writer, VectorWritable vw, Vector pdfPerCluster) throws IOException {
+ Map<Text, Text> props = Maps.newHashMap();
+ if (emitMostLikely) {
+ int maxValueIndex = pdfPerCluster.maxValueIndex();
+ WeightedPropertyVectorWritable weightedPropertyVectorWritable =
+ new WeightedPropertyVectorWritable(pdfPerCluster.maxValue(), vw.get(), props);
+ write(clusterModels, writer, weightedPropertyVectorWritable, maxValueIndex);
+ } else {
+ writeAllAboveThreshold(clusterModels, clusterClassificationThreshold, writer, vw, pdfPerCluster);
+ }
+ }
+
+ private static void writeAllAboveThreshold(List<Cluster> clusterModels, Double clusterClassificationThreshold,
+ SequenceFile.Writer writer, VectorWritable vw, Vector pdfPerCluster) throws IOException {
+ Map<Text, Text> props = Maps.newHashMap();
+ for (Element pdf : pdfPerCluster.nonZeroes()) {
+ if (pdf.get() >= clusterClassificationThreshold) {
+ WeightedPropertyVectorWritable wvw = new WeightedPropertyVectorWritable(pdf.get(), vw.get(), props);
+ int clusterIndex = pdf.index();
+ write(clusterModels, writer, wvw, clusterIndex);
+ }
+ }
+ }
+
+ private static void write(List<Cluster> clusterModels, SequenceFile.Writer writer,
+ WeightedPropertyVectorWritable weightedPropertyVectorWritable,
+ int maxValueIndex) throws IOException {
+ Cluster cluster = clusterModels.get(maxValueIndex);
+
+ DistanceMeasureCluster distanceMeasureCluster = (DistanceMeasureCluster) cluster;
+ DistanceMeasure distanceMeasure = distanceMeasureCluster.getMeasure();
+ double distance = distanceMeasure.distance(cluster.getCenter(), weightedPropertyVectorWritable.getVector());
+
+ weightedPropertyVectorWritable.getProperties().put(new Text("distance"), new Text(Double.toString(distance)));
+ writer.append(new IntWritable(cluster.getId()), weightedPropertyVectorWritable);
+ }
+
+ /**
+ * Decides whether the vector should be classified or not based on the max pdf
+ * value of the clusters and threshold value.
+ *
+ * @return whether the vector should be classified or not.
+ */
+ private static boolean shouldClassify(Vector pdfPerCluster, Double clusterClassificationThreshold) {
+ return pdfPerCluster.maxValue() >= clusterClassificationThreshold;
+ }
+
+ private static void classifyClusterMR(Configuration conf, Path input, Path clustersIn, Path output,
+ Double clusterClassificationThreshold, boolean emitMostLikely) throws IOException, InterruptedException,
+ ClassNotFoundException {
+
+ conf.setFloat(ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD,
+ clusterClassificationThreshold.floatValue());
+ conf.setBoolean(ClusterClassificationConfigKeys.EMIT_MOST_LIKELY, emitMostLikely);
+ conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN, clustersIn.toUri().toString());
+
+ Job job = new Job(conf, "Cluster Classification Driver running over input: " + input);
+ job.setJarByClass(ClusterClassificationDriver.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ job.setMapperClass(ClusterClassificationMapper.class);
+ job.setNumReduceTasks(0);
+
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(WeightedPropertyVectorWritable.class);
+
+ FileInputFormat.addInputPath(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+ if (!job.waitForCompletion(true)) {
+ throw new InterruptedException("Cluster Classification Driver Job failed processing " + input);
+ }
+ }
+
+ public static void run(Configuration conf, Path input, Path clusteringOutputPath, Path output,
+ double clusterClassificationThreshold, boolean emitMostLikely, boolean runSequential) throws IOException,
+ InterruptedException, ClassNotFoundException {
+ if (runSequential) {
+ classifyClusterSeq(conf, input, clusteringOutputPath, output, clusterClassificationThreshold, emitMostLikely);
+ } else {
+ classifyClusterMR(conf, input, clusteringOutputPath, output, clusterClassificationThreshold, emitMostLikely);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java b/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
new file mode 100644
index 0000000..9edbd8e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.clustering.iterator.ClusteringPolicy;
+import org.apache.mahout.clustering.iterator.DistanceMeasureCluster;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * Mapper for classifying vectors into clusters.
+ */
+public class ClusterClassificationMapper extends
+ Mapper<WritableComparable<?>,VectorWritable,IntWritable,WeightedVectorWritable> {
+
+ private double threshold;
+ private List<Cluster> clusterModels;
+ private ClusterClassifier clusterClassifier;
+ private IntWritable clusterId;
+ private boolean emitMostLikely;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+
+ Configuration conf = context.getConfiguration();
+ String clustersIn = conf.get(ClusterClassificationConfigKeys.CLUSTERS_IN);
+ threshold = conf.getFloat(ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD, 0.0f);
+ emitMostLikely = conf.getBoolean(ClusterClassificationConfigKeys.EMIT_MOST_LIKELY, false);
+
+ clusterModels = Lists.newArrayList();
+
+ if (clustersIn != null && !clustersIn.isEmpty()) {
+ Path clustersInPath = new Path(clustersIn);
+ clusterModels = populateClusterModels(clustersInPath, conf);
+ ClusteringPolicy policy = ClusterClassifier
+ .readPolicy(finalClustersPath(clustersInPath));
+ clusterClassifier = new ClusterClassifier(clusterModels, policy);
+ }
+ clusterId = new IntWritable();
+ }
+
+ /**
+ * Mapper which classifies the vectors to respective clusters.
+ */
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable vw, Context context)
+ throws IOException, InterruptedException {
+ if (!clusterModels.isEmpty()) {
+ // Converting to NamedVectors to preserve the vectorId else its not obvious as to which point
+ // belongs to which cluster - fix for MAHOUT-1410
+ Class<? extends Vector> vectorClass = vw.get().getClass();
+ Vector vector = vw.get();
+ if (!vectorClass.equals(NamedVector.class)) {
+ if (key.getClass().equals(Text.class)) {
+ vector = new NamedVector(vector, key.toString());
+ } else if (key.getClass().equals(IntWritable.class)) {
+ vector = new NamedVector(vector, Integer.toString(((IntWritable) key).get()));
+ }
+ }
+ Vector pdfPerCluster = clusterClassifier.classify(vector);
+ if (shouldClassify(pdfPerCluster)) {
+ if (emitMostLikely) {
+ int maxValueIndex = pdfPerCluster.maxValueIndex();
+ write(new VectorWritable(vector), context, maxValueIndex, 1.0);
+ } else {
+ writeAllAboveThreshold(new VectorWritable(vector), context, pdfPerCluster);
+ }
+ }
+ }
+ }
+
+ private void writeAllAboveThreshold(VectorWritable vw, Context context,
+ Vector pdfPerCluster) throws IOException, InterruptedException {
+ for (Element pdf : pdfPerCluster.nonZeroes()) {
+ if (pdf.get() >= threshold) {
+ int clusterIndex = pdf.index();
+ write(vw, context, clusterIndex, pdf.get());
+ }
+ }
+ }
+
+ private void write(VectorWritable vw, Context context, int clusterIndex, double weight)
+ throws IOException, InterruptedException {
+ Cluster cluster = clusterModels.get(clusterIndex);
+ clusterId.set(cluster.getId());
+
+ DistanceMeasureCluster distanceMeasureCluster = (DistanceMeasureCluster) cluster;
+ DistanceMeasure distanceMeasure = distanceMeasureCluster.getMeasure();
+ double distance = distanceMeasure.distance(cluster.getCenter(), vw.get());
+
+ Map<Text, Text> props = Maps.newHashMap();
+ props.put(new Text("distance"), new Text(Double.toString(distance)));
+ context.write(clusterId, new WeightedPropertyVectorWritable(weight, vw.get(), props));
+ }
+
+ public static List<Cluster> populateClusterModels(Path clusterOutputPath, Configuration conf) throws IOException {
+ List<Cluster> clusters = Lists.newArrayList();
+ FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
+ FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
+ Iterator<?> it = new SequenceFileDirValueIterator<Writable>(
+ clusterFiles[0].getPath(), PathType.LIST, PathFilters.partFilter(),
+ null, false, conf);
+ while (it.hasNext()) {
+ ClusterWritable next = (ClusterWritable) it.next();
+ Cluster cluster = next.getValue();
+ cluster.configure(conf);
+ clusters.add(cluster);
+ }
+ return clusters;
+ }
+
+ private boolean shouldClassify(Vector pdfPerCluster) {
+ return pdfPerCluster.maxValue() >= threshold;
+ }
+
+ private static Path finalClustersPath(Path clusterOutputPath) throws IOException {
+ FileSystem fileSystem = clusterOutputPath.getFileSystem(new Configuration());
+ FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter());
+ return clusterFiles[0].getPath();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassifier.java b/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassifier.java
new file mode 100644
index 0000000..d5f8d64
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassifier.java
@@ -0,0 +1,240 @@
+/* Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.classify;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+import java.util.Locale;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.clustering.iterator.ClusteringPolicy;
+import org.apache.mahout.clustering.iterator.ClusteringPolicyWritable;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+/**
+ * This classifier works with any ClusteringPolicy and its associated Clusters.
+ * It is initialized with a policy and a list of compatible clusters and
+ * thereafter it can classify any new Vector into one or more of the clusters
+ * based upon the pdf() function which each cluster supports.
+ *
+ * In addition, it is an OnlineLearner and can be trained. Training amounts to
+ * asking the actual model to observe the vector and closing the classifier
+ * causes all the models to computeParameters.
+ *
+ * Because a ClusterClassifier implements Writable, it can be written-to and
+ * read-from a sequence file as a single entity. For sequential and MapReduce
+ * clustering in conjunction with a ClusterIterator; however, it utilizes an
+ * exploded file format. In this format, the iterator writes the policy to a
+ * single POLICY_FILE_NAME file in the clustersOut directory and the models are
+ * written to one or more part-n files so that multiple reducers may employed to
+ * produce them.
+ */
+public class ClusterClassifier extends AbstractVectorClassifier implements OnlineLearner, Writable {
+
+ private static final String POLICY_FILE_NAME = "_policy";
+
+ private List<Cluster> models;
+
+ private String modelClass;
+
+ private ClusteringPolicy policy;
+
+ /**
+ * The public constructor accepts a list of clusters to become the models
+ *
+ * @param models
+ * a List<Cluster>
+ * @param policy
+ * a ClusteringPolicy
+ */
+ public ClusterClassifier(List<Cluster> models, ClusteringPolicy policy) {
+ this.models = models;
+ modelClass = models.get(0).getClass().getName();
+ this.policy = policy;
+ }
+
+ // needed for serialization/De-serialization
+ public ClusterClassifier() {}
+
+ // only used by MR ClusterIterator
+ protected ClusterClassifier(ClusteringPolicy policy) {
+ this.policy = policy;
+ }
+
+ @Override
+ public Vector classify(Vector instance) {
+ return policy.classify(instance, this);
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ if (models.size() == 2) {
+ double pdf0 = models.get(0).pdf(new VectorWritable(instance));
+ double pdf1 = models.get(1).pdf(new VectorWritable(instance));
+ return pdf0 / (pdf0 + pdf1);
+ }
+ throw new IllegalStateException();
+ }
+
+ @Override
+ public int numCategories() {
+ return models.size();
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(models.size());
+ out.writeUTF(modelClass);
+ new ClusteringPolicyWritable(policy).write(out);
+ for (Cluster cluster : models) {
+ cluster.write(out);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int size = in.readInt();
+ modelClass = in.readUTF();
+ models = Lists.newArrayList();
+ ClusteringPolicyWritable clusteringPolicyWritable = new ClusteringPolicyWritable();
+ clusteringPolicyWritable.readFields(in);
+ policy = clusteringPolicyWritable.getValue();
+ for (int i = 0; i < size; i++) {
+ Cluster element = ClassUtils.instantiateAs(modelClass, Cluster.class);
+ element.readFields(in);
+ models.add(element);
+ }
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ models.get(actual).observe(new VectorWritable(instance));
+ }
+
+ /**
+ * Train the models given an additional weight. Unique to ClusterClassifier
+ *
+ * @param actual
+ * the int index of a model
+ * @param data
+ * a data Vector
+ * @param weight
+ * a double weighting factor
+ */
+ public void train(int actual, Vector data, double weight) {
+ models.get(actual).observe(new VectorWritable(data), weight);
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ models.get(actual).observe(new VectorWritable(instance));
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ models.get(actual).observe(new VectorWritable(instance));
+ }
+
+ @Override
+ public void close() {
+ policy.close(this);
+ }
+
+ public List<Cluster> getModels() {
+ return models;
+ }
+
+ public ClusteringPolicy getPolicy() {
+ return policy;
+ }
+
+ public void writeToSeqFiles(Path path) throws IOException {
+ writePolicy(policy, path);
+ Configuration config = new Configuration();
+ FileSystem fs = FileSystem.get(path.toUri(), config);
+ SequenceFile.Writer writer = null;
+ ClusterWritable cw = new ClusterWritable();
+ for (int i = 0; i < models.size(); i++) {
+ try {
+ Cluster cluster = models.get(i);
+ cw.setValue(cluster);
+ writer = new SequenceFile.Writer(fs, config,
+ new Path(path, "part-" + String.format(Locale.ENGLISH, "%05d", i)), IntWritable.class,
+ ClusterWritable.class);
+ Writable key = new IntWritable(i);
+ writer.append(key, cw);
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+ }
+
+ public void readFromSeqFiles(Configuration conf, Path path) throws IOException {
+ Configuration config = new Configuration();
+ List<Cluster> clusters = Lists.newArrayList();
+ for (ClusterWritable cw : new SequenceFileDirValueIterable<ClusterWritable>(path, PathType.LIST,
+ PathFilters.logsCRCFilter(), config)) {
+ Cluster cluster = cw.getValue();
+ cluster.configure(conf);
+ clusters.add(cluster);
+ }
+ this.models = clusters;
+ modelClass = models.get(0).getClass().getName();
+ this.policy = readPolicy(path);
+ }
+
+ public static ClusteringPolicy readPolicy(Path path) throws IOException {
+ Path policyPath = new Path(path, POLICY_FILE_NAME);
+ Configuration config = new Configuration();
+ FileSystem fs = FileSystem.get(policyPath.toUri(), config);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, policyPath, config);
+ Text key = new Text();
+ ClusteringPolicyWritable cpw = new ClusteringPolicyWritable();
+ reader.next(key, cpw);
+ Closeables.close(reader, true);
+ return cpw.getValue();
+ }
+
+ public static void writePolicy(ClusteringPolicy policy, Path path) throws IOException {
+ Path policyPath = new Path(path, POLICY_FILE_NAME);
+ Configuration config = new Configuration();
+ FileSystem fs = FileSystem.get(policyPath.toUri(), config);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, policyPath, Text.class,
+ ClusteringPolicyWritable.class);
+ writer.append(new Text(), new ClusteringPolicyWritable(policy));
+ Closeables.close(writer, false);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/classify/WeightedPropertyVectorWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/classify/WeightedPropertyVectorWritable.java b/mr/src/main/java/org/apache/mahout/clustering/classify/WeightedPropertyVectorWritable.java
new file mode 100644
index 0000000..567659b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/classify/WeightedPropertyVectorWritable.java
@@ -0,0 +1,95 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.clustering.AbstractCluster;
+import org.apache.mahout.math.Vector;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+public class WeightedPropertyVectorWritable extends WeightedVectorWritable {
+
+ private Map<Text, Text> properties;
+
+ public WeightedPropertyVectorWritable() {
+ }
+
+ public WeightedPropertyVectorWritable(Map<Text, Text> properties) {
+ this.properties = properties;
+ }
+
+ public WeightedPropertyVectorWritable(double weight, Vector vector, Map<Text, Text> properties) {
+ super(weight, vector);
+ this.properties = properties;
+ }
+
+ public Map<Text, Text> getProperties() {
+ return properties;
+ }
+
+ public void setProperties(Map<Text, Text> properties) {
+ this.properties = properties;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ super.readFields(in);
+ int size = in.readInt();
+ if (size > 0) {
+ properties = new HashMap<>();
+ for (int i = 0; i < size; i++) {
+ Text key = new Text(in.readUTF());
+ Text val = new Text(in.readUTF());
+ properties.put(key, val);
+ }
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ super.write(out);
+ out.writeInt(properties != null ? properties.size() : 0);
+ if (properties != null) {
+ for (Map.Entry<Text, Text> entry : properties.entrySet()) {
+ out.writeUTF(entry.getKey().toString());
+ out.writeUTF(entry.getValue().toString());
+ }
+ }
+ }
+
+ @Override
+ public String toString() {
+ Vector vector = getVector();
+ StringBuilder bldr = new StringBuilder("wt: ").append(getWeight()).append(' ');
+ if (properties != null && !properties.isEmpty()) {
+ for (Map.Entry<Text, Text> entry : properties.entrySet()) {
+ bldr.append(entry.getKey().toString()).append(": ").append(entry.getValue().toString()).append(' ');
+ }
+ }
+ bldr.append(" vec: ").append(vector == null ? "null" : AbstractCluster.formatVector(vector, null));
+ return bldr.toString();
+ }
+
+
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/classify/WeightedVectorWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/classify/WeightedVectorWritable.java b/mr/src/main/java/org/apache/mahout/clustering/classify/WeightedVectorWritable.java
new file mode 100644
index 0000000..510dd39
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/classify/WeightedVectorWritable.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.AbstractCluster;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public class WeightedVectorWritable implements Writable {
+
+ private final VectorWritable vectorWritable = new VectorWritable();
+ private double weight;
+
+ public WeightedVectorWritable() {
+ }
+
+ public WeightedVectorWritable(double weight, Vector vector) {
+ this.vectorWritable.set(vector);
+ this.weight = weight;
+ }
+
+ public Vector getVector() {
+ return vectorWritable.get();
+ }
+
+ public void setVector(Vector vector) {
+ vectorWritable.set(vector);
+ }
+
+ public double getWeight() {
+ return weight;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ vectorWritable.readFields(in);
+ weight = in.readDouble();
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ vectorWritable.write(out);
+ out.writeDouble(weight);
+ }
+
+ @Override
+ public String toString() {
+ Vector vector = vectorWritable.get();
+ return weight + ": " + (vector == null ? "null" : AbstractCluster.formatVector(vector, null));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java b/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
new file mode 100644
index 0000000..ff02a4c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.fuzzykmeans;
+
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+
+public class FuzzyKMeansClusterer {
+
+ private static final double MINIMAL_VALUE = 0.0000000001;
+
+ private double m = 2.0; // default value
+
+ public Vector computePi(Collection<SoftCluster> clusters, List<Double> clusterDistanceList) {
+ Vector pi = new DenseVector(clusters.size());
+ for (int i = 0; i < clusters.size(); i++) {
+ double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList);
+ pi.set(i, probWeight);
+ }
+ return pi;
+ }
+
+ /** Computes the probability of a point belonging to a cluster */
+ public double computeProbWeight(double clusterDistance, Iterable<Double> clusterDistanceList) {
+ if (clusterDistance == 0) {
+ clusterDistance = MINIMAL_VALUE;
+ }
+ double denom = 0.0;
+ for (double eachCDist : clusterDistanceList) {
+ if (eachCDist == 0.0) {
+ eachCDist = MINIMAL_VALUE;
+ }
+ denom += Math.pow(clusterDistance / eachCDist, 2.0 / (m - 1));
+ }
+ return 1.0 / denom;
+ }
+
+ public void setM(double m) {
+ this.m = m;
+ }
+}
[28/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java
new file mode 100644
index 0000000..46fcc7f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+public class CVB0DocInferenceMapper extends CachingCVB0Mapper {
+
+ private final VectorWritable topics = new VectorWritable();
+
+ @Override
+ public void map(IntWritable docId, VectorWritable doc, Context context)
+ throws IOException, InterruptedException {
+ int numTopics = getNumTopics();
+ Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics);
+ Matrix docModel = new SparseRowMatrix(numTopics, doc.get().size());
+ int maxIters = getMaxIters();
+ ModelTrainer modelTrainer = getModelTrainer();
+ for (int i = 0; i < maxIters; i++) {
+ modelTrainer.getReadModel().trainDocTopicModel(doc.get(), docTopics, docModel);
+ }
+ topics.set(docTopics);
+ context.write(docId, topics);
+ }
+
+ @Override
+ protected void cleanup(Context context) {
+ getModelTrainer().stop();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java
new file mode 100644
index 0000000..3eee446
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java
@@ -0,0 +1,536 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import com.google.common.base.Joiner;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.mapreduce.VectorSumReducer;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.List;
+
+/**
+ * See {@link CachingCVB0Mapper} for more details on scalability and room for improvement.
+ * To try out this LDA implementation without using Hadoop, check out
+ * {@link InMemoryCollapsedVariationalBayes0}. If you want to do training directly in java code
+ * with your own main(), then look to {@link ModelTrainer} and {@link TopicModel}.
+ *
+ * Usage: {@code ./bin/mahout cvb <i>options</i>}
+ * <p>
+ * Valid options include:
+ * <dl>
+ * <dt>{@code --input path}</td>
+ * <dd>Input path for {@code SequenceFile<IntWritable, VectorWritable>} document vectors. See
+ * {@link org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles}
+ * for details on how to generate this input format.</dd>
+ * <dt>{@code --dictionary path}</dt>
+ * <dd>Path to dictionary file(s) generated during construction of input document vectors (glob
+ * expression supported). If set, this data is scanned to determine an appropriate value for option
+ * {@code --num_terms}.</dd>
+ * <dt>{@code --output path}</dt>
+ * <dd>Output path for topic-term distributions.</dd>
+ * <dt>{@code --doc_topic_output path}</dt>
+ * <dd>Output path for doc-topic distributions.</dd>
+ * <dt>{@code --num_topics k}</dt>
+ * <dd>Number of latent topics.</dd>
+ * <dt>{@code --num_terms nt}</dt>
+ * <dd>Number of unique features defined by input document vectors. If option {@code --dictionary}
+ * is defined and this option is unspecified, term count is calculated from dictionary.</dd>
+ * <dt>{@code --topic_model_temp_dir path}</dt>
+ * <dd>Path in which to store model state after each iteration.</dd>
+ * <dt>{@code --maxIter i}</dt>
+ * <dd>Maximum number of iterations to perform. If this value is less than or equal to the number of
+ * iteration states found beneath the path specified by option {@code --topic_model_temp_dir}, no
+ * further iterations are performed. Instead, output topic-term and doc-topic distributions are
+ * generated using data from the specified iteration.</dd>
+ * <dt>{@code --max_doc_topic_iters i}</dt>
+ * <dd>Maximum number of iterations per doc for p(topic|doc) learning. Defaults to {@code 10}.</dd>
+ * <dt>{@code --doc_topic_smoothing a}</dt>
+ * <dd>Smoothing for doc-topic distribution. Defaults to {@code 0.0001}.</dd>
+ * <dt>{@code --term_topic_smoothing e}</dt>
+ * <dd>Smoothing for topic-term distribution. Defaults to {@code 0.0001}.</dd>
+ * <dt>{@code --random_seed seed}</dt>
+ * <dd>Integer seed for random number generation.</dd>
+ * <dt>{@code --test_set_percentage p}</dt>
+ * <dd>Fraction of data to hold out for testing. Defaults to {@code 0.0}.</dd>
+ * <dt>{@code --iteration_block_size block}</dt>
+ * <dd>Number of iterations between perplexity checks. Defaults to {@code 10}. This option is
+ * ignored unless option {@code --test_set_percentage} is greater than zero.</dd>
+ * </dl>
+ */
+public class CVB0Driver extends AbstractJob {
+ private static final Logger log = LoggerFactory.getLogger(CVB0Driver.class);
+
+ public static final String NUM_TOPICS = "num_topics";
+ public static final String NUM_TERMS = "num_terms";
+ public static final String DOC_TOPIC_SMOOTHING = "doc_topic_smoothing";
+ public static final String TERM_TOPIC_SMOOTHING = "term_topic_smoothing";
+ public static final String DICTIONARY = "dictionary";
+ public static final String DOC_TOPIC_OUTPUT = "doc_topic_output";
+ public static final String MODEL_TEMP_DIR = "topic_model_temp_dir";
+ public static final String ITERATION_BLOCK_SIZE = "iteration_block_size";
+ public static final String RANDOM_SEED = "random_seed";
+ public static final String TEST_SET_FRACTION = "test_set_fraction";
+ public static final String NUM_TRAIN_THREADS = "num_train_threads";
+ public static final String NUM_UPDATE_THREADS = "num_update_threads";
+ public static final String MAX_ITERATIONS_PER_DOC = "max_doc_topic_iters";
+ public static final String MODEL_WEIGHT = "prev_iter_mult";
+ public static final String NUM_REDUCE_TASKS = "num_reduce_tasks";
+ public static final String BACKFILL_PERPLEXITY = "backfill_perplexity";
+ private static final String MODEL_PATHS = "mahout.lda.cvb.modelPath";
+
+ private static final double DEFAULT_CONVERGENCE_DELTA = 0;
+ private static final double DEFAULT_DOC_TOPIC_SMOOTHING = 0.0001;
+ private static final double DEFAULT_TERM_TOPIC_SMOOTHING = 0.0001;
+ private static final int DEFAULT_ITERATION_BLOCK_SIZE = 10;
+ private static final double DEFAULT_TEST_SET_FRACTION = 0;
+ private static final int DEFAULT_NUM_TRAIN_THREADS = 4;
+ private static final int DEFAULT_NUM_UPDATE_THREADS = 1;
+ private static final int DEFAULT_MAX_ITERATIONS_PER_DOC = 10;
+ private static final int DEFAULT_NUM_REDUCE_TASKS = 10;
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.maxIterationsOption().create());
+ addOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION, "cd", "The convergence delta value",
+ String.valueOf(DEFAULT_CONVERGENCE_DELTA));
+ addOption(DefaultOptionCreator.overwriteOption().create());
+
+ addOption(NUM_TOPICS, "k", "Number of topics to learn", true);
+ addOption(NUM_TERMS, "nt", "Vocabulary size", false);
+ addOption(DOC_TOPIC_SMOOTHING, "a", "Smoothing for document/topic distribution",
+ String.valueOf(DEFAULT_DOC_TOPIC_SMOOTHING));
+ addOption(TERM_TOPIC_SMOOTHING, "e", "Smoothing for topic/term distribution",
+ String.valueOf(DEFAULT_TERM_TOPIC_SMOOTHING));
+ addOption(DICTIONARY, "dict", "Path to term-dictionary file(s) (glob expression supported)", false);
+ addOption(DOC_TOPIC_OUTPUT, "dt", "Output path for the training doc/topic distribution", false);
+ addOption(MODEL_TEMP_DIR, "mt", "Path to intermediate model path (useful for restarting)", false);
+ addOption(ITERATION_BLOCK_SIZE, "block", "Number of iterations per perplexity check",
+ String.valueOf(DEFAULT_ITERATION_BLOCK_SIZE));
+ addOption(RANDOM_SEED, "seed", "Random seed", false);
+ addOption(TEST_SET_FRACTION, "tf", "Fraction of data to hold out for testing",
+ String.valueOf(DEFAULT_TEST_SET_FRACTION));
+ addOption(NUM_TRAIN_THREADS, "ntt", "number of threads per mapper to train with",
+ String.valueOf(DEFAULT_NUM_TRAIN_THREADS));
+ addOption(NUM_UPDATE_THREADS, "nut", "number of threads per mapper to update the model with",
+ String.valueOf(DEFAULT_NUM_UPDATE_THREADS));
+ addOption(MAX_ITERATIONS_PER_DOC, "mipd", "max number of iterations per doc for p(topic|doc) learning",
+ String.valueOf(DEFAULT_MAX_ITERATIONS_PER_DOC));
+ addOption(NUM_REDUCE_TASKS, null, "number of reducers to use during model estimation",
+ String.valueOf(DEFAULT_NUM_REDUCE_TASKS));
+ addOption(buildOption(BACKFILL_PERPLEXITY, null, "enable backfilling of missing perplexity values", false, false,
+ null));
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ int numTopics = Integer.parseInt(getOption(NUM_TOPICS));
+ Path inputPath = getInputPath();
+ Path topicModelOutputPath = getOutputPath();
+ int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
+ int iterationBlockSize = Integer.parseInt(getOption(ITERATION_BLOCK_SIZE));
+ double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION));
+ double alpha = Double.parseDouble(getOption(DOC_TOPIC_SMOOTHING));
+ double eta = Double.parseDouble(getOption(TERM_TOPIC_SMOOTHING));
+ int numTrainThreads = Integer.parseInt(getOption(NUM_TRAIN_THREADS));
+ int numUpdateThreads = Integer.parseInt(getOption(NUM_UPDATE_THREADS));
+ int maxItersPerDoc = Integer.parseInt(getOption(MAX_ITERATIONS_PER_DOC));
+ Path dictionaryPath = hasOption(DICTIONARY) ? new Path(getOption(DICTIONARY)) : null;
+ int numTerms = hasOption(NUM_TERMS)
+ ? Integer.parseInt(getOption(NUM_TERMS))
+ : getNumTerms(getConf(), dictionaryPath);
+ Path docTopicOutputPath = hasOption(DOC_TOPIC_OUTPUT) ? new Path(getOption(DOC_TOPIC_OUTPUT)) : null;
+ Path modelTempPath = hasOption(MODEL_TEMP_DIR)
+ ? new Path(getOption(MODEL_TEMP_DIR))
+ : getTempPath("topicModelState");
+ long seed = hasOption(RANDOM_SEED)
+ ? Long.parseLong(getOption(RANDOM_SEED))
+ : System.nanoTime() % 10000;
+ float testFraction = hasOption(TEST_SET_FRACTION)
+ ? Float.parseFloat(getOption(TEST_SET_FRACTION))
+ : 0.0f;
+ int numReduceTasks = Integer.parseInt(getOption(NUM_REDUCE_TASKS));
+ boolean backfillPerplexity = hasOption(BACKFILL_PERPLEXITY);
+
+ return run(getConf(), inputPath, topicModelOutputPath, numTopics, numTerms, alpha, eta,
+ maxIterations, iterationBlockSize, convergenceDelta, dictionaryPath, docTopicOutputPath,
+ modelTempPath, seed, testFraction, numTrainThreads, numUpdateThreads, maxItersPerDoc,
+ numReduceTasks, backfillPerplexity);
+ }
+
+ private static int getNumTerms(Configuration conf, Path dictionaryPath) throws IOException {
+ FileSystem fs = dictionaryPath.getFileSystem(conf);
+ Text key = new Text();
+ IntWritable value = new IntWritable();
+ int maxTermId = -1;
+ for (FileStatus stat : fs.globStatus(dictionaryPath)) {
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, stat.getPath(), conf);
+ while (reader.next(key, value)) {
+ maxTermId = Math.max(maxTermId, value.get());
+ }
+ }
+ return maxTermId + 1;
+ }
+
+ public int run(Configuration conf,
+ Path inputPath,
+ Path topicModelOutputPath,
+ int numTopics,
+ int numTerms,
+ double alpha,
+ double eta,
+ int maxIterations,
+ int iterationBlockSize,
+ double convergenceDelta,
+ Path dictionaryPath,
+ Path docTopicOutputPath,
+ Path topicModelStateTempPath,
+ long randomSeed,
+ float testFraction,
+ int numTrainThreads,
+ int numUpdateThreads,
+ int maxItersPerDoc,
+ int numReduceTasks,
+ boolean backfillPerplexity)
+ throws ClassNotFoundException, IOException, InterruptedException {
+
+ setConf(conf);
+
+ // verify arguments
+ Preconditions.checkArgument(testFraction >= 0.0 && testFraction <= 1.0,
+ "Expected 'testFraction' value in range [0, 1] but found value '%s'", testFraction);
+ Preconditions.checkArgument(!backfillPerplexity || testFraction > 0.0,
+ "Expected 'testFraction' value in range (0, 1] but found value '%s'", testFraction);
+
+ String infoString = "Will run Collapsed Variational Bayes (0th-derivative approximation) "
+ + "learning for LDA on {} (numTerms: {}), finding {}-topics, with document/topic prior {}, "
+ + "topic/term prior {}. Maximum iterations to run will be {}, unless the change in "
+ + "perplexity is less than {}. Topic model output (p(term|topic) for each topic) will be "
+ + "stored {}. Random initialization seed is {}, holding out {} of the data for perplexity "
+ + "check\n";
+ log.info(infoString, inputPath, numTerms, numTopics, alpha, eta, maxIterations,
+ convergenceDelta, topicModelOutputPath, randomSeed, testFraction);
+ infoString = dictionaryPath == null
+ ? "" : "Dictionary to be used located " + dictionaryPath.toString() + '\n';
+ infoString += docTopicOutputPath == null
+ ? "" : "p(topic|docId) will be stored " + docTopicOutputPath.toString() + '\n';
+ log.info(infoString);
+
+ FileSystem fs = FileSystem.get(topicModelStateTempPath.toUri(), conf);
+ int iterationNumber = getCurrentIterationNumber(conf, topicModelStateTempPath, maxIterations);
+ log.info("Current iteration number: {}", iterationNumber);
+
+ conf.set(NUM_TOPICS, String.valueOf(numTopics));
+ conf.set(NUM_TERMS, String.valueOf(numTerms));
+ conf.set(DOC_TOPIC_SMOOTHING, String.valueOf(alpha));
+ conf.set(TERM_TOPIC_SMOOTHING, String.valueOf(eta));
+ conf.set(RANDOM_SEED, String.valueOf(randomSeed));
+ conf.set(NUM_TRAIN_THREADS, String.valueOf(numTrainThreads));
+ conf.set(NUM_UPDATE_THREADS, String.valueOf(numUpdateThreads));
+ conf.set(MAX_ITERATIONS_PER_DOC, String.valueOf(maxItersPerDoc));
+ conf.set(MODEL_WEIGHT, "1"); // TODO
+ conf.set(TEST_SET_FRACTION, String.valueOf(testFraction));
+
+ List<Double> perplexities = Lists.newArrayList();
+ for (int i = 1; i <= iterationNumber; i++) {
+ // form path to model
+ Path modelPath = modelPath(topicModelStateTempPath, i);
+
+ // read perplexity
+ double perplexity = readPerplexity(conf, topicModelStateTempPath, i);
+ if (Double.isNaN(perplexity)) {
+ if (!(backfillPerplexity && i % iterationBlockSize == 0)) {
+ continue;
+ }
+ log.info("Backfilling perplexity at iteration {}", i);
+ if (!fs.exists(modelPath)) {
+ log.error("Model path '{}' does not exist; Skipping iteration {} perplexity calculation",
+ modelPath.toString(), i);
+ continue;
+ }
+ perplexity = calculatePerplexity(conf, inputPath, modelPath, i);
+ }
+
+ // register and log perplexity
+ perplexities.add(perplexity);
+ log.info("Perplexity at iteration {} = {}", i, perplexity);
+ }
+
+ long startTime = System.currentTimeMillis();
+ while (iterationNumber < maxIterations) {
+ // test convergence
+ if (convergenceDelta > 0.0) {
+ double delta = rateOfChange(perplexities);
+ if (delta < convergenceDelta) {
+ log.info("Convergence achieved at iteration {} with perplexity {} and delta {}",
+ iterationNumber, perplexities.get(perplexities.size() - 1), delta);
+ break;
+ }
+ }
+
+ // update model
+ iterationNumber++;
+ log.info("About to run iteration {} of {}", iterationNumber, maxIterations);
+ Path modelInputPath = modelPath(topicModelStateTempPath, iterationNumber - 1);
+ Path modelOutputPath = modelPath(topicModelStateTempPath, iterationNumber);
+ runIteration(conf, inputPath, modelInputPath, modelOutputPath, iterationNumber,
+ maxIterations, numReduceTasks);
+
+ // calculate perplexity
+ if (testFraction > 0 && iterationNumber % iterationBlockSize == 0) {
+ perplexities.add(calculatePerplexity(conf, inputPath, modelOutputPath, iterationNumber));
+ log.info("Current perplexity = {}", perplexities.get(perplexities.size() - 1));
+ log.info("(p_{} - p_{}) / p_0 = {}; target = {}", iterationNumber, iterationNumber - iterationBlockSize,
+ rateOfChange(perplexities), convergenceDelta);
+ }
+ }
+ log.info("Completed {} iterations in {} seconds", iterationNumber,
+ (System.currentTimeMillis() - startTime) / 1000);
+ log.info("Perplexities: ({})", Joiner.on(", ").join(perplexities));
+
+ // write final topic-term and doc-topic distributions
+ Path finalIterationData = modelPath(topicModelStateTempPath, iterationNumber);
+ Job topicModelOutputJob = topicModelOutputPath != null
+ ? writeTopicModel(conf, finalIterationData, topicModelOutputPath)
+ : null;
+ Job docInferenceJob = docTopicOutputPath != null
+ ? writeDocTopicInference(conf, inputPath, finalIterationData, docTopicOutputPath)
+ : null;
+ if (topicModelOutputJob != null && !topicModelOutputJob.waitForCompletion(true)) {
+ return -1;
+ }
+ if (docInferenceJob != null && !docInferenceJob.waitForCompletion(true)) {
+ return -1;
+ }
+ return 0;
+ }
+
+ private static double rateOfChange(List<Double> perplexities) {
+ int sz = perplexities.size();
+ if (sz < 2) {
+ return Double.MAX_VALUE;
+ }
+ return Math.abs(perplexities.get(sz - 1) - perplexities.get(sz - 2)) / perplexities.get(0);
+ }
+
+ private double calculatePerplexity(Configuration conf, Path corpusPath, Path modelPath, int iteration)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ String jobName = "Calculating perplexity for " + modelPath;
+ log.info("About to run: {}", jobName);
+
+ Path outputPath = perplexityPath(modelPath.getParent(), iteration);
+ Job job = prepareJob(corpusPath, outputPath, CachingCVB0PerplexityMapper.class, DoubleWritable.class,
+ DoubleWritable.class, DualDoubleSumReducer.class, DoubleWritable.class, DoubleWritable.class);
+
+ job.setJobName(jobName);
+ job.setCombinerClass(DualDoubleSumReducer.class);
+ job.setNumReduceTasks(1);
+ setModelPaths(job, modelPath);
+ HadoopUtil.delete(conf, outputPath);
+ if (!job.waitForCompletion(true)) {
+ throw new InterruptedException("Failed to calculate perplexity for: " + modelPath);
+ }
+ return readPerplexity(conf, modelPath.getParent(), iteration);
+ }
+
+ /**
+ * Sums keys and values independently.
+ */
+ public static class DualDoubleSumReducer extends
+ Reducer<DoubleWritable, DoubleWritable, DoubleWritable, DoubleWritable> {
+ private final DoubleWritable outKey = new DoubleWritable();
+ private final DoubleWritable outValue = new DoubleWritable();
+
+ @Override
+ public void run(Context context) throws IOException,
+ InterruptedException {
+ double keySum = 0.0;
+ double valueSum = 0.0;
+ while (context.nextKey()) {
+ keySum += context.getCurrentKey().get();
+ for (DoubleWritable value : context.getValues()) {
+ valueSum += value.get();
+ }
+ }
+ outKey.set(keySum);
+ outValue.set(valueSum);
+ context.write(outKey, outValue);
+ }
+ }
+
+ /**
+ * @param topicModelStateTemp
+ * @param iteration
+ * @return {@code double[2]} where first value is perplexity and second is model weight of those
+ * documents sampled during perplexity computation, or {@code null} if no perplexity data
+ * exists for the given iteration.
+ * @throws IOException
+ */
+ public static double readPerplexity(Configuration conf, Path topicModelStateTemp, int iteration)
+ throws IOException {
+ Path perplexityPath = perplexityPath(topicModelStateTemp, iteration);
+ FileSystem fs = FileSystem.get(perplexityPath.toUri(), conf);
+ if (!fs.exists(perplexityPath)) {
+ log.warn("Perplexity path {} does not exist, returning NaN", perplexityPath);
+ return Double.NaN;
+ }
+ double perplexity = 0;
+ double modelWeight = 0;
+ long n = 0;
+ for (Pair<DoubleWritable, DoubleWritable> pair : new SequenceFileDirIterable<DoubleWritable, DoubleWritable>(
+ perplexityPath, PathType.LIST, PathFilters.partFilter(), null, true, conf)) {
+ modelWeight += pair.getFirst().get();
+ perplexity += pair.getSecond().get();
+ n++;
+ }
+ log.info("Read {} entries with total perplexity {} and model weight {}", n,
+ perplexity, modelWeight);
+ return perplexity / modelWeight;
+ }
+
+ private Job writeTopicModel(Configuration conf, Path modelInput, Path output)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ String jobName = String.format("Writing final topic/term distributions from %s to %s", modelInput, output);
+ log.info("About to run: {}", jobName);
+
+ Job job = prepareJob(modelInput, output, SequenceFileInputFormat.class, CVB0TopicTermVectorNormalizerMapper.class,
+ IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, jobName);
+ job.submit();
+ return job;
+ }
+
+ private Job writeDocTopicInference(Configuration conf, Path corpus, Path modelInput, Path output)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ String jobName = String.format("Writing final document/topic inference from %s to %s", corpus, output);
+ log.info("About to run: {}", jobName);
+
+ Job job = prepareJob(corpus, output, SequenceFileInputFormat.class, CVB0DocInferenceMapper.class,
+ IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, jobName);
+
+ FileSystem fs = FileSystem.get(corpus.toUri(), conf);
+ if (modelInput != null && fs.exists(modelInput)) {
+ FileStatus[] statuses = fs.listStatus(modelInput, PathFilters.partFilter());
+ URI[] modelUris = new URI[statuses.length];
+ for (int i = 0; i < statuses.length; i++) {
+ modelUris[i] = statuses[i].getPath().toUri();
+ }
+ DistributedCache.setCacheFiles(modelUris, conf);
+ setModelPaths(job, modelInput);
+ }
+ job.submit();
+ return job;
+ }
+
+ public static Path modelPath(Path topicModelStateTempPath, int iterationNumber) {
+ return new Path(topicModelStateTempPath, "model-" + iterationNumber);
+ }
+
+ public static Path perplexityPath(Path topicModelStateTempPath, int iterationNumber) {
+ return new Path(topicModelStateTempPath, "perplexity-" + iterationNumber);
+ }
+
+ private static int getCurrentIterationNumber(Configuration config, Path modelTempDir, int maxIterations)
+ throws IOException {
+ FileSystem fs = FileSystem.get(modelTempDir.toUri(), config);
+ int iterationNumber = 1;
+ Path iterationPath = modelPath(modelTempDir, iterationNumber);
+ while (fs.exists(iterationPath) && iterationNumber <= maxIterations) {
+ log.info("Found previous state: {}", iterationPath);
+ iterationNumber++;
+ iterationPath = modelPath(modelTempDir, iterationNumber);
+ }
+ return iterationNumber - 1;
+ }
+
+ public void runIteration(Configuration conf, Path corpusInput, Path modelInput, Path modelOutput,
+ int iterationNumber, int maxIterations, int numReduceTasks)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ String jobName = String.format("Iteration %d of %d, input path: %s",
+ iterationNumber, maxIterations, modelInput);
+ log.info("About to run: {}", jobName);
+ Job job = prepareJob(corpusInput, modelOutput, CachingCVB0Mapper.class, IntWritable.class, VectorWritable.class,
+ VectorSumReducer.class, IntWritable.class, VectorWritable.class);
+ job.setCombinerClass(VectorSumReducer.class);
+ job.setNumReduceTasks(numReduceTasks);
+ job.setJobName(jobName);
+ setModelPaths(job, modelInput);
+ HadoopUtil.delete(conf, modelOutput);
+ if (!job.waitForCompletion(true)) {
+ throw new InterruptedException(String.format("Failed to complete iteration %d stage 1",
+ iterationNumber));
+ }
+ }
+
+ private static void setModelPaths(Job job, Path modelPath) throws IOException {
+ Configuration conf = job.getConfiguration();
+ if (modelPath == null || !FileSystem.get(modelPath.toUri(), conf).exists(modelPath)) {
+ return;
+ }
+ FileStatus[] statuses = FileSystem.get(modelPath.toUri(), conf).listStatus(modelPath, PathFilters.partFilter());
+ Preconditions.checkState(statuses.length > 0, "No part files found in model path '%s'", modelPath.toString());
+ String[] modelPaths = new String[statuses.length];
+ for (int i = 0; i < statuses.length; i++) {
+ modelPaths[i] = statuses[i].getPath().toUri().toString();
+ }
+ conf.setStrings(MODEL_PATHS, modelPaths);
+ }
+
+ public static Path[] getModelPaths(Configuration conf) {
+ String[] modelPathNames = conf.getStrings(MODEL_PATHS);
+ if (modelPathNames == null || modelPathNames.length == 0) {
+ return null;
+ }
+ Path[] modelPaths = new Path[modelPathNames.length];
+ for (int i = 0; i < modelPathNames.length; i++) {
+ modelPaths[i] = new Path(modelPathNames[i]);
+ }
+ return modelPaths;
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new CVB0Driver(), args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java
new file mode 100644
index 0000000..1253942
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.IOException;
+
+/**
+ * Performs L1 normalization of input vectors.
+ */
+public class CVB0TopicTermVectorNormalizerMapper extends
+ Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
+
+ @Override
+ protected void map(IntWritable key, VectorWritable value, Context context) throws IOException,
+ InterruptedException {
+ value.get().assign(Functions.div(value.get().norm(1.0)));
+ context.write(key, value);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java
new file mode 100644
index 0000000..96f36d4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java
@@ -0,0 +1,133 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+
+/**
+ * Run ensemble learning via loading the {@link ModelTrainer} with two {@link TopicModel} instances:
+ * one from the previous iteration, the other empty. Inference is done on the first, and the
+ * learning updates are stored in the second, and only emitted at cleanup().
+ * <p/>
+ * In terms of obvious performance improvements still available, the memory footprint in this
+ * Mapper could be dropped by half if we accumulated model updates onto the model we're using
+ * for inference, which might also speed up convergence, as we'd be able to take advantage of
+ * learning <em>during</em> iteration, not just after each one is done. Most likely we don't
+ * really need to accumulate double values in the model either, floats would most likely be
+ * sufficient. Between these two, we could squeeze another factor of 4 in memory efficiency.
+ * <p/>
+ * In terms of CPU, we're re-learning the p(topic|doc) distribution on every iteration, starting
+ * from scratch. This is usually only 10 fixed-point iterations per doc, but that's 10x more than
+ * only 1. To avoid having to do this, we would need to do a map-side join of the unchanging
+ * corpus with the continually-improving p(topic|doc) matrix, and then emit multiple outputs
+ * from the mappers to make sure we can do the reduce model averaging as well. Tricky, but
+ * possibly worth it.
+ * <p/>
+ * {@link ModelTrainer} already takes advantage (in maybe the not-nice way) of multi-core
+ * availability by doing multithreaded learning, see that class for details.
+ */
+public class CachingCVB0Mapper
+ extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
+
+ private static final Logger log = LoggerFactory.getLogger(CachingCVB0Mapper.class);
+
+ private ModelTrainer modelTrainer;
+ private TopicModel readModel;
+ private TopicModel writeModel;
+ private int maxIters;
+ private int numTopics;
+
+ protected ModelTrainer getModelTrainer() {
+ return modelTrainer;
+ }
+
+ protected int getMaxIters() {
+ return maxIters;
+ }
+
+ protected int getNumTopics() {
+ return numTopics;
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ log.info("Retrieving configuration");
+ Configuration conf = context.getConfiguration();
+ float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN);
+ float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN);
+ long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L);
+ numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1);
+ int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1);
+ int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1);
+ int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4);
+ maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10);
+ float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f);
+
+ log.info("Initializing read model");
+ Path[] modelPaths = CVB0Driver.getModelPaths(conf);
+ if (modelPaths != null && modelPaths.length > 0) {
+ readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths);
+ } else {
+ log.info("No model files found");
+ readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null,
+ numTrainThreads, modelWeight);
+ }
+
+ log.info("Initializing write model");
+ writeModel = modelWeight == 1
+ ? new TopicModel(numTopics, numTerms, eta, alpha, null, numUpdateThreads)
+ : readModel;
+
+ log.info("Initializing model trainer");
+ modelTrainer = new ModelTrainer(readModel, writeModel, numTrainThreads, numTopics, numTerms);
+ modelTrainer.start();
+ }
+
+ @Override
+ public void map(IntWritable docId, VectorWritable document, Context context)
+ throws IOException, InterruptedException {
+ /* where to get docTopics? */
+ Vector topicVector = new DenseVector(numTopics).assign(1.0 / numTopics);
+ modelTrainer.train(document.get(), topicVector, true, maxIters);
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ log.info("Stopping model trainer");
+ modelTrainer.stop();
+
+ log.info("Writing model");
+ TopicModel readFrom = modelTrainer.getReadModel();
+ for (MatrixSlice topic : readFrom) {
+ context.write(new IntWritable(topic.index()), new VectorWritable(topic.vector()));
+ }
+ readModel.stop();
+ writeModel.stop();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java
new file mode 100644
index 0000000..da77baf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.MemoryUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Random;
+
+public class CachingCVB0PerplexityMapper extends
+ Mapper<IntWritable, VectorWritable, DoubleWritable, DoubleWritable> {
+ /**
+ * Hadoop counters for {@link CachingCVB0PerplexityMapper}, to aid in debugging.
+ */
+ public enum Counters {
+ SAMPLED_DOCUMENTS
+ }
+
+ private static final Logger log = LoggerFactory.getLogger(CachingCVB0PerplexityMapper.class);
+
+ private ModelTrainer modelTrainer;
+ private TopicModel readModel;
+ private int maxIters;
+ private int numTopics;
+ private float testFraction;
+ private Random random;
+ private Vector topicVector;
+ private final DoubleWritable outKey = new DoubleWritable();
+ private final DoubleWritable outValue = new DoubleWritable();
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ MemoryUtil.startMemoryLogger(5000);
+
+ log.info("Retrieving configuration");
+ Configuration conf = context.getConfiguration();
+ float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN);
+ float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN);
+ long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L);
+ random = RandomUtils.getRandom(seed);
+ numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1);
+ int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1);
+ int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1);
+ int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4);
+ maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10);
+ float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f);
+ testFraction = conf.getFloat(CVB0Driver.TEST_SET_FRACTION, 0.1f);
+
+ log.info("Initializing read model");
+ Path[] modelPaths = CVB0Driver.getModelPaths(conf);
+ if (modelPaths != null && modelPaths.length > 0) {
+ readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths);
+ } else {
+ log.info("No model files found");
+ readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null,
+ numTrainThreads, modelWeight);
+ }
+
+ log.info("Initializing model trainer");
+ modelTrainer = new ModelTrainer(readModel, null, numTrainThreads, numTopics, numTerms);
+
+ log.info("Initializing topic vector");
+ topicVector = new DenseVector(new double[numTopics]);
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ readModel.stop();
+ MemoryUtil.stopMemoryLogger();
+ }
+
+ @Override
+ public void map(IntWritable docId, VectorWritable document, Context context)
+ throws IOException, InterruptedException {
+ if (testFraction < 1.0f && random.nextFloat() >= testFraction) {
+ return;
+ }
+ context.getCounter(Counters.SAMPLED_DOCUMENTS).increment(1);
+ outKey.set(document.get().norm(1));
+ outValue.set(modelTrainer.calculatePerplexity(document.get(), topicVector.assign(1.0 / numTopics), maxIters));
+ context.write(outKey, outValue);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
new file mode 100644
index 0000000..07ae100
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java
@@ -0,0 +1,515 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.DistributedRowMatrixWriter;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.NamedVector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Runs the same algorithm as {@link CVB0Driver}, but sequentially, in memory. Memory requirements
+ * are currently: the entire corpus is read into RAM, two copies of the model (each of size
+ * numTerms * numTopics), and another matrix of size numDocs * numTopics is held in memory
+ * (to store p(topic|doc) for all docs).
+ *
+ * But if all this fits in memory, this can be significantly faster than an iterative MR job.
+ */
+public class InMemoryCollapsedVariationalBayes0 extends AbstractJob {
+
+ private static final Logger log = LoggerFactory.getLogger(InMemoryCollapsedVariationalBayes0.class);
+
+ private int numTopics;
+ private int numTerms;
+ private int numDocuments;
+ private double alpha;
+ private double eta;
+ //private int minDfCt;
+ //private double maxDfPct;
+ private boolean verbose = false;
+ private String[] terms; // of length numTerms;
+ private Matrix corpusWeights; // length numDocs;
+ private double totalCorpusWeight;
+ private double initialModelCorpusFraction;
+ private Matrix docTopicCounts;
+ private int numTrainingThreads;
+ private int numUpdatingThreads;
+ private ModelTrainer modelTrainer;
+
+ private InMemoryCollapsedVariationalBayes0() {
+ // only for main usage
+ }
+
+ public void setVerbose(boolean verbose) {
+ this.verbose = verbose;
+ }
+
+ public InMemoryCollapsedVariationalBayes0(Matrix corpus,
+ String[] terms,
+ int numTopics,
+ double alpha,
+ double eta,
+ int numTrainingThreads,
+ int numUpdatingThreads,
+ double modelCorpusFraction) {
+ //this.seed = seed;
+ this.numTopics = numTopics;
+ this.alpha = alpha;
+ this.eta = eta;
+ //this.minDfCt = 0;
+ //this.maxDfPct = 1.0f;
+ corpusWeights = corpus;
+ numDocuments = corpus.numRows();
+ this.terms = terms;
+ this.initialModelCorpusFraction = modelCorpusFraction;
+ numTerms = terms != null ? terms.length : corpus.numCols();
+ Map<String, Integer> termIdMap = Maps.newHashMap();
+ if (terms != null) {
+ for (int t = 0; t < terms.length; t++) {
+ termIdMap.put(terms[t], t);
+ }
+ }
+ this.numTrainingThreads = numTrainingThreads;
+ this.numUpdatingThreads = numUpdatingThreads;
+ postInitCorpus();
+ initializeModel();
+ }
+
+ private void postInitCorpus() {
+ totalCorpusWeight = 0;
+ int numNonZero = 0;
+ for (int i = 0; i < numDocuments; i++) {
+ Vector v = corpusWeights.viewRow(i);
+ double norm;
+ if (v != null && (norm = v.norm(1)) != 0) {
+ numNonZero += v.getNumNondefaultElements();
+ totalCorpusWeight += norm;
+ }
+ }
+ String s = "Initializing corpus with %d docs, %d terms, %d nonzero entries, total termWeight %f";
+ log.info(String.format(s, numDocuments, numTerms, numNonZero, totalCorpusWeight));
+ }
+
+ private void initializeModel() {
+ TopicModel topicModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(), terms,
+ numUpdatingThreads, initialModelCorpusFraction == 0 ? 1 : initialModelCorpusFraction * totalCorpusWeight);
+ topicModel.setConf(getConf());
+
+ TopicModel updatedModel = initialModelCorpusFraction == 0
+ ? new TopicModel(numTopics, numTerms, eta, alpha, null, terms, numUpdatingThreads, 1)
+ : topicModel;
+ updatedModel.setConf(getConf());
+ docTopicCounts = new DenseMatrix(numDocuments, numTopics);
+ docTopicCounts.assign(1.0 / numTopics);
+ modelTrainer = new ModelTrainer(topicModel, updatedModel, numTrainingThreads, numTopics, numTerms);
+ }
+
+ /*
+ private void inferDocuments(double convergence, int maxIter, boolean recalculate) {
+ for (int docId = 0; docId < corpusWeights.numRows() ; docId++) {
+ Vector inferredDocument = topicModel.infer(corpusWeights.viewRow(docId),
+ docTopicCounts.viewRow(docId));
+ // do what now?
+ }
+ }
+ */
+
+ public void trainDocuments() {
+ trainDocuments(0);
+ }
+
+ public void trainDocuments(double testFraction) {
+ long start = System.nanoTime();
+ modelTrainer.start();
+ for (int docId = 0; docId < corpusWeights.numRows(); docId++) {
+ if (testFraction == 0 || docId % (1 / testFraction) != 0) {
+ Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics); // docTopicCounts.getRow(docId)
+ modelTrainer.trainSync(corpusWeights.viewRow(docId), docTopics , true, 10);
+ }
+ }
+ modelTrainer.stop();
+ logTime("train documents", System.nanoTime() - start);
+ }
+
+ /*
+ private double error(int docId) {
+ Vector docTermCounts = corpusWeights.viewRow(docId);
+ if (docTermCounts == null) {
+ return 0;
+ } else {
+ Vector expectedDocTermCounts =
+ topicModel.infer(corpusWeights.viewRow(docId), docTopicCounts.viewRow(docId));
+ double expectedNorm = expectedDocTermCounts.norm(1);
+ return expectedDocTermCounts.times(docTermCounts.norm(1)/expectedNorm)
+ .minus(docTermCounts).norm(1);
+ }
+ }
+
+ private double error() {
+ long time = System.nanoTime();
+ double error = 0;
+ for (int docId = 0; docId < numDocuments; docId++) {
+ error += error(docId);
+ }
+ logTime("error calculation", System.nanoTime() - time);
+ return error / totalCorpusWeight;
+ }
+ */
+
+ public double iterateUntilConvergence(double minFractionalErrorChange,
+ int maxIterations, int minIter) {
+ return iterateUntilConvergence(minFractionalErrorChange, maxIterations, minIter, 0);
+ }
+
+ public double iterateUntilConvergence(double minFractionalErrorChange,
+ int maxIterations, int minIter, double testFraction) {
+ int iter = 0;
+ double oldPerplexity = 0;
+ while (iter < minIter) {
+ trainDocuments(testFraction);
+ if (verbose) {
+ log.info("model after: {}: {}", iter, modelTrainer.getReadModel());
+ }
+ log.info("iteration {} complete", iter);
+ oldPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts,
+ testFraction);
+ log.info("{} = perplexity", oldPerplexity);
+ iter++;
+ }
+ double newPerplexity = 0;
+ double fractionalChange = Double.MAX_VALUE;
+ while (iter < maxIterations && fractionalChange > minFractionalErrorChange) {
+ trainDocuments();
+ if (verbose) {
+ log.info("model after: {}: {}", iter, modelTrainer.getReadModel());
+ }
+ newPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts,
+ testFraction);
+ log.info("{} = perplexity", newPerplexity);
+ iter++;
+ fractionalChange = Math.abs(newPerplexity - oldPerplexity) / oldPerplexity;
+ log.info("{} = fractionalChange", fractionalChange);
+ oldPerplexity = newPerplexity;
+ }
+ if (iter < maxIterations) {
+ log.info(String.format("Converged! fractional error change: %f, error %f",
+ fractionalChange, newPerplexity));
+ } else {
+ log.info(String.format("Reached max iteration count (%d), fractional error change: %f, error: %f",
+ maxIterations, fractionalChange, newPerplexity));
+ }
+ return newPerplexity;
+ }
+
+ public void writeModel(Path outputPath) throws IOException {
+ modelTrainer.persist(outputPath);
+ }
+
+ private static void logTime(String label, long nanos) {
+ log.info("{} time: {}ms", label, nanos / 1.0e6);
+ }
+
+ public static int main2(String[] args, Configuration conf) throws Exception {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option helpOpt = DefaultOptionCreator.helpOption();
+
+ Option inputDirOpt = obuilder.withLongName("input").withRequired(true).withArgument(
+ abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The Directory on HDFS containing the collapsed, properly formatted files having "
+ + "one doc per line").withShortName("i").create();
+
+ Option dictOpt = obuilder.withLongName("dictionary").withRequired(false).withArgument(
+ abuilder.withName("dictionary").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The path to the term-dictionary format is ... ").withShortName("d").create();
+
+ Option dfsOpt = obuilder.withLongName("dfs").withRequired(false).withArgument(
+ abuilder.withName("dfs").withMinimum(1).withMaximum(1).create()).withDescription(
+ "HDFS namenode URI").withShortName("dfs").create();
+
+ Option numTopicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(abuilder
+ .withName("numTopics").withMinimum(1).withMaximum(1)
+ .create()).withDescription("Number of topics to learn").withShortName("top").create();
+
+ Option outputTopicFileOpt = obuilder.withLongName("topicOutputFile").withRequired(true).withArgument(
+ abuilder.withName("topicOutputFile").withMinimum(1).withMaximum(1).create())
+ .withDescription("File to write out p(term | topic)").withShortName("to").create();
+
+ Option outputDocFileOpt = obuilder.withLongName("docOutputFile").withRequired(true).withArgument(
+ abuilder.withName("docOutputFile").withMinimum(1).withMaximum(1).create())
+ .withDescription("File to write out p(topic | docid)").withShortName("do").create();
+
+ Option alphaOpt = obuilder.withLongName("alpha").withRequired(false).withArgument(abuilder
+ .withName("alpha").withMinimum(1).withMaximum(1).withDefault("0.1").create())
+ .withDescription("Smoothing parameter for p(topic | document) prior").withShortName("a").create();
+
+ Option etaOpt = obuilder.withLongName("eta").withRequired(false).withArgument(abuilder
+ .withName("eta").withMinimum(1).withMaximum(1).withDefault("0.1").create())
+ .withDescription("Smoothing parameter for p(term | topic)").withShortName("e").create();
+
+ Option maxIterOpt = obuilder.withLongName("maxIterations").withRequired(false).withArgument(abuilder
+ .withName("maxIterations").withMinimum(1).withMaximum(1).withDefault("10").create())
+ .withDescription("Maximum number of training passes").withShortName("m").create();
+
+ Option modelCorpusFractionOption = obuilder.withLongName("modelCorpusFraction")
+ .withRequired(false).withArgument(abuilder.withName("modelCorpusFraction").withMinimum(1)
+ .withMaximum(1).withDefault("0.0").create()).withShortName("mcf")
+ .withDescription("For online updates, initial value of |model|/|corpus|").create();
+
+ Option burnInOpt = obuilder.withLongName("burnInIterations").withRequired(false).withArgument(abuilder
+ .withName("burnInIterations").withMinimum(1).withMaximum(1).withDefault("5").create())
+ .withDescription("Minimum number of iterations").withShortName("b").create();
+
+ Option convergenceOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(abuilder
+ .withName("convergence").withMinimum(1).withMaximum(1).withDefault("0.0").create())
+ .withDescription("Fractional rate of perplexity to consider convergence").withShortName("c").create();
+
+ Option reInferDocTopicsOpt = obuilder.withLongName("reInferDocTopics").withRequired(false)
+ .withArgument(abuilder.withName("reInferDocTopics").withMinimum(1).withMaximum(1)
+ .withDefault("no").create())
+ .withDescription("re-infer p(topic | doc) : [no | randstart | continue]")
+ .withShortName("rdt").create();
+
+ Option numTrainThreadsOpt = obuilder.withLongName("numTrainThreads").withRequired(false)
+ .withArgument(abuilder.withName("numTrainThreads").withMinimum(1).withMaximum(1)
+ .withDefault("1").create())
+ .withDescription("number of threads to train with")
+ .withShortName("ntt").create();
+
+ Option numUpdateThreadsOpt = obuilder.withLongName("numUpdateThreads").withRequired(false)
+ .withArgument(abuilder.withName("numUpdateThreads").withMinimum(1).withMaximum(1)
+ .withDefault("1").create())
+ .withDescription("number of threads to update the model with")
+ .withShortName("nut").create();
+
+ Option verboseOpt = obuilder.withLongName("verbose").withRequired(false)
+ .withArgument(abuilder.withName("verbose").withMinimum(1).withMaximum(1)
+ .withDefault("false").create())
+ .withDescription("print verbose information, like top-terms in each topic, during iteration")
+ .withShortName("v").create();
+
+ Group group = gbuilder.withName("Options").withOption(inputDirOpt).withOption(numTopicsOpt)
+ .withOption(alphaOpt).withOption(etaOpt)
+ .withOption(maxIterOpt).withOption(burnInOpt).withOption(convergenceOpt)
+ .withOption(dictOpt).withOption(reInferDocTopicsOpt)
+ .withOption(outputDocFileOpt).withOption(outputTopicFileOpt).withOption(dfsOpt)
+ .withOption(numTrainThreadsOpt).withOption(numUpdateThreadsOpt)
+ .withOption(modelCorpusFractionOption).withOption(verboseOpt).create();
+
+ try {
+ Parser parser = new Parser();
+
+ parser.setGroup(group);
+ parser.setHelpOption(helpOpt);
+ CommandLine cmdLine = parser.parse(args);
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return -1;
+ }
+
+ String inputDirString = (String) cmdLine.getValue(inputDirOpt);
+ String dictDirString = cmdLine.hasOption(dictOpt) ? (String)cmdLine.getValue(dictOpt) : null;
+ int numTopics = Integer.parseInt((String) cmdLine.getValue(numTopicsOpt));
+ double alpha = Double.parseDouble((String)cmdLine.getValue(alphaOpt));
+ double eta = Double.parseDouble((String)cmdLine.getValue(etaOpt));
+ int maxIterations = Integer.parseInt((String)cmdLine.getValue(maxIterOpt));
+ int burnInIterations = Integer.parseInt((String)cmdLine.getValue(burnInOpt));
+ double minFractionalErrorChange = Double.parseDouble((String) cmdLine.getValue(convergenceOpt));
+ int numTrainThreads = Integer.parseInt((String)cmdLine.getValue(numTrainThreadsOpt));
+ int numUpdateThreads = Integer.parseInt((String)cmdLine.getValue(numUpdateThreadsOpt));
+ String topicOutFile = (String)cmdLine.getValue(outputTopicFileOpt);
+ String docOutFile = (String)cmdLine.getValue(outputDocFileOpt);
+ //String reInferDocTopics = (String)cmdLine.getValue(reInferDocTopicsOpt);
+ boolean verbose = Boolean.parseBoolean((String) cmdLine.getValue(verboseOpt));
+ double modelCorpusFraction = Double.parseDouble((String)cmdLine.getValue(modelCorpusFractionOption));
+
+ long start = System.nanoTime();
+
+ if (conf.get("fs.default.name") == null) {
+ String dfsNameNode = (String)cmdLine.getValue(dfsOpt);
+ conf.set("fs.default.name", dfsNameNode);
+ }
+ String[] terms = loadDictionary(dictDirString, conf);
+ logTime("dictionary loading", System.nanoTime() - start);
+ start = System.nanoTime();
+ Matrix corpus = loadVectors(inputDirString, conf);
+ logTime("vector seqfile corpus loading", System.nanoTime() - start);
+ start = System.nanoTime();
+ InMemoryCollapsedVariationalBayes0 cvb0 =
+ new InMemoryCollapsedVariationalBayes0(corpus, terms, numTopics, alpha, eta,
+ numTrainThreads, numUpdateThreads, modelCorpusFraction);
+ logTime("cvb0 init", System.nanoTime() - start);
+
+ start = System.nanoTime();
+ cvb0.setVerbose(verbose);
+ cvb0.iterateUntilConvergence(minFractionalErrorChange, maxIterations, burnInIterations);
+ logTime("total training time", System.nanoTime() - start);
+
+ /*
+ if ("randstart".equalsIgnoreCase(reInferDocTopics)) {
+ cvb0.inferDocuments(0.0, 100, true);
+ } else if ("continue".equalsIgnoreCase(reInferDocTopics)) {
+ cvb0.inferDocuments(0.0, 100, false);
+ }
+ */
+
+ start = System.nanoTime();
+ cvb0.writeModel(new Path(topicOutFile));
+ DistributedRowMatrixWriter.write(new Path(docOutFile), conf, cvb0.docTopicCounts);
+ logTime("printTopics", System.nanoTime() - start);
+ } catch (OptionException e) {
+ log.error("Error while parsing options", e);
+ CommandLineUtil.printHelp(group);
+ }
+ return 0;
+ }
+
+ /*
+ private static Map<Integer, Map<String, Integer>> loadCorpus(String path) throws IOException {
+ List<String> lines = Resources.readLines(Resources.getResource(path), Charsets.UTF_8);
+ Map<Integer, Map<String, Integer>> corpus = Maps.newHashMap();
+ for (int i=0; i<lines.size(); i++) {
+ String line = lines.get(i);
+ Map<String, Integer> doc = Maps.newHashMap();
+ for (String s : line.split(" ")) {
+ s = s.replaceAll("\\W", "").toLowerCase().trim();
+ if (s.length() == 0) {
+ continue;
+ }
+ if (!doc.containsKey(s)) {
+ doc.put(s, 0);
+ }
+ doc.put(s, doc.get(s) + 1);
+ }
+ corpus.put(i, doc);
+ }
+ return corpus;
+ }
+ */
+
+ private static String[] loadDictionary(String dictionaryPath, Configuration conf) {
+ if (dictionaryPath == null) {
+ return null;
+ }
+ Path dictionaryFile = new Path(dictionaryPath);
+ List<Pair<Integer, String>> termList = Lists.newArrayList();
+ int maxTermId = 0;
+ // key is word value is id
+ for (Pair<Writable, IntWritable> record
+ : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) {
+ termList.add(new Pair<>(record.getSecond().get(),
+ record.getFirst().toString()));
+ maxTermId = Math.max(maxTermId, record.getSecond().get());
+ }
+ String[] terms = new String[maxTermId + 1];
+ for (Pair<Integer, String> pair : termList) {
+ terms[pair.getFirst()] = pair.getSecond();
+ }
+ return terms;
+ }
+
+ @Override
+ public Configuration getConf() {
+ return super.getConf();
+ }
+
+ private static Matrix loadVectors(String vectorPathString, Configuration conf)
+ throws IOException {
+ Path vectorPath = new Path(vectorPathString);
+ FileSystem fs = vectorPath.getFileSystem(conf);
+ List<Path> subPaths = Lists.newArrayList();
+ if (fs.isFile(vectorPath)) {
+ subPaths.add(vectorPath);
+ } else {
+ for (FileStatus fileStatus : fs.listStatus(vectorPath, PathFilters.logsCRCFilter())) {
+ subPaths.add(fileStatus.getPath());
+ }
+ }
+ List<Pair<Integer, Vector>> rowList = Lists.newArrayList();
+ int numRows = Integer.MIN_VALUE;
+ int numCols = -1;
+ boolean sequentialAccess = false;
+ for (Path subPath : subPaths) {
+ for (Pair<IntWritable, VectorWritable> record
+ : new SequenceFileIterable<IntWritable, VectorWritable>(subPath, true, conf)) {
+ int id = record.getFirst().get();
+ Vector vector = record.getSecond().get();
+ if (vector instanceof NamedVector) {
+ vector = ((NamedVector)vector).getDelegate();
+ }
+ if (numCols < 0) {
+ numCols = vector.size();
+ sequentialAccess = vector.isSequentialAccess();
+ }
+ rowList.add(Pair.of(id, vector));
+ numRows = Math.max(numRows, id);
+ }
+ }
+ numRows++;
+ Vector[] rowVectors = new Vector[numRows];
+ for (Pair<Integer, Vector> pair : rowList) {
+ rowVectors[pair.getFirst()] = pair.getSecond();
+ }
+ return new SparseRowMatrix(numRows, numCols, rowVectors, true, !sequentialAccess);
+
+ }
+
+ @Override
+ public int run(String[] strings) throws Exception {
+ return main2(strings, getConf());
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new InMemoryCollapsedVariationalBayes0(), args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java
new file mode 100644
index 0000000..912b6d5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java
@@ -0,0 +1,301 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Multithreaded LDA model trainer class, which primarily operates by running a "map/reduce"
+ * operation, all in memory locally (ie not a hadoop job!) : the "map" operation is to take
+ * the "read-only" {@link TopicModel} and use it to iteratively learn the p(topic|term, doc)
+ * distribution for documents (this can be done in parallel across many documents, as the
+ * "read-only" model is, well, read-only. Then the outputs of this are "reduced" onto the
+ * "write" model, and these updates are not parallelizable in the same way: individual
+ * documents can't be added to the same entries in different threads at the same time, but
+ * updates across many topics to the same term from the same document can be done in parallel,
+ * so they are.
+ *
+ * Because computation is done asynchronously, when iteration is done, it's important to call
+ * the stop() method, which blocks until work is complete.
+ *
+ * Setting the read model and the write model to be the same object may not quite work yet,
+ * on account of parallelism badness.
+ */
+public class ModelTrainer {
+
+ private static final Logger log = LoggerFactory.getLogger(ModelTrainer.class);
+
+ private final int numTopics;
+ private final int numTerms;
+ private TopicModel readModel;
+ private TopicModel writeModel;
+ private ThreadPoolExecutor threadPool;
+ private BlockingQueue<Runnable> workQueue;
+ private final int numTrainThreads;
+ private final boolean isReadWrite;
+
+ public ModelTrainer(TopicModel initialReadModel, TopicModel initialWriteModel,
+ int numTrainThreads, int numTopics, int numTerms) {
+ this.readModel = initialReadModel;
+ this.writeModel = initialWriteModel;
+ this.numTrainThreads = numTrainThreads;
+ this.numTopics = numTopics;
+ this.numTerms = numTerms;
+ isReadWrite = initialReadModel == initialWriteModel;
+ }
+
+ /**
+ * WARNING: this constructor may not lead to good behavior. What should be verified is that
+ * the model updating process does not conflict with model reading. It might work, but then
+ * again, it might not!
+ * @param model to be used for both reading (inference) and accumulating (learning)
+ * @param numTrainThreads
+ * @param numTopics
+ * @param numTerms
+ */
+ public ModelTrainer(TopicModel model, int numTrainThreads, int numTopics, int numTerms) {
+ this(model, model, numTrainThreads, numTopics, numTerms);
+ }
+
+ public TopicModel getReadModel() {
+ return readModel;
+ }
+
+ public void start() {
+ log.info("Starting training threadpool with {} threads", numTrainThreads);
+ workQueue = new ArrayBlockingQueue<>(numTrainThreads * 10);
+ threadPool = new ThreadPoolExecutor(numTrainThreads, numTrainThreads, 0, TimeUnit.SECONDS,
+ workQueue);
+ threadPool.allowCoreThreadTimeOut(false);
+ threadPool.prestartAllCoreThreads();
+ writeModel.reset();
+ }
+
+ public void train(VectorIterable matrix, VectorIterable docTopicCounts) {
+ train(matrix, docTopicCounts, 1);
+ }
+
+ public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts) {
+ return calculatePerplexity(matrix, docTopicCounts, 0);
+ }
+
+ public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts,
+ double testFraction) {
+ Iterator<MatrixSlice> docIterator = matrix.iterator();
+ Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator();
+ double perplexity = 0;
+ double matrixNorm = 0;
+ while (docIterator.hasNext() && docTopicIterator.hasNext()) {
+ MatrixSlice docSlice = docIterator.next();
+ MatrixSlice topicSlice = docTopicIterator.next();
+ int docId = docSlice.index();
+ Vector document = docSlice.vector();
+ Vector topicDist = topicSlice.vector();
+ if (testFraction == 0 || docId % (1 / testFraction) == 0) {
+ trainSync(document, topicDist, false, 10);
+ perplexity += readModel.perplexity(document, topicDist);
+ matrixNorm += document.norm(1);
+ }
+ }
+ return perplexity / matrixNorm;
+ }
+
+ public void train(VectorIterable matrix, VectorIterable docTopicCounts, int numDocTopicIters) {
+ start();
+ Iterator<MatrixSlice> docIterator = matrix.iterator();
+ Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator();
+ long startTime = System.nanoTime();
+ int i = 0;
+ double[] times = new double[100];
+ Map<Vector, Vector> batch = Maps.newHashMap();
+ int numTokensInBatch = 0;
+ long batchStart = System.nanoTime();
+ while (docIterator.hasNext() && docTopicIterator.hasNext()) {
+ i++;
+ Vector document = docIterator.next().vector();
+ Vector topicDist = docTopicIterator.next().vector();
+ if (isReadWrite) {
+ if (batch.size() < numTrainThreads) {
+ batch.put(document, topicDist);
+ if (log.isDebugEnabled()) {
+ numTokensInBatch += document.getNumNondefaultElements();
+ }
+ } else {
+ batchTrain(batch, true, numDocTopicIters);
+ long time = System.nanoTime();
+ log.debug("trained {} docs with {} tokens, start time {}, end time {}",
+ numTrainThreads, numTokensInBatch, batchStart, time);
+ batchStart = time;
+ numTokensInBatch = 0;
+ }
+ } else {
+ long start = System.nanoTime();
+ train(document, topicDist, true, numDocTopicIters);
+ if (log.isDebugEnabled()) {
+ times[i % times.length] =
+ (System.nanoTime() - start) / (1.0e6 * document.getNumNondefaultElements());
+ if (i % 100 == 0) {
+ long time = System.nanoTime() - startTime;
+ log.debug("trained {} documents in {}ms", i, time / 1.0e6);
+ if (i % 500 == 0) {
+ Arrays.sort(times);
+ log.debug("training took median {}ms per token-instance", times[times.length / 2]);
+ }
+ }
+ }
+ }
+ }
+ stop();
+ }
+
+ public void batchTrain(Map<Vector, Vector> batch, boolean update, int numDocTopicsIters) {
+ while (true) {
+ try {
+ List<TrainerRunnable> runnables = Lists.newArrayList();
+ for (Map.Entry<Vector, Vector> entry : batch.entrySet()) {
+ runnables.add(new TrainerRunnable(readModel, null, entry.getKey(),
+ entry.getValue(), new SparseRowMatrix(numTopics, numTerms, true),
+ numDocTopicsIters));
+ }
+ threadPool.invokeAll(runnables);
+ if (update) {
+ for (TrainerRunnable runnable : runnables) {
+ writeModel.update(runnable.docTopicModel);
+ }
+ }
+ break;
+ } catch (InterruptedException e) {
+ log.warn("Interrupted during batch training, retrying!", e);
+ }
+ }
+ }
+
+ public void train(Vector document, Vector docTopicCounts, boolean update, int numDocTopicIters) {
+ while (true) {
+ try {
+ workQueue.put(new TrainerRunnable(readModel, update
+ ? writeModel
+ : null, document, docTopicCounts, new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters));
+ return;
+ } catch (InterruptedException e) {
+ log.warn("Interrupted waiting to submit document to work queue: {}", document, e);
+ }
+ }
+ }
+
+ public void trainSync(Vector document, Vector docTopicCounts, boolean update,
+ int numDocTopicIters) {
+ new TrainerRunnable(readModel, update
+ ? writeModel
+ : null, document, docTopicCounts, new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters).run();
+ }
+
+ public double calculatePerplexity(Vector document, Vector docTopicCounts, int numDocTopicIters) {
+ TrainerRunnable runner = new TrainerRunnable(readModel, null, document, docTopicCounts,
+ new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters);
+ return runner.call();
+ }
+
+ public void stop() {
+ long startTime = System.nanoTime();
+ log.info("Initiating stopping of training threadpool");
+ try {
+ threadPool.shutdown();
+ if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) {
+ log.warn("Threadpool timed out on await termination - jobs still running!");
+ }
+ long newTime = System.nanoTime();
+ log.info("threadpool took: {}ms", (newTime - startTime) / 1.0e6);
+ startTime = newTime;
+ readModel.stop();
+ newTime = System.nanoTime();
+ log.info("readModel.stop() took {}ms", (newTime - startTime) / 1.0e6);
+ startTime = newTime;
+ writeModel.stop();
+ newTime = System.nanoTime();
+ log.info("writeModel.stop() took {}ms", (newTime - startTime) / 1.0e6);
+ TopicModel tmpModel = writeModel;
+ writeModel = readModel;
+ readModel = tmpModel;
+ } catch (InterruptedException e) {
+ log.error("Interrupted shutting down!", e);
+ }
+ }
+
+ public void persist(Path outputPath) throws IOException {
+ readModel.persist(outputPath, true);
+ }
+
+ private static final class TrainerRunnable implements Runnable, Callable<Double> {
+ private final TopicModel readModel;
+ private final TopicModel writeModel;
+ private final Vector document;
+ private final Vector docTopics;
+ private final Matrix docTopicModel;
+ private final int numDocTopicIters;
+
+ private TrainerRunnable(TopicModel readModel, TopicModel writeModel, Vector document,
+ Vector docTopics, Matrix docTopicModel, int numDocTopicIters) {
+ this.readModel = readModel;
+ this.writeModel = writeModel;
+ this.document = document;
+ this.docTopics = docTopics;
+ this.docTopicModel = docTopicModel;
+ this.numDocTopicIters = numDocTopicIters;
+ }
+
+ @Override
+ public void run() {
+ for (int i = 0; i < numDocTopicIters; i++) {
+ // synchronous read-only call:
+ readModel.trainDocTopicModel(document, docTopics, docTopicModel);
+ }
+ if (writeModel != null) {
+ // parallel call which is read-only on the docTopicModel, and write-only on the writeModel
+ // this method does not return until all rows of the docTopicModel have been submitted
+ // to write work queues
+ writeModel.update(docTopicModel);
+ }
+ }
+
+ @Override
+ public Double call() {
+ run();
+ return readModel.perplexity(document, docTopics);
+ }
+ }
+}
[05/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java b/mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
new file mode 100644
index 0000000..35de87e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java
@@ -0,0 +1,674 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.canopy;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Set;
+
+import com.google.common.collect.Iterables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.common.DummyRecordWriter;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+@Deprecated
+public final class TestCanopyCreation extends MahoutTestCase {
+
+ private static final double[][] RAW = { { 1, 1 }, { 2, 1 }, { 1, 2 },
+ { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 } };
+
+ private List<Canopy> referenceManhattan;
+
+ private final DistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure();
+
+ private List<Vector> manhattanCentroids;
+
+ private List<Canopy> referenceEuclidean;
+
+ private final DistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure();
+
+ private List<Vector> euclideanCentroids;
+
+ private FileSystem fs;
+
+ private static List<VectorWritable> getPointsWritable() {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : RAW) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+ private static List<Vector> getPoints() {
+ List<Vector> points = Lists.newArrayList();
+ for (double[] fr : RAW) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(vec);
+ }
+ return points;
+ }
+
+ /**
+ * Print the canopies to the transcript
+ *
+ * @param canopies
+ * a List<Canopy>
+ */
+ private static void printCanopies(Iterable<Canopy> canopies) {
+ for (Canopy canopy : canopies) {
+ System.out.println(canopy.asFormatString(null));
+ }
+ }
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ fs = FileSystem.get(getConfiguration());
+ referenceManhattan = CanopyClusterer.createCanopies(getPoints(),
+ manhattanDistanceMeasure, 3.1, 2.1);
+ manhattanCentroids = CanopyClusterer.getCenters(referenceManhattan);
+ referenceEuclidean = CanopyClusterer.createCanopies(getPoints(),
+ euclideanDistanceMeasure, 3.1, 2.1);
+ euclideanCentroids = CanopyClusterer.getCenters(referenceEuclidean);
+ }
+
+ /**
+ * Story: User can cluster points using a ManhattanDistanceMeasure and a
+ * reference implementation
+ */
+ @Test
+ public void testReferenceManhattan() throws Exception {
+ // see setUp for cluster creation
+ printCanopies(referenceManhattan);
+ assertEquals("number of canopies", 3, referenceManhattan.size());
+ for (int canopyIx = 0; canopyIx < referenceManhattan.size(); canopyIx++) {
+ Canopy testCanopy = referenceManhattan.get(canopyIx);
+ int[] expectedNumPoints = { 4, 4, 3 };
+ double[][] expectedCentroids = { { 1.5, 1.5 }, { 4.0, 4.0 },
+ { 4.666666666666667, 4.6666666666666667 } };
+ assertEquals("canopy points " + canopyIx, testCanopy.getNumObservations(),
+ expectedNumPoints[canopyIx]);
+ double[] refCentroid = expectedCentroids[canopyIx];
+ Vector testCentroid = testCanopy.computeCentroid();
+ for (int pointIx = 0; pointIx < refCentroid.length; pointIx++) {
+ assertEquals("canopy centroid " + canopyIx + '[' + pointIx + ']',
+ refCentroid[pointIx], testCentroid.get(pointIx), EPSILON);
+ }
+ }
+ }
+
+ /**
+ * Story: User can cluster points using a EuclideanDistanceMeasure and a
+ * reference implementation
+ */
+ @Test
+ public void testReferenceEuclidean() throws Exception {
+ // see setUp for cluster creation
+ printCanopies(referenceEuclidean);
+ assertEquals("number of canopies", 3, referenceEuclidean.size());
+ int[] expectedNumPoints = { 5, 5, 3 };
+ double[][] expectedCentroids = { { 1.8, 1.8 }, { 4.2, 4.2 },
+ { 4.666666666666667, 4.666666666666667 } };
+ for (int canopyIx = 0; canopyIx < referenceEuclidean.size(); canopyIx++) {
+ Canopy testCanopy = referenceEuclidean.get(canopyIx);
+ assertEquals("canopy points " + canopyIx, testCanopy.getNumObservations(),
+ expectedNumPoints[canopyIx]);
+ double[] refCentroid = expectedCentroids[canopyIx];
+ Vector testCentroid = testCanopy.computeCentroid();
+ for (int pointIx = 0; pointIx < refCentroid.length; pointIx++) {
+ assertEquals("canopy centroid " + canopyIx + '[' + pointIx + ']',
+ refCentroid[pointIx], testCentroid.get(pointIx), EPSILON);
+ }
+ }
+ }
+
+ /**
+ * Story: User can produce initial canopy centers using a
+ * ManhattanDistanceMeasure and a CanopyMapper which clusters input points to
+ * produce an output set of canopy centroid points.
+ */
+ @Test
+ public void testCanopyMapperManhattan() throws Exception {
+ CanopyMapper mapper = new CanopyMapper();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, manhattanDistanceMeasure
+ .getClass().getName());
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "0");
+ DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<Text, VectorWritable>();
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter
+ .build(mapper, conf, writer);
+ mapper.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ // map the data
+ for (VectorWritable point : points) {
+ mapper.map(new Text(), point, context);
+ }
+ mapper.cleanup(context);
+ assertEquals("Number of map results", 1, writer.getData().size());
+ // now verify the output
+ List<VectorWritable> data = writer.getValue(new Text("centroid"));
+ assertEquals("Number of centroids", 3, data.size());
+ for (int i = 0; i < data.size(); i++) {
+ assertEquals("Centroid error",
+ manhattanCentroids.get(i).asFormatString(), data.get(i).get()
+ .asFormatString());
+ }
+ }
+
+ /**
+ * Story: User can produce initial canopy centers using a
+ * EuclideanDistanceMeasure and a CanopyMapper/Combiner which clusters input
+ * points to produce an output set of canopy centroid points.
+ */
+ @Test
+ public void testCanopyMapperEuclidean() throws Exception {
+ CanopyMapper mapper = new CanopyMapper();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, euclideanDistanceMeasure
+ .getClass().getName());
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "0");
+ DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<Text, VectorWritable>();
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter
+ .build(mapper, conf, writer);
+ mapper.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ // map the data
+ for (VectorWritable point : points) {
+ mapper.map(new Text(), point, context);
+ }
+ mapper.cleanup(context);
+ assertEquals("Number of map results", 1, writer.getData().size());
+ // now verify the output
+ List<VectorWritable> data = writer.getValue(new Text("centroid"));
+ assertEquals("Number of centroids", 3, data.size());
+ for (int i = 0; i < data.size(); i++) {
+ assertEquals("Centroid error",
+ euclideanCentroids.get(i).asFormatString(), data.get(i).get()
+ .asFormatString());
+ }
+ }
+
+ /**
+ * Story: User can produce final canopy centers using a
+ * ManhattanDistanceMeasure and a CanopyReducer which clusters input centroid
+ * points to produce an output set of final canopy centroid points.
+ */
+ @Test
+ public void testCanopyReducerManhattan() throws Exception {
+ CanopyReducer reducer = new CanopyReducer();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY,
+ "org.apache.mahout.common.distance.ManhattanDistanceMeasure");
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "0");
+ DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<Text, ClusterWritable>();
+ Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context = DummyRecordWriter
+ .build(reducer, conf, writer, Text.class, VectorWritable.class);
+ reducer.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ reducer.reduce(new Text("centroid"), points, context);
+ Iterable<Text> keys = writer.getKeysInInsertionOrder();
+ assertEquals("Number of centroids", 3, Iterables.size(keys));
+ int i = 0;
+ for (Text key : keys) {
+ List<ClusterWritable> data = writer.getValue(key);
+ ClusterWritable clusterWritable = data.get(0);
+ Canopy canopy = (Canopy) clusterWritable.getValue();
+ assertEquals(manhattanCentroids.get(i).asFormatString() + " is not equal to "
+ + canopy.computeCentroid().asFormatString(),
+ manhattanCentroids.get(i), canopy.computeCentroid());
+ i++;
+ }
+ }
+
+ /**
+ * Story: User can produce final canopy centers using a
+ * EuclideanDistanceMeasure and a CanopyReducer which clusters input centroid
+ * points to produce an output set of final canopy centroid points.
+ */
+ @Test
+ public void testCanopyReducerEuclidean() throws Exception {
+ CanopyReducer reducer = new CanopyReducer();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, "org.apache.mahout.common.distance.EuclideanDistanceMeasure");
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "0");
+ DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<Text, ClusterWritable>();
+ Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context =
+ DummyRecordWriter.build(reducer, conf, writer, Text.class, VectorWritable.class);
+ reducer.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ reducer.reduce(new Text("centroid"), points, context);
+ Iterable<Text> keys = writer.getKeysInInsertionOrder();
+ assertEquals("Number of centroids", 3, Iterables.size(keys));
+ int i = 0;
+ for (Text key : keys) {
+ List<ClusterWritable> data = writer.getValue(key);
+ ClusterWritable clusterWritable = data.get(0);
+ Canopy canopy = (Canopy) clusterWritable.getValue();
+ assertEquals(euclideanCentroids.get(i).asFormatString() + " is not equal to "
+ + canopy.computeCentroid().asFormatString(),
+ euclideanCentroids.get(i), canopy.computeCentroid());
+ i++;
+ }
+ }
+
+ /**
+ * Story: User can produce final canopy centers using a Hadoop map/reduce job
+ * and a ManhattanDistanceMeasure.
+ */
+ @Test
+ public void testCanopyGenManhattanMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration config = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file1"), fs, config);
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file2"), fs, config);
+ // now run the Canopy Driver
+ Path output = getTestTempDirPath("output");
+ CanopyDriver.run(config, getTestTempDirPath("testdata"), output,
+ manhattanDistanceMeasure, 3.1, 2.1, false, 0.0, false);
+
+ // verify output from sequence file
+ Path path = new Path(output, "clusters-0-final/part-r-00000");
+ FileSystem fs = FileSystem.get(path.toUri(), config);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config);
+ try {
+ Writable key = new Text();
+ ClusterWritable clusterWritable = new ClusterWritable();
+ assertTrue("more to come", reader.next(key, clusterWritable));
+ assertEquals("1st key", "C-0", key.toString());
+
+ List<Pair<Double,Double>> refCenters = Lists.newArrayList();
+ refCenters.add(new Pair<Double,Double>(1.5,1.5));
+ refCenters.add(new Pair<Double,Double>(4.333333333333334,4.333333333333334));
+ Pair<Double,Double> c = new Pair<Double,Double>(clusterWritable.getValue() .getCenter().get(0),
+ clusterWritable.getValue().getCenter().get(1));
+ assertTrue("center "+c+" not found", findAndRemove(c, refCenters, EPSILON));
+ assertTrue("more to come", reader.next(key, clusterWritable));
+ assertEquals("2nd key", "C-1", key.toString());
+ c = new Pair<Double,Double>(clusterWritable.getValue().getCenter().get(0),
+ clusterWritable.getValue().getCenter().get(1));
+ assertTrue("center " + c + " not found", findAndRemove(c, refCenters, EPSILON));
+ assertFalse("more to come", reader.next(key, clusterWritable));
+ } finally {
+ Closeables.close(reader, true);
+ }
+ }
+
+ static boolean findAndRemove(Pair<Double, Double> target, Collection<Pair<Double, Double>> list, double epsilon) {
+ for (Pair<Double,Double> curr : list) {
+ if ( (Math.abs(target.getFirst() - curr.getFirst()) < epsilon)
+ && (Math.abs(target.getSecond() - curr.getSecond()) < epsilon) ) {
+ list.remove(curr);
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Story: User can produce final canopy centers using a Hadoop map/reduce job
+ * and a EuclideanDistanceMeasure.
+ */
+ @Test
+ public void testCanopyGenEuclideanMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration config = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file1"), fs, config);
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file2"), fs, config);
+ // now run the Canopy Driver
+ Path output = getTestTempDirPath("output");
+ CanopyDriver.run(config, getTestTempDirPath("testdata"), output,
+ euclideanDistanceMeasure, 3.1, 2.1, false, 0.0, false);
+
+ // verify output from sequence file
+ Path path = new Path(output, "clusters-0-final/part-r-00000");
+ FileSystem fs = FileSystem.get(path.toUri(), config);
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config);
+ try {
+ Writable key = new Text();
+ ClusterWritable clusterWritable = new ClusterWritable();
+ assertTrue("more to come", reader.next(key, clusterWritable));
+ assertEquals("1st key", "C-0", key.toString());
+
+ List<Pair<Double,Double>> refCenters = Lists.newArrayList();
+ refCenters.add(new Pair<Double,Double>(1.8,1.8));
+ refCenters.add(new Pair<Double,Double>(4.433333333333334, 4.433333333333334));
+ Pair<Double,Double> c = new Pair<Double,Double>(clusterWritable.getValue().getCenter().get(0),
+ clusterWritable.getValue().getCenter().get(1));
+ assertTrue("center "+c+" not found", findAndRemove(c, refCenters, EPSILON));
+ assertTrue("more to come", reader.next(key, clusterWritable));
+ assertEquals("2nd key", "C-1", key.toString());
+ c = new Pair<Double,Double>(clusterWritable.getValue().getCenter().get(0),
+ clusterWritable.getValue().getCenter().get(1));
+ assertTrue("center "+c+" not found", findAndRemove(c, refCenters, EPSILON));
+ assertFalse("more to come", reader.next(key, clusterWritable));
+ } finally {
+ Closeables.close(reader, true);
+ }
+ }
+
+ /** Story: User can cluster points using sequential execution */
+ @Test
+ public void testClusteringManhattanSeq() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration config = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file1"), fs, config);
+ // now run the Canopy Driver in sequential mode
+ Path output = getTestTempDirPath("output");
+ CanopyDriver.run(config, getTestTempDirPath("testdata"), output,
+ manhattanDistanceMeasure, 3.1, 2.1, true, 0.0, true);
+
+ // verify output from sequence file
+ Path path = new Path(output, "clusters-0-final/part-r-00000");
+ int ix = 0;
+ for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(path, true,
+ config)) {
+ assertEquals("Center [" + ix + ']', manhattanCentroids.get(ix), clusterWritable.getValue()
+ .getCenter());
+ ix++;
+ }
+
+ path = new Path(output, "clusteredPoints/part-m-0");
+ long count = HadoopUtil.countRecords(path, config);
+ assertEquals("number of points", points.size(), count);
+ }
+
+ /** Story: User can cluster points using sequential execution */
+ @Test
+ public void testClusteringEuclideanSeq() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration config = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file1"), fs, config);
+ // now run the Canopy Driver in sequential mode
+ Path output = getTestTempDirPath("output");
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION),
+ getTestTempDirPath("testdata").toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.T1_OPTION), "3.1",
+ optKey(DefaultOptionCreator.T2_OPTION), "2.1",
+ optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION),
+ optKey(DefaultOptionCreator.METHOD_OPTION),
+ DefaultOptionCreator.SEQUENTIAL_METHOD };
+ ToolRunner.run(config, new CanopyDriver(), args);
+
+ // verify output from sequence file
+ Path path = new Path(output, "clusters-0-final/part-r-00000");
+
+ int ix = 0;
+ for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(path, true,
+ config)) {
+ assertEquals("Center [" + ix + ']', euclideanCentroids.get(ix), clusterWritable.getValue()
+ .getCenter());
+ ix++;
+ }
+
+ path = new Path(output, "clusteredPoints/part-m-0");
+ long count = HadoopUtil.countRecords(path, config);
+ assertEquals("number of points", points.size(), count);
+ }
+
+ /** Story: User can remove outliers while clustering points using sequential execution */
+ @Test
+ public void testClusteringEuclideanWithOutlierRemovalSeq() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration config = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points,
+ getTestTempFilePath("testdata/file1"), fs, config);
+ // now run the Canopy Driver in sequential mode
+ Path output = getTestTempDirPath("output");
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION),
+ getTestTempDirPath("testdata").toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.T1_OPTION), "3.1",
+ optKey(DefaultOptionCreator.T2_OPTION), "2.1",
+ optKey(DefaultOptionCreator.OUTLIER_THRESHOLD), "0.5",
+ optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION),
+ optKey(DefaultOptionCreator.METHOD_OPTION),
+ DefaultOptionCreator.SEQUENTIAL_METHOD };
+ ToolRunner.run(config, new CanopyDriver(), args);
+
+ // verify output from sequence file
+ Path path = new Path(output, "clusters-0-final/part-r-00000");
+
+ int ix = 0;
+ for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(path, true,
+ config)) {
+ assertEquals("Center [" + ix + ']', euclideanCentroids.get(ix), clusterWritable.getValue()
+ .getCenter());
+ ix++;
+ }
+
+ path = new Path(output, "clusteredPoints/part-m-0");
+ long count = HadoopUtil.countRecords(path, config);
+ int expectedPointsHavingPDFGreaterThanThreshold = 6;
+ assertEquals("number of points", expectedPointsHavingPDFGreaterThanThreshold, count);
+ }
+
+
+ /**
+ * Story: User can produce final point clustering using a Hadoop map/reduce
+ * job and a ManhattanDistanceMeasure.
+ */
+ @Test
+ public void testClusteringManhattanMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file2"), fs, conf);
+ // now run the Job
+ Path output = getTestTempDirPath("output");
+ CanopyDriver.run(conf, getTestTempDirPath("testdata"), output,
+ manhattanDistanceMeasure, 3.1, 2.1, true, 0.0, false);
+ Path path = new Path(output, "clusteredPoints/part-m-00000");
+ long count = HadoopUtil.countRecords(path, conf);
+ assertEquals("number of points", points.size(), count);
+ }
+
+ /**
+ * Story: User can produce final point clustering using a Hadoop map/reduce
+ * job and a EuclideanDistanceMeasure.
+ */
+ @Test
+ public void testClusteringEuclideanMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file2"), fs, conf);
+ // now run the Job using the run() command. Others can use runJob().
+ Path output = getTestTempDirPath("output");
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION),
+ getTestTempDirPath("testdata").toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.T1_OPTION), "3.1",
+ optKey(DefaultOptionCreator.T2_OPTION), "2.1",
+ optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION) };
+ ToolRunner.run(getConfiguration(), new CanopyDriver(), args);
+ Path path = new Path(output, "clusteredPoints/part-m-00000");
+ long count = HadoopUtil.countRecords(path, conf);
+ assertEquals("number of points", points.size(), count);
+ }
+
+ /**
+ * Story: User can produce final point clustering using a Hadoop map/reduce
+ * job and a EuclideanDistanceMeasure and outlier removal threshold.
+ */
+ @Test
+ public void testClusteringEuclideanWithOutlierRemovalMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable();
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, true,
+ getTestTempFilePath("testdata/file2"), fs, conf);
+ // now run the Job using the run() command. Others can use runJob().
+ Path output = getTestTempDirPath("output");
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION),
+ getTestTempDirPath("testdata").toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.T1_OPTION), "3.1",
+ optKey(DefaultOptionCreator.T2_OPTION), "2.1",
+ optKey(DefaultOptionCreator.OUTLIER_THRESHOLD), "0.7",
+ optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION) };
+ ToolRunner.run(getConfiguration(), new CanopyDriver(), args);
+ Path path = new Path(output, "clusteredPoints/part-m-00000");
+ long count = HadoopUtil.countRecords(path, conf);
+ int expectedPointsAfterOutlierRemoval = 8;
+ assertEquals("number of points", expectedPointsAfterOutlierRemoval, count);
+ }
+
+
+ /**
+ * Story: User can set T3 and T4 values to be used by the reducer for its T1
+ * and T2 thresholds
+ */
+ @Test
+ public void testCanopyReducerT3T4Configuration() throws Exception {
+ CanopyReducer reducer = new CanopyReducer();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY,
+ "org.apache.mahout.common.distance.ManhattanDistanceMeasure");
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.T3_KEY, String.valueOf(1.1));
+ conf.set(CanopyConfigKeys.T4_KEY, String.valueOf(0.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "0");
+ DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<Text, ClusterWritable>();
+ Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context = DummyRecordWriter
+ .build(reducer, conf, writer, Text.class, VectorWritable.class);
+ reducer.setup(context);
+ assertEquals(1.1, reducer.getCanopyClusterer().getT1(), EPSILON);
+ assertEquals(0.1, reducer.getCanopyClusterer().getT2(), EPSILON);
+ }
+
+ /**
+ * Story: User can specify a clustering limit that prevents output of small
+ * clusters
+ */
+ @Test
+ public void testCanopyMapperClusterFilter() throws Exception {
+ CanopyMapper mapper = new CanopyMapper();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, manhattanDistanceMeasure
+ .getClass().getName());
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "3");
+ DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<Text, VectorWritable>();
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter
+ .build(mapper, conf, writer);
+ mapper.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ // map the data
+ for (VectorWritable point : points) {
+ mapper.map(new Text(), point, context);
+ }
+ mapper.cleanup(context);
+ assertEquals("Number of map results", 1, writer.getData().size());
+ // now verify the output
+ List<VectorWritable> data = writer.getValue(new Text("centroid"));
+ assertEquals("Number of centroids", 2, data.size());
+ }
+
+ /**
+ * Story: User can specify a cluster filter that limits the minimum size of
+ * canopies produced by the reducer
+ */
+ @Test
+ public void testCanopyReducerClusterFilter() throws Exception {
+ CanopyReducer reducer = new CanopyReducer();
+ Configuration conf = getConfiguration();
+ conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY,
+ "org.apache.mahout.common.distance.ManhattanDistanceMeasure");
+ conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1));
+ conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1));
+ conf.set(CanopyConfigKeys.CF_KEY, "3");
+ DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<Text, ClusterWritable>();
+ Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context = DummyRecordWriter
+ .build(reducer, conf, writer, Text.class, VectorWritable.class);
+ reducer.setup(context);
+
+ List<VectorWritable> points = getPointsWritable();
+ reducer.reduce(new Text("centroid"), points, context);
+ Set<Text> keys = writer.getKeys();
+ assertEquals("Number of centroids", 2, keys.size());
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java b/mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
new file mode 100644
index 0000000..cbf0e55
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
@@ -0,0 +1,255 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.classify;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Set;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.clustering.iterator.CanopyClusteringPolicy;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+
+public class ClusterClassificationDriverTest extends MahoutTestCase {
+
+ private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4},
+ {5, 4}, {4, 5}, {5, 5}, {9, 9}, {8, 8}};
+
+ private FileSystem fs;
+ private Path clusteringOutputPath;
+ private Configuration conf;
+ private Path pointsPath;
+ private Path classifiedOutputPath;
+ private List<Vector> firstCluster;
+ private List<Vector> secondCluster;
+ private List<Vector> thirdCluster;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ Configuration conf = getConfiguration();
+ fs = FileSystem.get(conf);
+ firstCluster = Lists.newArrayList();
+ secondCluster = Lists.newArrayList();
+ thirdCluster = Lists.newArrayList();
+
+ }
+
+ private static List<VectorWritable> getPointsWritable(double[][] raw) {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : raw) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+ @Test
+ public void testVectorClassificationWithOutlierRemovalMR() throws Exception {
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ pointsPath = getTestTempDirPath("points");
+ clusteringOutputPath = getTestTempDirPath("output");
+ classifiedOutputPath = getTestTempDirPath("classifiedClusters");
+ HadoopUtil.delete(conf, classifiedOutputPath);
+
+ conf = getConfiguration();
+
+ ClusteringTestUtils.writePointsToFile(points, true,
+ new Path(pointsPath, "file1"), fs, conf);
+ runClustering(pointsPath, conf, false);
+ runClassificationWithOutlierRemoval(false);
+ collectVectorsForAssertion();
+ assertVectorsWithOutlierRemoval();
+ }
+
+ @Test
+ public void testVectorClassificationWithoutOutlierRemoval() throws Exception {
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ pointsPath = getTestTempDirPath("points");
+ clusteringOutputPath = getTestTempDirPath("output");
+ classifiedOutputPath = getTestTempDirPath("classify");
+
+ conf = getConfiguration();
+
+ ClusteringTestUtils.writePointsToFile(points,
+ new Path(pointsPath, "file1"), fs, conf);
+ runClustering(pointsPath, conf, true);
+ runClassificationWithoutOutlierRemoval();
+ collectVectorsForAssertion();
+ assertVectorsWithoutOutlierRemoval();
+ }
+
+ @Test
+ public void testVectorClassificationWithOutlierRemoval() throws Exception {
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ pointsPath = getTestTempDirPath("points");
+ clusteringOutputPath = getTestTempDirPath("output");
+ classifiedOutputPath = getTestTempDirPath("classify");
+
+ conf = getConfiguration();
+
+ ClusteringTestUtils.writePointsToFile(points,
+ new Path(pointsPath, "file1"), fs, conf);
+ runClustering(pointsPath, conf, true);
+ runClassificationWithOutlierRemoval(true);
+ collectVectorsForAssertion();
+ assertVectorsWithOutlierRemoval();
+ }
+
+ private void runClustering(Path pointsPath, Configuration conf,
+ Boolean runSequential) throws IOException, InterruptedException,
+ ClassNotFoundException {
+ CanopyDriver.run(conf, pointsPath, clusteringOutputPath,
+ new ManhattanDistanceMeasure(), 3.1, 2.1, false, 0.0, runSequential);
+ Path finalClustersPath = new Path(clusteringOutputPath, "clusters-0-final");
+ ClusterClassifier.writePolicy(new CanopyClusteringPolicy(),
+ finalClustersPath);
+ }
+
+ private void runClassificationWithoutOutlierRemoval()
+ throws IOException, InterruptedException, ClassNotFoundException {
+ ClusterClassificationDriver.run(getConfiguration(), pointsPath, clusteringOutputPath, classifiedOutputPath, 0.0, true, true);
+ }
+
+ private void runClassificationWithOutlierRemoval(boolean runSequential)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ ClusterClassificationDriver.run(getConfiguration(), pointsPath, clusteringOutputPath, classifiedOutputPath, 0.73, true, runSequential);
+ }
+
+ private void collectVectorsForAssertion() throws IOException {
+ Path[] partFilePaths = FileUtil.stat2Paths(fs
+ .globStatus(classifiedOutputPath));
+ FileStatus[] listStatus = fs.listStatus(partFilePaths,
+ PathFilters.partFilter());
+ for (FileStatus partFile : listStatus) {
+ SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs,
+ partFile.getPath(), conf);
+ Writable clusterIdAsKey = new IntWritable();
+ WeightedPropertyVectorWritable point = new WeightedPropertyVectorWritable();
+ while (classifiedVectors.next(clusterIdAsKey, point)) {
+ collectVector(clusterIdAsKey.toString(), point.getVector());
+ }
+ }
+ }
+
+ private void collectVector(String clusterId, Vector vector) {
+ if ("0".equals(clusterId)) {
+ firstCluster.add(vector);
+ } else if ("1".equals(clusterId)) {
+ secondCluster.add(vector);
+ } else if ("2".equals(clusterId)) {
+ thirdCluster.add(vector);
+ }
+ }
+
+ private void assertVectorsWithOutlierRemoval() {
+ checkClustersWithOutlierRemoval();
+ }
+
+ private void assertVectorsWithoutOutlierRemoval() {
+ assertFirstClusterWithoutOutlierRemoval();
+ assertSecondClusterWithoutOutlierRemoval();
+ assertThirdClusterWithoutOutlierRemoval();
+ }
+
+ private void assertThirdClusterWithoutOutlierRemoval() {
+ Assert.assertEquals(2, thirdCluster.size());
+ for (Vector vector : thirdCluster) {
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:9.0,1:9.0}",
+ "{0:8.0,1:8.0}"}, vector.asFormatString()));
+ }
+ }
+
+ private void assertSecondClusterWithoutOutlierRemoval() {
+ Assert.assertEquals(4, secondCluster.size());
+ for (Vector vector : secondCluster) {
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:4.0,1:4.0}",
+ "{0:5.0,1:4.0}", "{0:4.0,1:5.0}", "{0:5.0,1:5.0}"},
+ vector.asFormatString()));
+ }
+ }
+
+ private void assertFirstClusterWithoutOutlierRemoval() {
+ Assert.assertEquals(3, firstCluster.size());
+ for (Vector vector : firstCluster) {
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:1.0,1:1.0}",
+ "{0:2.0,1:1.0}", "{0:1.0,1:2.0}"}, vector.asFormatString()));
+ }
+ }
+
+ private void checkClustersWithOutlierRemoval() {
+ Set<String> reference = Sets.newHashSet("{0:9.0,1:9.0}", "{0:1.0,1:1.0}");
+
+ List<List<Vector>> clusters = Lists.newArrayList();
+ clusters.add(firstCluster);
+ clusters.add(secondCluster);
+ clusters.add(thirdCluster);
+
+ int singletonCnt = 0;
+ int emptyCnt = 0;
+ for (List<Vector> vList : clusters) {
+ if (vList.isEmpty()) {
+ emptyCnt++;
+ } else {
+ singletonCnt++;
+ assertEquals("expecting only singleton clusters; got size=" + vList.size(), 1, vList.size());
+ if (vList.get(0).getClass().equals(NamedVector.class)) {
+ Assert.assertTrue("not expecting cluster:" + ((NamedVector) vList.get(0)).getDelegate().asFormatString(),
+ reference.contains(((NamedVector) vList.get(0)).getDelegate().asFormatString()));
+ reference.remove(((NamedVector)vList.get(0)).getDelegate().asFormatString());
+ } else if (vList.get(0).getClass().equals(RandomAccessSparseVector.class)) {
+ Assert.assertTrue("not expecting cluster:" + vList.get(0).asFormatString(),
+ reference.contains(vList.get(0).asFormatString()));
+ reference.remove(vList.get(0).asFormatString());
+ }
+ }
+ }
+ Assert.assertEquals("Different number of empty clusters than expected!", 1, emptyCnt);
+ Assert.assertEquals("Different number of singletons than expected!", 2, singletonCnt);
+ Assert.assertEquals("Didn't match all reference clusters!", 0, reference.size());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java b/mr/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java
new file mode 100644
index 0000000..fc71ecf
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/fuzzykmeans/TestFuzzyKmeansClustering.java
@@ -0,0 +1,202 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.fuzzykmeans;
+
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.kmeans.TestKmeansClustering;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.io.Closeables;
+
+public final class TestFuzzyKmeansClustering extends MahoutTestCase {
+
+ private FileSystem fs;
+ private final DistanceMeasure measure = new EuclideanDistanceMeasure();
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ Configuration conf = getConfiguration();
+ fs = FileSystem.get(conf);
+ }
+
+ private static Vector tweakValue(Vector point) {
+ return point.plus(0.1);
+ }
+
+ @Test
+ public void testFuzzyKMeansSeqJob() throws Exception {
+ List<VectorWritable> points = TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE);
+
+ Path pointsPath = getTestTempDirPath("points");
+ Path clustersPath = getTestTempDirPath("clusters");
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
+
+ for (int k = 0; k < points.size(); k++) {
+ System.out.println("testKFuzzyKMeansMRJob k= " + k);
+ // pick k initial cluster centers at random
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs,
+ conf,
+ new Path(clustersPath, "part-00000"),
+ Text.class,
+ SoftCluster.class);
+ try {
+ for (int i = 0; i < k + 1; i++) {
+ Vector vec = tweakValue(points.get(i).get());
+ SoftCluster cluster = new SoftCluster(vec, i, measure);
+ /* add the center so the centroid will be correct upon output */
+ cluster.observe(cluster.getCenter(), 1);
+ // writer.write(cluster.getIdentifier() + '\t' + SoftCluster.formatCluster(cluster) + '\n');
+ writer.append(new Text(cluster.getIdentifier()), cluster);
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+
+ // now run the Job using the run() command line options.
+ Path output = getTestTempDirPath("output" + k);
+ /* FuzzyKMeansDriver.runJob(pointsPath,
+ clustersPath,
+ output,
+ EuclideanDistanceMeasure.class.getName(),
+ 0.001,
+ 2,
+ k + 1,
+ 2,
+ false,
+ true,
+ 0);
+ */
+ String[] args = {
+ optKey(DefaultOptionCreator.INPUT_OPTION), pointsPath.toString(),
+ optKey(DefaultOptionCreator.CLUSTERS_IN_OPTION),
+ clustersPath.toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION),
+ output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION),
+ "0.001",
+ optKey(DefaultOptionCreator.MAX_ITERATIONS_OPTION),
+ "2",
+ optKey(FuzzyKMeansDriver.M_OPTION),
+ "2.0",
+ optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.EMIT_MOST_LIKELY_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION),
+ optKey(DefaultOptionCreator.METHOD_OPTION),
+ DefaultOptionCreator.SEQUENTIAL_METHOD
+ };
+ FuzzyKMeansDriver.main(args);
+ long count = HadoopUtil.countRecords(new Path(output, "clusteredPoints/part-m-0"), conf);
+ assertTrue(count > 0);
+ }
+
+ }
+
+ @Test
+ public void testFuzzyKMeansMRJob() throws Exception {
+ List<VectorWritable> points = TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE);
+
+ Path pointsPath = getTestTempDirPath("points");
+ Path clustersPath = getTestTempDirPath("clusters");
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
+
+ for (int k = 0; k < points.size(); k++) {
+ System.out.println("testKFuzzyKMeansMRJob k= " + k);
+ // pick k initial cluster centers at random
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs,
+ conf,
+ new Path(clustersPath, "part-00000"),
+ Text.class,
+ SoftCluster.class);
+ try {
+ for (int i = 0; i < k + 1; i++) {
+ Vector vec = tweakValue(points.get(i).get());
+
+ SoftCluster cluster = new SoftCluster(vec, i, measure);
+ /* add the center so the centroid will be correct upon output */
+ cluster.observe(cluster.getCenter(), 1);
+ // writer.write(cluster.getIdentifier() + '\t' + SoftCluster.formatCluster(cluster) + '\n');
+ writer.append(new Text(cluster.getIdentifier()), cluster);
+
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+
+ // now run the Job using the run() command line options.
+ Path output = getTestTempDirPath("output" + k);
+ /* FuzzyKMeansDriver.runJob(pointsPath,
+ clustersPath,
+ output,
+ EuclideanDistanceMeasure.class.getName(),
+ 0.001,
+ 2,
+ k + 1,
+ 2,
+ false,
+ true,
+ 0);
+ */
+ String[] args = {
+ optKey(DefaultOptionCreator.INPUT_OPTION),
+ pointsPath.toString(),
+ optKey(DefaultOptionCreator.CLUSTERS_IN_OPTION),
+ clustersPath.toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION),
+ output.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION),
+ "0.001",
+ optKey(DefaultOptionCreator.MAX_ITERATIONS_OPTION),
+ "2",
+ optKey(FuzzyKMeansDriver.M_OPTION),
+ "2.0",
+ optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.EMIT_MOST_LIKELY_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION)
+ };
+ ToolRunner.run(getConfiguration(), new FuzzyKMeansDriver(), args);
+ long count = HadoopUtil.countRecords(new Path(output, "clusteredPoints/part-m-00000"), conf);
+ assertTrue(count > 0);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java b/mr/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java
new file mode 100644
index 0000000..fdcfd64
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/iterator/TestClusterClassifier.java
@@ -0,0 +1,238 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.iterator;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.AbstractCluster;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
+import org.apache.mahout.clustering.kmeans.TestKmeansClustering;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.CosineDistanceMeasure;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public final class TestClusterClassifier extends MahoutTestCase {
+
+ private static ClusterClassifier newDMClassifier() {
+ List<Cluster> models = Lists.newArrayList();
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ models.add(new DistanceMeasureCluster(new DenseVector(2).assign(1), 0, measure));
+ models.add(new DistanceMeasureCluster(new DenseVector(2), 1, measure));
+ models.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1), 2, measure));
+ return new ClusterClassifier(models, new KMeansClusteringPolicy());
+ }
+
+ private static ClusterClassifier newKlusterClassifier() {
+ List<Cluster> models = Lists.newArrayList();
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(1), 0, measure));
+ models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2), 1, measure));
+ models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(-1), 2, measure));
+ return new ClusterClassifier(models, new KMeansClusteringPolicy());
+ }
+
+ private static ClusterClassifier newCosineKlusterClassifier() {
+ List<Cluster> models = Lists.newArrayList();
+ DistanceMeasure measure = new CosineDistanceMeasure();
+ models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(1), 0, measure));
+ models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2), 1, measure));
+ models.add(new org.apache.mahout.clustering.kmeans.Kluster(new DenseVector(2).assign(-1), 2, measure));
+ return new ClusterClassifier(models, new KMeansClusteringPolicy());
+ }
+
+ private static ClusterClassifier newSoftClusterClassifier() {
+ List<Cluster> models = Lists.newArrayList();
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ models.add(new SoftCluster(new DenseVector(2).assign(1), 0, measure));
+ models.add(new SoftCluster(new DenseVector(2), 1, measure));
+ models.add(new SoftCluster(new DenseVector(2).assign(-1), 2, measure));
+ return new ClusterClassifier(models, new FuzzyKMeansClusteringPolicy());
+ }
+
+ private ClusterClassifier writeAndRead(ClusterClassifier classifier) throws IOException {
+ Path path = new Path(getTestTempDirPath(), "output");
+ classifier.writeToSeqFiles(path);
+ ClusterClassifier newClassifier = new ClusterClassifier();
+ newClassifier.readFromSeqFiles(getConfiguration(), path);
+ return newClassifier;
+ }
+
+ @Test
+ public void testDMClusterClassification() {
+ ClusterClassifier classifier = newDMClassifier();
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.2,0.6,0.2]", AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.493,0.296,0.211]", AbstractCluster.formatVector(pdf, null));
+ }
+
+ @Test
+ public void testClusterClassification() {
+ ClusterClassifier classifier = newKlusterClassifier();
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.2,0.6,0.2]", AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.493,0.296,0.211]", AbstractCluster.formatVector(pdf, null));
+ }
+
+ @Test
+ public void testSoftClusterClassification() {
+ ClusterClassifier classifier = newSoftClusterClassifier();
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.0,1.0,0.0]", AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.735,0.184,0.082]", AbstractCluster.formatVector(pdf, null));
+ }
+
+ @Test
+ public void testDMClassifierSerialization() throws Exception {
+ ClusterClassifier classifier = newDMClassifier();
+ ClusterClassifier classifierOut = writeAndRead(classifier);
+ assertEquals(classifier.getModels().size(), classifierOut.getModels().size());
+ assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass()
+ .getName());
+ }
+
+ @Test
+ public void testClusterClassifierSerialization() throws Exception {
+ ClusterClassifier classifier = newKlusterClassifier();
+ ClusterClassifier classifierOut = writeAndRead(classifier);
+ assertEquals(classifier.getModels().size(), classifierOut.getModels().size());
+ assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass()
+ .getName());
+ }
+
+ @Test
+ public void testSoftClusterClassifierSerialization() throws Exception {
+ ClusterClassifier classifier = newSoftClusterClassifier();
+ ClusterClassifier classifierOut = writeAndRead(classifier);
+ assertEquals(classifier.getModels().size(), classifierOut.getModels().size());
+ assertEquals(classifier.getModels().get(0).getClass().getName(), classifierOut.getModels().get(0).getClass()
+ .getName());
+ }
+
+ @Test
+ public void testClusterIteratorKMeans() {
+ List<Vector> data = TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE);
+ ClusterClassifier prior = newKlusterClassifier();
+ ClusterClassifier posterior = ClusterIterator.iterate(data, prior, 5);
+ assertEquals(3, posterior.getModels().size());
+ for (Cluster cluster : posterior.getModels()) {
+ System.out.println(cluster.asFormatString(null));
+ }
+ }
+
+ @Test
+ public void testClusterIteratorDirichlet() {
+ List<Vector> data = TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE);
+ ClusterClassifier prior = newKlusterClassifier();
+ ClusterClassifier posterior = ClusterIterator.iterate(data, prior, 5);
+ assertEquals(3, posterior.getModels().size());
+ for (Cluster cluster : posterior.getModels()) {
+ System.out.println(cluster.asFormatString(null));
+ }
+ }
+
+ @Test
+ public void testSeqFileClusterIteratorKMeans() throws IOException {
+ Path pointsPath = getTestTempDirPath("points");
+ Path priorPath = getTestTempDirPath("prior");
+ Path outPath = getTestTempDirPath("output");
+ Configuration conf = getConfiguration();
+ FileSystem fs = FileSystem.get(pointsPath.toUri(), conf);
+ List<VectorWritable> points = TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE);
+ ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
+ Path path = new Path(priorPath, "priorClassifier");
+ ClusterClassifier prior = newKlusterClassifier();
+ prior.writeToSeqFiles(path);
+ assertEquals(3, prior.getModels().size());
+ System.out.println("Prior");
+ for (Cluster cluster : prior.getModels()) {
+ System.out.println(cluster.asFormatString(null));
+ }
+ ClusterIterator.iterateSeq(conf, pointsPath, path, outPath, 5);
+
+ for (int i = 1; i <= 4; i++) {
+ System.out.println("Classifier-" + i);
+ ClusterClassifier posterior = new ClusterClassifier();
+ String name = i == 4 ? "clusters-4-final" : "clusters-" + i;
+ posterior.readFromSeqFiles(conf, new Path(outPath, name));
+ assertEquals(3, posterior.getModels().size());
+ for (Cluster cluster : posterior.getModels()) {
+ System.out.println(cluster.asFormatString(null));
+ }
+
+ }
+ }
+
+ @Test
+ public void testMRFileClusterIteratorKMeans() throws Exception {
+ Path pointsPath = getTestTempDirPath("points");
+ Path priorPath = getTestTempDirPath("prior");
+ Path outPath = getTestTempDirPath("output");
+ Configuration conf = getConfiguration();
+ FileSystem fs = FileSystem.get(pointsPath.toUri(), conf);
+ List<VectorWritable> points = TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE);
+ ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
+ Path path = new Path(priorPath, "priorClassifier");
+ ClusterClassifier prior = newKlusterClassifier();
+ prior.writeToSeqFiles(path);
+ ClusteringPolicy policy = new KMeansClusteringPolicy();
+ ClusterClassifier.writePolicy(policy, path);
+ assertEquals(3, prior.getModels().size());
+ System.out.println("Prior");
+ for (Cluster cluster : prior.getModels()) {
+ System.out.println(cluster.asFormatString(null));
+ }
+ ClusterIterator.iterateMR(conf, pointsPath, path, outPath, 5);
+
+ for (int i = 1; i <= 4; i++) {
+ System.out.println("Classifier-" + i);
+ ClusterClassifier posterior = new ClusterClassifier();
+ String name = i == 4 ? "clusters-4-final" : "clusters-" + i;
+ posterior.readFromSeqFiles(conf, new Path(outPath, name));
+ assertEquals(3, posterior.getModels().size());
+ for (Cluster cluster : posterior.getModels()) {
+ System.out.println(cluster.asFormatString(null));
+ }
+ }
+ }
+
+ @Test
+ public void testCosineKlusterClassification() {
+ ClusterClassifier classifier = newCosineKlusterClassifier();
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.333,0.333,0.333]", AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.429,0.429,0.143]", AbstractCluster.formatVector(pdf, null));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java b/mr/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
new file mode 100644
index 0000000..5666765
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/kmeans/TestKmeansClustering.java
@@ -0,0 +1,385 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.kmeans;
+
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.common.DummyOutputCollector;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+public final class TestKmeansClustering extends MahoutTestCase {
+
+ public static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {2, 2}, {3, 3}, {4, 4}, {5, 4}, {4, 5}, {5, 5}};
+
+ private static final int[][] EXPECTED_NUM_POINTS = { {9}, {4, 5}, {4, 4, 1}, {1, 2, 1, 5}, {1, 1, 1, 2, 4},
+ {1, 1, 1, 1, 1, 4}, {1, 1, 1, 1, 1, 2, 2}, {1, 1, 1, 1, 1, 1, 2, 1}, {1, 1, 1, 1, 1, 1, 1, 1, 1}};
+
+ private FileSystem fs;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ Configuration conf = getConfiguration();
+ fs = FileSystem.get(conf);
+ }
+
+ public static List<VectorWritable> getPointsWritable(double[][] raw) {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : raw) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+ public static List<VectorWritable> getPointsWritableDenseVector(double[][] raw) {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : raw) {
+ Vector vec = new DenseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+ public static List<Vector> getPoints(double[][] raw) {
+ List<Vector> points = Lists.newArrayList();
+ for (double[] fr : raw) {
+ Vector vec = new SequentialAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(vec);
+ }
+ return points;
+ }
+
+ /**
+ * Tests
+ * {@link KMeansClusterer#runKMeansIteration(Iterable, Iterable, DistanceMeasure, double)}
+ * ) single run convergence with a given distance threshold.
+ */
+ /*@Test
+ public void testRunKMeansIterationConvergesInOneRunWithGivenDistanceThreshold() {
+ double[][] rawPoints = { {0, 0}, {0, 0.25}, {0, 0.75}, {0, 1}};
+ List<Vector> points = getPoints(rawPoints);
+
+ ManhattanDistanceMeasure distanceMeasure = new ManhattanDistanceMeasure();
+ List<Kluster> clusters = Arrays.asList(new Kluster(points.get(0), 0, distanceMeasure), new Kluster(points.get(3),
+ 3, distanceMeasure));
+
+ // To converge in a single run, the given distance threshold should be
+ // greater than or equal to 0.125,
+ // since 0.125 will be the distance between center and centroid for the
+ // initial two clusters after one run.
+ double distanceThreshold = 0.25;
+
+ boolean converged = KMeansClusterer.runKMeansIteration(points, clusters, distanceMeasure, distanceThreshold);
+
+ Vector cluster1Center = clusters.get(0).getCenter();
+ assertEquals(0, cluster1Center.get(0), EPSILON);
+ assertEquals(0.125, cluster1Center.get(1), EPSILON);
+
+ Vector cluster2Center = clusters.get(1).getCenter();
+ assertEquals(0, cluster2Center.get(0), EPSILON);
+ assertEquals(0.875, cluster2Center.get(1), EPSILON);
+
+ assertTrue("KMeans iteration should be converged after a single run", converged);
+ }*/
+
+ /** Story: User wishes to run kmeans job on reference data */
+ @Test
+ public void testKMeansSeqJob() throws Exception {
+ DistanceMeasure measure = new EuclideanDistanceMeasure();
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ Path pointsPath = getTestTempDirPath("points");
+ Path clustersPath = getTestTempDirPath("clusters");
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(pointsPath, "file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(pointsPath, "file2"), fs, conf);
+ for (int k = 1; k < points.size(); k++) {
+ System.out.println("testKMeansMRJob k= " + k);
+ // pick k initial cluster centers at random
+ Path path = new Path(clustersPath, "part-00000");
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, Text.class, Kluster.class);
+ try {
+ for (int i = 0; i < k + 1; i++) {
+ Vector vec = points.get(i).get();
+
+ Kluster cluster = new Kluster(vec, i, measure);
+ // add the center so the centroid will be correct upon output
+ cluster.observe(cluster.getCenter(), 1);
+ writer.append(new Text(cluster.getIdentifier()), cluster);
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ // now run the Job
+ Path outputPath = getTestTempDirPath("output" + k);
+ String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), pointsPath.toString(),
+ optKey(DefaultOptionCreator.CLUSTERS_IN_OPTION), clustersPath.toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), outputPath.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION), "0.001",
+ optKey(DefaultOptionCreator.MAX_ITERATIONS_OPTION), "2", optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION), optKey(DefaultOptionCreator.METHOD_OPTION),
+ DefaultOptionCreator.SEQUENTIAL_METHOD};
+ ToolRunner.run(conf, new KMeansDriver(), args);
+
+ // now compare the expected clusters with actual
+ Path clusteredPointsPath = new Path(outputPath, "clusteredPoints");
+ int[] expect = EXPECTED_NUM_POINTS[k];
+ DummyOutputCollector<IntWritable,WeightedPropertyVectorWritable> collector = new DummyOutputCollector<IntWritable,WeightedPropertyVectorWritable>();
+ // The key is the clusterId, the value is the weighted vector
+ for (Pair<IntWritable,WeightedPropertyVectorWritable> record : new SequenceFileIterable<IntWritable,WeightedPropertyVectorWritable>(
+ new Path(clusteredPointsPath, "part-m-0"), conf)) {
+ collector.collect(record.getFirst(), record.getSecond());
+ }
+ assertEquals("clusters[" + k + ']', expect.length, collector.getKeys().size());
+ }
+ }
+
+ /** Story: User wishes to run kmeans job on reference data (DenseVector test) */
+ @Test
+ public void testKMeansSeqJobDenseVector() throws Exception {
+ DistanceMeasure measure = new EuclideanDistanceMeasure();
+ List<VectorWritable> points = getPointsWritableDenseVector(REFERENCE);
+
+ Path pointsPath = getTestTempDirPath("points");
+ Path clustersPath = getTestTempDirPath("clusters");
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(pointsPath, "file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(pointsPath, "file2"), fs, conf);
+ for (int k = 1; k < points.size(); k++) {
+ System.out.println("testKMeansMRJob k= " + k);
+ // pick k initial cluster centers at random
+ Path path = new Path(clustersPath, "part-00000");
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, Text.class, Kluster.class);
+ try {
+ for (int i = 0; i < k + 1; i++) {
+ Vector vec = points.get(i).get();
+
+ Kluster cluster = new Kluster(vec, i, measure);
+ // add the center so the centroid will be correct upon output
+ cluster.observe(cluster.getCenter(), 1);
+ writer.append(new Text(cluster.getIdentifier()), cluster);
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ // now run the Job
+ Path outputPath = getTestTempDirPath("output" + k);
+ String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), pointsPath.toString(),
+ optKey(DefaultOptionCreator.CLUSTERS_IN_OPTION), clustersPath.toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), outputPath.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION), "0.001",
+ optKey(DefaultOptionCreator.MAX_ITERATIONS_OPTION), "2", optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION), optKey(DefaultOptionCreator.METHOD_OPTION),
+ DefaultOptionCreator.SEQUENTIAL_METHOD};
+ ToolRunner.run(conf, new KMeansDriver(), args);
+
+ // now compare the expected clusters with actual
+ Path clusteredPointsPath = new Path(outputPath, "clusteredPoints");
+ int[] expect = EXPECTED_NUM_POINTS[k];
+ DummyOutputCollector<IntWritable,WeightedPropertyVectorWritable> collector = new DummyOutputCollector<IntWritable,WeightedPropertyVectorWritable>();
+ // The key is the clusterId, the value is the weighted vector
+ for (Pair<IntWritable,WeightedPropertyVectorWritable> record : new SequenceFileIterable<IntWritable,WeightedPropertyVectorWritable>(
+ new Path(clusteredPointsPath, "part-m-0"), conf)) {
+ collector.collect(record.getFirst(), record.getSecond());
+ }
+ assertEquals("clusters[" + k + ']', expect.length, collector.getKeys().size());
+ }
+ }
+
+ /** Story: User wishes to run kmeans job on reference data */
+ @Test
+ public void testKMeansMRJob() throws Exception {
+ DistanceMeasure measure = new EuclideanDistanceMeasure();
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ Path pointsPath = getTestTempDirPath("points");
+ Path clustersPath = getTestTempDirPath("clusters");
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(pointsPath, "file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(pointsPath, "file2"), fs, conf);
+ for (int k = 1; k < points.size(); k += 3) {
+ System.out.println("testKMeansMRJob k= " + k);
+ // pick k initial cluster centers at random
+ Path path = new Path(clustersPath, "part-00000");
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, Text.class, Kluster.class);
+
+ try {
+ for (int i = 0; i < k + 1; i++) {
+ Vector vec = points.get(i).get();
+
+ Kluster cluster = new Kluster(vec, i, measure);
+ // add the center so the centroid will be correct upon output
+ cluster.observe(cluster.getCenter(), 1);
+ writer.append(new Text(cluster.getIdentifier()), cluster);
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ // now run the Job
+ Path outputPath = getTestTempDirPath("output" + k);
+ String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), pointsPath.toString(),
+ optKey(DefaultOptionCreator.CLUSTERS_IN_OPTION), clustersPath.toString(),
+ optKey(DefaultOptionCreator.OUTPUT_OPTION), outputPath.toString(),
+ optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName(),
+ optKey(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION), "0.001",
+ optKey(DefaultOptionCreator.MAX_ITERATIONS_OPTION), "2", optKey(DefaultOptionCreator.CLUSTERING_OPTION),
+ optKey(DefaultOptionCreator.OVERWRITE_OPTION)};
+ ToolRunner.run(getConfiguration(), new KMeansDriver(), args);
+
+ // now compare the expected clusters with actual
+ Path clusteredPointsPath = new Path(outputPath, "clusteredPoints");
+ // assertEquals("output dir files?", 4, outFiles.length);
+ int[] expect = EXPECTED_NUM_POINTS[k];
+ DummyOutputCollector<IntWritable,WeightedPropertyVectorWritable> collector = new DummyOutputCollector<IntWritable,WeightedPropertyVectorWritable>();
+ // The key is the clusterId, the value is the weighted vector
+ for (Pair<IntWritable,WeightedPropertyVectorWritable> record : new SequenceFileIterable<IntWritable,WeightedPropertyVectorWritable>(
+ new Path(clusteredPointsPath, "part-m-00000"), conf)) {
+ collector.collect(record.getFirst(), record.getSecond());
+ }
+ assertEquals("clusters[" + k + ']', expect.length, collector.getKeys().size());
+ }
+ }
+
+ /**
+ * Story: User wants to use canopy clustering to input the initial clusters
+ * for kmeans job.
+ */
+ @Test
+ public void testKMeansWithCanopyClusterInput() throws Exception {
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ Path pointsPath = getTestTempDirPath("points");
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(pointsPath, "file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(pointsPath, "file2"), fs, conf);
+
+ Path outputPath = getTestTempDirPath("output");
+ // now run the Canopy job
+ CanopyDriver.run(conf, pointsPath, outputPath, new ManhattanDistanceMeasure(), 3.1, 2.1, false, 0.0, false);
+
+ DummyOutputCollector<Text, ClusterWritable> collector1 =
+ new DummyOutputCollector<Text, ClusterWritable>();
+
+ FileStatus[] outParts = FileSystem.get(conf).globStatus(
+ new Path(outputPath, "clusters-0-final/*-0*"));
+ for (FileStatus outPartStat : outParts) {
+ for (Pair<Text,ClusterWritable> record :
+ new SequenceFileIterable<Text,ClusterWritable>(
+ outPartStat.getPath(), conf)) {
+ collector1.collect(record.getFirst(), record.getSecond());
+ }
+ }
+
+ boolean got15 = false;
+ boolean got43 = false;
+ int count = 0;
+ for (Text k : collector1.getKeys()) {
+ count++;
+ List<ClusterWritable> vl = collector1.getValue(k);
+ assertEquals("non-singleton centroid!", 1, vl.size());
+ ClusterWritable clusterWritable = vl.get(0);
+ Vector v = clusterWritable.getValue().getCenter();
+ assertEquals("cetriod vector is wrong length", 2, v.size());
+ if ( (Math.abs(v.get(0) - 1.5) < EPSILON)
+ && (Math.abs(v.get(1) - 1.5) < EPSILON)
+ && !got15) {
+ got15 = true;
+ } else if ( (Math.abs(v.get(0) - 4.333333333333334) < EPSILON)
+ && (Math.abs(v.get(1) - 4.333333333333334) < EPSILON)
+ && !got43) {
+ got43 = true;
+ } else {
+ fail("got unexpected center: " + v + " [" + v.getClass().toString() + ']');
+ }
+ }
+ assertEquals("got unexpected number of centers", 2, count);
+
+ // now run the KMeans job
+ Path kmeansOutput = new Path(outputPath, "kmeans");
+ KMeansDriver.run(getConfiguration(), pointsPath, new Path(outputPath, "clusters-0-final"), kmeansOutput,
+ 0.001, 10, true, 0.0, false);
+
+ // now compare the expected clusters with actual
+ Path clusteredPointsPath = new Path(kmeansOutput, "clusteredPoints");
+ DummyOutputCollector<IntWritable,WeightedPropertyVectorWritable> collector = new DummyOutputCollector<IntWritable,WeightedPropertyVectorWritable>();
+
+ // The key is the clusterId, the value is the weighted vector
+ for (Pair<IntWritable,WeightedPropertyVectorWritable> record : new SequenceFileIterable<IntWritable,WeightedPropertyVectorWritable>(
+ new Path(clusteredPointsPath, "part-m-00000"), conf)) {
+ collector.collect(record.getFirst(), record.getSecond());
+ }
+
+ for (IntWritable k : collector.getKeys()) {
+ List<WeightedPropertyVectorWritable> wpvList = collector.getValue(k);
+ assertTrue("empty cluster!", !wpvList.isEmpty());
+ if (wpvList.get(0).getVector().get(0) <= 2.0) {
+ for (WeightedPropertyVectorWritable wv : wpvList) {
+ Vector v = wv.getVector();
+ int idx = v.maxValueIndex();
+ assertTrue("bad cluster!", v.get(idx) <= 2.0);
+ }
+ assertEquals("Wrong size cluster", 4, wpvList.size());
+ } else {
+ for (WeightedPropertyVectorWritable wv : wpvList) {
+ Vector v = wv.getVector();
+ int idx = v.minValueIndex();
+ assertTrue("bad cluster!", v.get(idx) > 2.0);
+ }
+ assertEquals("Wrong size cluster", 5, wpvList.size());
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/kmeans/TestRandomSeedGenerator.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/kmeans/TestRandomSeedGenerator.java b/mr/src/test/java/org/apache/mahout/clustering/kmeans/TestRandomSeedGenerator.java
new file mode 100644
index 0000000..5cb012a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/kmeans/TestRandomSeedGenerator.java
@@ -0,0 +1,169 @@
+ /**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.kmeans;
+
+import java.util.Collection;
+import java.util.List;
+
+import com.google.common.collect.Sets;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public final class TestRandomSeedGenerator extends MahoutTestCase {
+
+ private static final double[][] RAW = {{1, 1}, {2, 1}, {1, 2}, {2, 2},
+ {3, 3}, {4, 4}, {5, 4}, {4, 5}, {5, 5}};
+
+ private FileSystem fs;
+
+ private static List<VectorWritable> getPoints() {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : RAW) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ Configuration conf = getConfiguration();
+ fs = FileSystem.get(conf);
+ }
+
+ /** Story: test random seed generation generates 4 clusters with proper ids and data */
+ @Test
+ public void testRandomSeedGenerator() throws Exception {
+ List<VectorWritable> points = getPoints();
+ Job job = new Job();
+ Configuration conf = job.getConfiguration();
+ job.setMapOutputValueClass(VectorWritable.class);
+ Path input = getTestTempFilePath("random-input");
+ Path output = getTestTempDirPath("random-output");
+ ClusteringTestUtils.writePointsToFile(points, input, fs, conf);
+
+ RandomSeedGenerator.buildRandom(conf, input, output, 4, new ManhattanDistanceMeasure());
+
+ int clusterCount = 0;
+ Collection<Integer> set = Sets.newHashSet();
+ for (ClusterWritable clusterWritable :
+ new SequenceFileValueIterable<ClusterWritable>(new Path(output, "part-randomSeed"), true, conf)) {
+ clusterCount++;
+ Cluster cluster = clusterWritable.getValue();
+ int id = cluster.getId();
+ assertTrue(set.add(id)); // Validate unique id's
+
+ Vector v = cluster.getCenter();
+ assertVectorEquals(RAW[id], v); // Validate values match
+ }
+
+ assertEquals(4, clusterCount); // Validate sample count
+ }
+
+ /** Be sure that the buildRandomSeeded works in the same way as RandomSeedGenerator.buildRandom */
+ @Test
+ public void testRandomSeedGeneratorSeeded() throws Exception {
+ List<VectorWritable> points = getPoints();
+ Job job = new Job();
+ Configuration conf = job.getConfiguration();
+ job.setMapOutputValueClass(VectorWritable.class);
+ Path input = getTestTempFilePath("random-input");
+ Path output = getTestTempDirPath("random-output");
+ ClusteringTestUtils.writePointsToFile(points, input, fs, conf);
+
+ RandomSeedGenerator.buildRandom(conf, input, output, 4, new ManhattanDistanceMeasure(), 1L);
+
+ int clusterCount = 0;
+ Collection<Integer> set = Sets.newHashSet();
+ for (ClusterWritable clusterWritable :
+ new SequenceFileValueIterable<ClusterWritable>(new Path(output, "part-randomSeed"), true, conf)) {
+ clusterCount++;
+ Cluster cluster = clusterWritable.getValue();
+ int id = cluster.getId();
+ assertTrue(set.add(id)); // validate unique id's
+
+ Vector v = cluster.getCenter();
+ assertVectorEquals(RAW[id], v); // validate values match
+ }
+
+ assertEquals(4, clusterCount); // validate sample count
+ }
+
+ /** Test that initial clusters built with same random seed are reproduced */
+ @Test
+ public void testBuildRandomSeededSameInitalClusters() throws Exception {
+ List<VectorWritable> points = getPoints();
+ Job job = new Job();
+ Configuration conf = job.getConfiguration();
+ job.setMapOutputValueClass(VectorWritable.class);
+ Path input = getTestTempFilePath("random-input");
+ Path output = getTestTempDirPath("random-output");
+ ClusteringTestUtils.writePointsToFile(points, input, fs, conf);
+ long randSeed=1;
+
+ RandomSeedGenerator.buildRandom(conf, input, output, 4, new ManhattanDistanceMeasure(), randSeed);
+
+ int[] clusterIDSeq = new int[4];
+
+ /** run through all clusters once and set sequence of IDs */
+ int clusterCount = 0;
+ for (ClusterWritable clusterWritable :
+ new SequenceFileValueIterable<ClusterWritable>(new Path(output, "part-randomSeed"), true, conf)) {
+ Cluster cluster = clusterWritable.getValue();
+ clusterIDSeq[clusterCount] = cluster.getId();
+ clusterCount++;
+ }
+
+ /* Rebuild cluster and run through again making sure all IDs are in the same random sequence
+ * Needs a better test because in this case passes when seeded with 1 and 2 fails with 1, 3
+ * passes when set to two */
+ RandomSeedGenerator.buildRandom(conf, input, output, 4, new ManhattanDistanceMeasure(), randSeed); clusterCount = 0;
+ for (ClusterWritable clusterWritable :
+ new SequenceFileValueIterable<ClusterWritable>(new Path(output, "part-randomSeed"), true, conf)) {
+ Cluster cluster = clusterWritable.getValue();
+ // Make sure cluster ids are in same random sequence
+ assertEquals(clusterIDSeq[clusterCount], cluster.getId());
+ clusterCount++;
+ }
+ }
+
+ private static void assertVectorEquals(double[] raw, Vector v) {
+ assertEquals(raw.length, v.size());
+ for (int i = 0; i < raw.length; i++) {
+ assertEquals(raw[i], v.getQuick(i), EPSILON);
+ }
+ }
+}
[37/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java b/mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java
new file mode 100644
index 0000000..3d4c6d6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/builder/TreeBuilder.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.node.Node;
+
+import java.util.Random;
+
+/**
+ * Abstract base class for TreeBuilders
+ */
+public interface TreeBuilder {
+
+ /**
+ * Builds a Decision tree using the training data
+ *
+ * @param rng
+ * random-numbers generator
+ * @param data
+ * training data
+ * @return root Node
+ */
+ Node build(Random rng, Data data);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java
new file mode 100644
index 0000000..c1bddd9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/Data.java
@@ -0,0 +1,280 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Holds a list of vectors and their corresponding Dataset. contains various operations that deals with the
+ * vectors (subset, count,...)
+ *
+ */
+public class Data implements Cloneable {
+
+ private final List<Instance> instances;
+
+ private final Dataset dataset;
+
+ public Data(Dataset dataset) {
+ this.dataset = dataset;
+ this.instances = Lists.newArrayList();
+ }
+
+ public Data(Dataset dataset, List<Instance> instances) {
+ this.dataset = dataset;
+ this.instances = Lists.newArrayList(instances);
+ }
+
+ /**
+ * @return the number of elements
+ */
+ public int size() {
+ return instances.size();
+ }
+
+ /**
+ * @return true if this data contains no element
+ */
+ public boolean isEmpty() {
+ return instances.isEmpty();
+ }
+
+ /**
+ * @param v
+ * element whose presence in this list if to be searched
+ * @return true is this data contains the specified element.
+ */
+ public boolean contains(Instance v) {
+ return instances.contains(v);
+ }
+
+ /**
+ * Returns the element at the specified position
+ *
+ * @param index
+ * index of element to return
+ * @return the element at the specified position
+ * @throws IndexOutOfBoundsException
+ * if the index is out of range
+ */
+ public Instance get(int index) {
+ return instances.get(index);
+ }
+
+ /**
+ * @return the subset from this data that matches the given condition
+ */
+ public Data subset(Condition condition) {
+ List<Instance> subset = Lists.newArrayList();
+
+ for (Instance instance : instances) {
+ if (condition.isTrueFor(instance)) {
+ subset.add(instance);
+ }
+ }
+
+ return new Data(dataset, subset);
+ }
+
+ /**
+ * if data has N cases, sample N cases at random -but with replacement.
+ */
+ public Data bagging(Random rng) {
+ int datasize = size();
+ List<Instance> bag = Lists.newArrayListWithCapacity(datasize);
+
+ for (int i = 0; i < datasize; i++) {
+ bag.add(instances.get(rng.nextInt(datasize)));
+ }
+
+ return new Data(dataset, bag);
+ }
+
+ /**
+ * if data has N cases, sample N cases at random -but with replacement.
+ *
+ * @param sampled
+ * indicating which instance has been sampled
+ *
+ * @return sampled data
+ */
+ public Data bagging(Random rng, boolean[] sampled) {
+ int datasize = size();
+ List<Instance> bag = Lists.newArrayListWithCapacity(datasize);
+
+ for (int i = 0; i < datasize; i++) {
+ int index = rng.nextInt(datasize);
+ bag.add(instances.get(index));
+ sampled[index] = true;
+ }
+
+ return new Data(dataset, bag);
+ }
+
+ /**
+ * Splits the data in two, returns one part, and this gets the rest of the data. <b>VERY SLOW!</b>
+ */
+ public Data rsplit(Random rng, int subsize) {
+ List<Instance> subset = Lists.newArrayListWithCapacity(subsize);
+
+ for (int i = 0; i < subsize; i++) {
+ subset.add(instances.remove(rng.nextInt(instances.size())));
+ }
+
+ return new Data(dataset, subset);
+ }
+
+ /**
+ * checks if all the vectors have identical attribute values
+ *
+ * @return true is all the vectors are identical or the data is empty<br>
+ * false otherwise
+ */
+ public boolean isIdentical() {
+ if (isEmpty()) {
+ return true;
+ }
+
+ Instance instance = get(0);
+ for (int attr = 0; attr < dataset.nbAttributes(); attr++) {
+ for (int index = 1; index < size(); index++) {
+ if (get(index).get(attr) != instance.get(attr)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+ /**
+ * checks if all the vectors have identical label values
+ */
+ public boolean identicalLabel() {
+ if (isEmpty()) {
+ return true;
+ }
+
+ double label = dataset.getLabel(get(0));
+ for (int index = 1; index < size(); index++) {
+ if (dataset.getLabel(get(index)) != label) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ /**
+ * finds all distinct values of a given attribute
+ */
+ public double[] values(int attr) {
+ Collection<Double> result = Sets.newHashSet();
+
+ for (Instance instance : instances) {
+ result.add(instance.get(attr));
+ }
+
+ double[] values = new double[result.size()];
+
+ int index = 0;
+ for (Double value : result) {
+ values[index++] = value;
+ }
+
+ return values;
+ }
+
+ @Override
+ public Data clone() {
+ return new Data(dataset, Lists.newArrayList(instances));
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof Data)) {
+ return false;
+ }
+
+ Data data = (Data) obj;
+
+ return instances.equals(data.instances) && dataset.equals(data.dataset);
+ }
+
+ @Override
+ public int hashCode() {
+ return instances.hashCode() + dataset.hashCode();
+ }
+
+ /**
+ * extract the labels of all instances
+ */
+ public double[] extractLabels() {
+ double[] labels = new double[size()];
+
+ for (int index = 0; index < labels.length; index++) {
+ labels[index] = dataset.getLabel(get(index));
+ }
+
+ return labels;
+ }
+
+ /**
+ * finds the majority label, breaking ties randomly<br>
+ * This method can be used when the criterion variable is the categorical attribute.
+ *
+ * @return the majority label value
+ */
+ public int majorityLabel(Random rng) {
+ // count the frequency of each label value
+ int[] counts = new int[dataset.nblabels()];
+
+ for (int index = 0; index < size(); index++) {
+ counts[(int) dataset.getLabel(get(index))]++;
+ }
+
+ // find the label values that appears the most
+ return DataUtils.maxindex(rng, counts);
+ }
+
+ /**
+ * Counts the number of occurrences of each label value<br>
+ * This method can be used when the criterion variable is the categorical attribute.
+ *
+ * @param counts
+ * will contain the results, supposed to be initialized at 0
+ */
+ public void countLabels(int[] counts) {
+ for (int index = 0; index < size(); index++) {
+ counts[(int) dataset.getLabel(get(index))]++;
+ }
+ }
+
+ public Dataset getDataset() {
+ return dataset;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java
new file mode 100644
index 0000000..318c0d0
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/DataConverter.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import com.google.common.base.Preconditions;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.math.DenseVector;
+
+import java.util.regex.Pattern;
+
+/**
+ * Converts String to Instance using a Dataset
+ */
+public class DataConverter {
+
+ private static final Pattern COMMA_SPACE = Pattern.compile("[, ]");
+
+ private final Dataset dataset;
+
+ public DataConverter(Dataset dataset) {
+ this.dataset = dataset;
+ }
+
+ public Instance convert(CharSequence string) {
+ // all attributes (categorical, numerical, label), ignored
+ int nball = dataset.nbAttributes() + dataset.getIgnored().length;
+
+ String[] tokens = COMMA_SPACE.split(string);
+ Preconditions.checkArgument(tokens.length == nball,
+ "Wrong number of attributes in the string: " + tokens.length + ". Must be " + nball);
+
+ int nbattrs = dataset.nbAttributes();
+ DenseVector vector = new DenseVector(nbattrs);
+
+ int aId = 0;
+ for (int attr = 0; attr < nball; attr++) {
+ if (!ArrayUtils.contains(dataset.getIgnored(), attr)) {
+ String token = tokens[attr].trim();
+
+ if ("?".equals(token)) {
+ // missing value
+ return null;
+ }
+
+ if (dataset.isNumerical(aId)) {
+ vector.set(aId++, Double.parseDouble(token));
+ } else { // CATEGORICAL
+ vector.set(aId, dataset.valueOf(aId, token));
+ aId++;
+ }
+ }
+ }
+
+ return new Instance(vector);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java
new file mode 100644
index 0000000..8eed6cf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/DataLoader.java
@@ -0,0 +1,253 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.data.Dataset.Attribute;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Scanner;
+import java.util.Set;
+import java.util.regex.Pattern;
+
+/**
+ * Converts the input data to a Vector Array using the information given by the Dataset.<br>
+ * Generates for each line a Vector that contains :<br>
+ * <ul>
+ * <li>double parsed value for NUMERICAL attributes</li>
+ * <li>int value for CATEGORICAL and LABEL attributes</li>
+ * </ul>
+ * <br>
+ * adds an IGNORED first attribute that will contain a unique id for each instance, which is the line number
+ * of the instance in the input data
+ */
+public final class DataLoader {
+
+ private static final Logger log = LoggerFactory.getLogger(DataLoader.class);
+
+ private static final Pattern SEPARATORS = Pattern.compile("[, ]");
+
+ private DataLoader() {}
+
+ /**
+ * Converts a comma-separated String to a Vector.
+ *
+ * @param attrs
+ * attributes description
+ * @param values
+ * used to convert CATEGORICAL attribute values to Integer
+ * @return false if there are missing values '?' or NUMERICAL attribute values is not numeric
+ */
+ private static boolean parseString(Attribute[] attrs, Set<String>[] values, CharSequence string,
+ boolean regression) {
+ String[] tokens = SEPARATORS.split(string);
+ Preconditions.checkArgument(tokens.length == attrs.length,
+ "Wrong number of attributes in the string: " + tokens.length + ". Must be: " + attrs.length);
+
+ // extract tokens and check is there is any missing value
+ for (int attr = 0; attr < attrs.length; attr++) {
+ if (!attrs[attr].isIgnored() && "?".equals(tokens[attr])) {
+ return false; // missing value
+ }
+ }
+
+ for (int attr = 0; attr < attrs.length; attr++) {
+ if (!attrs[attr].isIgnored()) {
+ String token = tokens[attr];
+ if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) {
+ // update values
+ if (values[attr] == null) {
+ values[attr] = Sets.newHashSet();
+ }
+ values[attr].add(token);
+ } else {
+ try {
+ Double.parseDouble(token);
+ } catch (NumberFormatException e) {
+ return false;
+ }
+ }
+ }
+ }
+
+ return true;
+ }
+
+ /**
+ * Loads the data from a file
+ *
+ * @param fs
+ * file system
+ * @param fpath
+ * data file path
+ * @throws IOException
+ * if any problem is encountered
+ */
+
+ public static Data loadData(Dataset dataset, FileSystem fs, Path fpath) throws IOException {
+ FSDataInputStream input = fs.open(fpath);
+ Scanner scanner = new Scanner(input, "UTF-8");
+
+ List<Instance> instances = Lists.newArrayList();
+
+ DataConverter converter = new DataConverter(dataset);
+
+ while (scanner.hasNextLine()) {
+ String line = scanner.nextLine();
+ if (!line.isEmpty()) {
+ Instance instance = converter.convert(line);
+ if (instance != null) {
+ instances.add(instance);
+ } else {
+ // missing values found
+ log.warn("{}: missing values", instances.size());
+ }
+ } else {
+ log.warn("{}: empty string", instances.size());
+ }
+ }
+
+ scanner.close();
+ return new Data(dataset, instances);
+ }
+
+
+ /** Loads the data from multiple paths specified by pathes */
+ public static Data loadData(Dataset dataset, FileSystem fs, Path[] pathes) throws IOException {
+ List<Instance> instances = Lists.newArrayList();
+
+ for (Path path : pathes) {
+ Data loadedData = loadData(dataset, fs, path);
+ for (int index = 0; index <= loadedData.size(); index++) {
+ instances.add(loadedData.get(index));
+ }
+ }
+ return new Data(dataset, instances);
+ }
+
+ /** Loads the data from a String array */
+ public static Data loadData(Dataset dataset, String[] data) {
+ List<Instance> instances = Lists.newArrayList();
+
+ DataConverter converter = new DataConverter(dataset);
+
+ for (String line : data) {
+ if (!line.isEmpty()) {
+ Instance instance = converter.convert(line);
+ if (instance != null) {
+ instances.add(instance);
+ } else {
+ // missing values found
+ log.warn("{}: missing values", instances.size());
+ }
+ } else {
+ log.warn("{}: empty string", instances.size());
+ }
+ }
+
+ return new Data(dataset, instances);
+ }
+
+ /**
+ * Generates the Dataset by parsing the entire data
+ *
+ * @param descriptor attributes description
+ * @param regression if true, the label is numerical
+ * @param fs file system
+ * @param path data path
+ */
+ public static Dataset generateDataset(CharSequence descriptor,
+ boolean regression,
+ FileSystem fs,
+ Path path) throws DescriptorException, IOException {
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ FSDataInputStream input = fs.open(path);
+ Scanner scanner = new Scanner(input, "UTF-8");
+
+ // used to convert CATEGORICAL attribute to Integer
+ @SuppressWarnings("unchecked")
+ Set<String>[] valsets = new Set[attrs.length];
+
+ int size = 0;
+ while (scanner.hasNextLine()) {
+ String line = scanner.nextLine();
+ if (!line.isEmpty()) {
+ if (parseString(attrs, valsets, line, regression)) {
+ size++;
+ }
+ }
+ }
+
+ scanner.close();
+
+ @SuppressWarnings("unchecked")
+ List<String>[] values = new List[attrs.length];
+ for (int i = 0; i < valsets.length; i++) {
+ if (valsets[i] != null) {
+ values[i] = Lists.newArrayList(valsets[i]);
+ }
+ }
+
+ return new Dataset(attrs, values, size, regression);
+ }
+
+ /**
+ * Generates the Dataset by parsing the entire data
+ *
+ * @param descriptor
+ * attributes description
+ */
+ public static Dataset generateDataset(CharSequence descriptor,
+ boolean regression,
+ String[] data) throws DescriptorException {
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ // used to convert CATEGORICAL attributes to Integer
+ @SuppressWarnings("unchecked")
+ Set<String>[] valsets = new Set[attrs.length];
+
+ int size = 0;
+ for (String aData : data) {
+ if (!aData.isEmpty()) {
+ if (parseString(attrs, valsets, aData, regression)) {
+ size++;
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ List<String>[] values = new List[attrs.length];
+ for (int i = 0; i < valsets.length; i++) {
+ if (valsets[i] != null) {
+ values[i] = Lists.newArrayList(valsets[i]);
+ }
+ }
+
+ return new Dataset(attrs, values, size, regression);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java
new file mode 100644
index 0000000..856d452
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/DataUtils.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Helper methods that deals with data lists and arrays of values
+ */
+public final class DataUtils {
+ private DataUtils() { }
+
+ /**
+ * Computes the sum of the values
+ *
+ */
+ public static int sum(int[] values) {
+ int sum = 0;
+ for (int value : values) {
+ sum += value;
+ }
+
+ return sum;
+ }
+
+ /**
+ * foreach i : array1[i] += array2[i]
+ */
+ public static void add(int[] array1, int[] array2) {
+ Preconditions.checkArgument(array1.length == array2.length, "array1.length != array2.length");
+ for (int index = 0; index < array1.length; index++) {
+ array1[index] += array2[index];
+ }
+ }
+
+ /**
+ * foreach i : array1[i] -= array2[i]
+ */
+ public static void dec(int[] array1, int[] array2) {
+ Preconditions.checkArgument(array1.length == array2.length, "array1.length != array2.length");
+ for (int index = 0; index < array1.length; index++) {
+ array1[index] -= array2[index];
+ }
+ }
+
+ /**
+ * return the index of the maximum of the array, breaking ties randomly
+ *
+ * @param rng
+ * used to break ties
+ * @return index of the maximum
+ */
+ public static int maxindex(Random rng, int[] values) {
+ int max = 0;
+ List<Integer> maxindices = Lists.newArrayList();
+
+ for (int index = 0; index < values.length; index++) {
+ if (values[index] > max) {
+ max = values[index];
+ maxindices.clear();
+ maxindices.add(index);
+ } else if (values[index] == max) {
+ maxindices.add(index);
+ }
+ }
+
+ return maxindices.size() > 1 ? maxindices.get(rng.nextInt(maxindices.size())) : maxindices.get(0);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java
new file mode 100644
index 0000000..d2bec37
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/Dataset.java
@@ -0,0 +1,421 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.codehaus.jackson.type.TypeReference;
+
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+/**
+ * Contains information about the attributes.
+ */
+public class Dataset {
+
+ /**
+ * Attributes type
+ */
+ public enum Attribute {
+ IGNORED,
+ NUMERICAL,
+ CATEGORICAL,
+ LABEL;
+
+ public boolean isNumerical() {
+ return this == NUMERICAL;
+ }
+
+ public boolean isCategorical() {
+ return this == CATEGORICAL;
+ }
+
+ public boolean isLabel() {
+ return this == LABEL;
+ }
+
+ public boolean isIgnored() {
+ return this == IGNORED;
+ }
+
+ private static Attribute fromString(String from) {
+ Attribute toReturn = LABEL;
+ if (NUMERICAL.toString().equalsIgnoreCase(from)) {
+ toReturn = NUMERICAL;
+ } else if (CATEGORICAL.toString().equalsIgnoreCase(from)) {
+ toReturn = CATEGORICAL;
+ } else if (IGNORED.toString().equalsIgnoreCase(from)) {
+ toReturn = IGNORED;
+ }
+ return toReturn;
+ }
+ }
+
+ private Attribute[] attributes;
+
+ /**
+ * list of ignored attributes
+ */
+ private int[] ignored;
+
+ /**
+ * distinct values (CATEGORIAL attributes only)
+ */
+ private String[][] values;
+
+ /**
+ * index of the label attribute in the loaded data (without ignored attributed)
+ */
+ private int labelId;
+
+ /**
+ * number of instances in the dataset
+ */
+ private int nbInstances;
+
+ /** JSON serial/de-serial-izer */
+ private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
+
+ // Some literals for JSON representation
+ static final String TYPE = "type";
+ static final String VALUES = "values";
+ static final String LABEL = "label";
+
+ protected Dataset() {}
+
+ /**
+ * Should only be called by a DataLoader
+ *
+ * @param attrs attributes description
+ * @param values distinct values for all CATEGORICAL attributes
+ */
+ Dataset(Attribute[] attrs, List<String>[] values, int nbInstances, boolean regression) {
+ validateValues(attrs, values);
+
+ int nbattrs = countAttributes(attrs);
+
+ // the label values are set apart
+ attributes = new Attribute[nbattrs];
+ this.values = new String[nbattrs][];
+ ignored = new int[attrs.length - nbattrs]; // nbignored = total - nbattrs
+
+ labelId = -1;
+ int ignoredId = 0;
+ int ind = 0;
+ for (int attr = 0; attr < attrs.length; attr++) {
+ if (attrs[attr].isIgnored()) {
+ ignored[ignoredId++] = attr;
+ continue;
+ }
+
+ if (attrs[attr].isLabel()) {
+ if (labelId != -1) {
+ throw new IllegalStateException("Label found more than once");
+ }
+ labelId = ind;
+ if (regression) {
+ attrs[attr] = Attribute.NUMERICAL;
+ } else {
+ attrs[attr] = Attribute.CATEGORICAL;
+ }
+ }
+
+ if (attrs[attr].isCategorical() || (!regression && attrs[attr].isLabel())) {
+ this.values[ind] = new String[values[attr].size()];
+ values[attr].toArray(this.values[ind]);
+ }
+
+ attributes[ind++] = attrs[attr];
+ }
+
+ if (labelId == -1) {
+ throw new IllegalStateException("Label not found");
+ }
+
+ this.nbInstances = nbInstances;
+ }
+
+ public int nbValues(int attr) {
+ return values[attr].length;
+ }
+
+ public String[] labels() {
+ return Arrays.copyOf(values[labelId], nblabels());
+ }
+
+ public int nblabels() {
+ return values[labelId].length;
+ }
+
+ public int getLabelId() {
+ return labelId;
+ }
+
+ public double getLabel(Instance instance) {
+ return instance.get(getLabelId());
+ }
+
+ public Attribute getAttribute(int attr) {
+ return attributes[attr];
+ }
+
+ /**
+ * Returns the code used to represent the label value in the data
+ *
+ * @param label label's value to code
+ * @return label's code
+ */
+ public int labelCode(String label) {
+ return ArrayUtils.indexOf(values[labelId], label);
+ }
+
+ /**
+ * Returns the label value in the data
+ * This method can be used when the criterion variable is the categorical attribute.
+ *
+ * @param code label's code
+ * @return label's value
+ */
+ public String getLabelString(double code) {
+ // handle the case (prediction is NaN)
+ if (Double.isNaN(code)) {
+ return "unknown";
+ }
+ return values[labelId][(int) code];
+ }
+
+ @Override
+ public String toString() {
+ return "attributes=" + Arrays.toString(attributes);
+ }
+
+ /**
+ * Converts a token to its corresponding integer code for a given attribute
+ *
+ * @param attr attribute index
+ */
+ public int valueOf(int attr, String token) {
+ Preconditions.checkArgument(!isNumerical(attr), "Only for CATEGORICAL attributes");
+ Preconditions.checkArgument(values != null, "Values not found (equals null)");
+ return ArrayUtils.indexOf(values[attr], token);
+ }
+
+ public int[] getIgnored() {
+ return ignored;
+ }
+
+ /**
+ * @return number of attributes that are not IGNORED
+ */
+ private static int countAttributes(Attribute[] attrs) {
+ int nbattrs = 0;
+ for (Attribute attr : attrs) {
+ if (!attr.isIgnored()) {
+ nbattrs++;
+ }
+ }
+ return nbattrs;
+ }
+
+ private static void validateValues(Attribute[] attrs, List<String>[] values) {
+ Preconditions.checkArgument(attrs.length == values.length, "attrs.length != values.length");
+ for (int attr = 0; attr < attrs.length; attr++) {
+ Preconditions.checkArgument(!attrs[attr].isCategorical() || values[attr] != null,
+ "values not found for attribute " + attr);
+ }
+ }
+
+ /**
+ * @return number of attributes
+ */
+ public int nbAttributes() {
+ return attributes.length;
+ }
+
+ /**
+ * Is this a numerical attribute ?
+ *
+ * @param attr index of the attribute to check
+ * @return true if the attribute is numerical
+ */
+ public boolean isNumerical(int attr) {
+ return attributes[attr].isNumerical();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof Dataset)) {
+ return false;
+ }
+
+ Dataset dataset = (Dataset) obj;
+
+ if (!Arrays.equals(attributes, dataset.attributes)) {
+ return false;
+ }
+
+ for (int attr = 0; attr < nbAttributes(); attr++) {
+ if (!Arrays.equals(values[attr], dataset.values[attr])) {
+ return false;
+ }
+ }
+
+ return labelId == dataset.labelId && nbInstances == dataset.nbInstances;
+ }
+
+ @Override
+ public int hashCode() {
+ int hashCode = labelId + 31 * nbInstances;
+ for (Attribute attr : attributes) {
+ hashCode = 31 * hashCode + attr.hashCode();
+ }
+ for (String[] valueRow : values) {
+ if (valueRow == null) {
+ continue;
+ }
+ for (String value : valueRow) {
+ hashCode = 31 * hashCode + value.hashCode();
+ }
+ }
+ return hashCode;
+ }
+
+ /**
+ * Loads the dataset from a file
+ *
+ * @throws java.io.IOException
+ */
+ public static Dataset load(Configuration conf, Path path) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+ long bytesToRead = fs.getFileStatus(path).getLen();
+ byte[] buff = new byte[Long.valueOf(bytesToRead).intValue()];
+ FSDataInputStream input = fs.open(path);
+ try {
+ input.readFully(buff);
+ } finally {
+ Closeables.close(input, true);
+ }
+ String json = new String(buff, Charset.defaultCharset());
+ return fromJSON(json);
+ }
+
+
+ /**
+ * Serialize this instance to JSON
+ * @return some JSON
+ */
+ public String toJSON() {
+ List<Map<String, Object>> toWrite = Lists.newLinkedList();
+ // attributes does not include ignored columns and it does include the class label
+ int ignoredCount = 0;
+ for (int i = 0; i < attributes.length + ignored.length; i++) {
+ Map<String, Object> attribute;
+ int attributesIndex = i - ignoredCount;
+ if (ignoredCount < ignored.length && i == ignored[ignoredCount]) {
+ // fill in ignored atttribute
+ attribute = getMap(Attribute.IGNORED, null, false);
+ ignoredCount++;
+ } else if (attributesIndex == labelId) {
+ // fill in the label
+ attribute = getMap(attributes[attributesIndex], values[attributesIndex], true);
+ } else {
+ // normal attribute
+ attribute = getMap(attributes[attributesIndex], values[attributesIndex], false);
+ }
+ toWrite.add(attribute);
+ }
+ try {
+ return OBJECT_MAPPER.writeValueAsString(toWrite);
+ } catch (Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+
+ /**
+ * De-serialize an instance from a string
+ * @param json From which an instance is created
+ * @return A shiny new Dataset
+ */
+ public static Dataset fromJSON(String json) {
+ List<Map<String, Object>> fromJSON;
+ try {
+ fromJSON = OBJECT_MAPPER.readValue(json, new TypeReference<List<Map<String, Object>>>() {});
+ } catch (Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ List<Attribute> attributes = Lists.newLinkedList();
+ List<Integer> ignored = Lists.newLinkedList();
+ String[][] nominalValues = new String[fromJSON.size()][];
+ Dataset dataset = new Dataset();
+ for (int i = 0; i < fromJSON.size(); i++) {
+ Map<String, Object> attribute = fromJSON.get(i);
+ if (Attribute.fromString((String) attribute.get(TYPE)) == Attribute.IGNORED) {
+ ignored.add(i);
+ } else {
+ Attribute asAttribute = Attribute.fromString((String) attribute.get(TYPE));
+ attributes.add(asAttribute);
+ if ((Boolean) attribute.get(LABEL)) {
+ dataset.labelId = i - ignored.size();
+ }
+ if (attribute.get(VALUES) != null) {
+ List<String> get = (List<String>) attribute.get(VALUES);
+ String[] array = get.toArray(new String[get.size()]);
+ nominalValues[i - ignored.size()] = array;
+ }
+ }
+ }
+ dataset.attributes = attributes.toArray(new Attribute[attributes.size()]);
+ dataset.ignored = new int[ignored.size()];
+ dataset.values = nominalValues;
+ for (int i = 0; i < dataset.ignored.length; i++) {
+ dataset.ignored[i] = ignored.get(i);
+ }
+ return dataset;
+ }
+
+ /**
+ * Generate a map to describe an attribute
+ * @param type The type
+ * @param values - values
+ * @param isLabel - is a label
+ * @return map of (AttributeTypes, Values)
+ */
+ private Map<String, Object> getMap(Attribute type, String[] values, boolean isLabel) {
+ Map<String, Object> attribute = Maps.newHashMap();
+ attribute.put(TYPE, type.toString().toLowerCase(Locale.getDefault()));
+ attribute.put(VALUES, values);
+ attribute.put(LABEL, isLabel);
+ return attribute;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java
new file mode 100644
index 0000000..f4419f0
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorException.java
@@ -0,0 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+/**
+ * Exception thrown when parsing a descriptor
+ */
+public class DescriptorException extends Exception {
+ public DescriptorException(String msg) {
+ super(msg);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java
new file mode 100644
index 0000000..a2198b1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/DescriptorUtils.java
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import org.apache.mahout.classifier.df.data.Dataset.Attribute;
+
+import java.util.List;
+import java.util.Locale;
+
+/**
+ * Contains various methods that deal with descriptor strings
+ */
+public final class DescriptorUtils {
+
+ private static final Splitter SPACE = Splitter.on(' ').omitEmptyStrings();
+
+ private DescriptorUtils() { }
+
+ /**
+ * Parses a descriptor string and generates the corresponding array of Attributes
+ *
+ * @throws DescriptorException
+ * if a bad token is encountered
+ */
+ public static Attribute[] parseDescriptor(CharSequence descriptor) throws DescriptorException {
+ List<Attribute> attributes = Lists.newArrayList();
+ for (String token : SPACE.split(descriptor)) {
+ token = token.toUpperCase(Locale.ENGLISH);
+ if ("I".equals(token)) {
+ attributes.add(Attribute.IGNORED);
+ } else if ("N".equals(token)) {
+ attributes.add(Attribute.NUMERICAL);
+ } else if ("C".equals(token)) {
+ attributes.add(Attribute.CATEGORICAL);
+ } else if ("L".equals(token)) {
+ attributes.add(Attribute.LABEL);
+ } else {
+ throw new DescriptorException("Bad Token : " + token);
+ }
+ }
+ return attributes.toArray(new Attribute[attributes.size()]);
+ }
+
+ /**
+ * Generates a valid descriptor string from a user-friendly representation.<br>
+ * for example "3 N I N N 2 C L 5 I" generates "N N N I N N C C L I I I I I".<br>
+ * this useful when describing datasets with a large number of attributes
+ * @throws DescriptorException
+ */
+ public static String generateDescriptor(CharSequence description) throws DescriptorException {
+ return generateDescriptor(SPACE.split(description));
+ }
+
+ /**
+ * Generates a valid descriptor string from a list of tokens
+ * @throws DescriptorException
+ */
+ public static String generateDescriptor(Iterable<String> tokens) throws DescriptorException {
+ StringBuilder descriptor = new StringBuilder();
+
+ int multiplicator = 0;
+
+ for (String token : tokens) {
+ try {
+ // try to parse an integer
+ int number = Integer.parseInt(token);
+
+ if (number <= 0) {
+ throw new DescriptorException("Multiplicator (" + number + ") must be > 0");
+ }
+ if (multiplicator > 0) {
+ throw new DescriptorException("A multiplicator cannot be followed by another multiplicator");
+ }
+
+ multiplicator = number;
+ } catch (NumberFormatException e) {
+ // token is not a number
+ if (multiplicator == 0) {
+ multiplicator = 1;
+ }
+
+ for (int index = 0; index < multiplicator; index++) {
+ descriptor.append(token).append(' ');
+ }
+
+ multiplicator = 0;
+ }
+ }
+
+ return descriptor.toString().trim();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java
new file mode 100644
index 0000000..3abf124
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/Instance.java
@@ -0,0 +1,74 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * Represents one data instance.
+ */
+public class Instance {
+
+ /** attributes, except LABEL and IGNORED */
+ private final Vector attrs;
+
+ public Instance(Vector attrs) {
+ this.attrs = attrs;
+ }
+
+ /**
+ * Return the attribute at the specified position
+ *
+ * @param index
+ * position of the attribute to retrieve
+ * @return value of the attribute
+ */
+ public double get(int index) {
+ return attrs.getQuick(index);
+ }
+
+ /**
+ * Set the value at the given index
+ *
+ * @param value
+ * a double value to set
+ */
+ public void set(int index, double value) {
+ attrs.set(index, value);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof Instance)) {
+ return false;
+ }
+
+ Instance instance = (Instance) obj;
+
+ return /*id == instance.id &&*/ attrs.equals(instance.attrs);
+
+ }
+
+ @Override
+ public int hashCode() {
+ return /*id +*/ attrs.hashCode();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java
new file mode 100644
index 0000000..b199834
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Condition.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data.conditions;
+
+import org.apache.mahout.classifier.df.data.Instance;
+
+/**
+ * Condition on Instance
+ */
+public abstract class Condition {
+
+ /**
+ * Returns true is the checked instance matches the condition
+ *
+ * @param instance
+ * checked instance
+ * @return true is the checked instance matches the condition
+ */
+ public abstract boolean isTrueFor(Instance instance);
+
+ /**
+ * Condition that checks if the given attribute has a value "equal" to the given value
+ */
+ public static Condition equals(int attr, double value) {
+ return new Equals(attr, value);
+ }
+
+ /**
+ * Condition that checks if the given attribute has a value "lesser" than the given value
+ */
+ public static Condition lesser(int attr, double value) {
+ return new Lesser(attr, value);
+ }
+
+ /**
+ * Condition that checks if the given attribute has a value "greater or equal" than the given value
+ */
+ public static Condition greaterOrEquals(int attr, double value) {
+ return new GreaterOrEquals(attr, value);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java
new file mode 100644
index 0000000..73f4ef6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Equals.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data.conditions;
+
+import org.apache.mahout.classifier.df.data.Instance;
+
+/**
+ * True if a given attribute has a given value
+ */
+public class Equals extends Condition {
+
+ private final int attr;
+
+ private final double value;
+
+ public Equals(int attr, double value) {
+ this.attr = attr;
+ this.value = value;
+ }
+
+ @Override
+ public boolean isTrueFor(Instance instance) {
+ return instance.get(attr) == value;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java
new file mode 100644
index 0000000..2db3f2e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/GreaterOrEquals.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data.conditions;
+
+import org.apache.mahout.classifier.df.data.Instance;
+
+/**
+ * True if a given attribute has a value "greater or equal" than a given value
+ */
+public class GreaterOrEquals extends Condition {
+
+ private final int attr;
+
+ private final double value;
+
+ public GreaterOrEquals(int attr, double value) {
+ this.attr = attr;
+ this.value = value;
+ }
+
+ @Override
+ public boolean isTrueFor(Instance v) {
+ return v.get(attr) >= value;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java b/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java
new file mode 100644
index 0000000..4e49eb7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/data/conditions/Lesser.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data.conditions;
+
+import org.apache.mahout.classifier.df.data.Instance;
+
+/**
+ * True if a given attribute has a value "lesser" than a given value
+ */
+public class Lesser extends Condition {
+
+ private final int attr;
+
+ private final double value;
+
+ public Lesser(int attr, double value) {
+ this.attr = attr;
+ this.value = value;
+ }
+
+ @Override
+ public boolean isTrueFor(Instance instance) {
+ return instance.get(attr) < value;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java
new file mode 100644
index 0000000..da2448f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Builder.java
@@ -0,0 +1,332 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.StringUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+
+/**
+ * Base class for Mapred DecisionForest builders. Takes care of storing the parameters common to the mapred
+ * implementations.<br>
+ * The child classes must implement at least :
+ * <ul>
+ * <li>void configureJob(Job) : to further configure the job before its launch; and</li>
+ * <li>DecisionForest parseOutput(Job, PredictionCallback) : in order to convert the job outputs into a
+ * DecisionForest and its corresponding oob predictions</li>
+ * </ul>
+ *
+ */
+public abstract class Builder {
+
+ private static final Logger log = LoggerFactory.getLogger(Builder.class);
+
+ private final TreeBuilder treeBuilder;
+ private final Path dataPath;
+ private final Path datasetPath;
+ private final Long seed;
+ private final Configuration conf;
+ private String outputDirName = "output";
+
+ protected Builder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed, Configuration conf) {
+ this.treeBuilder = treeBuilder;
+ this.dataPath = dataPath;
+ this.datasetPath = datasetPath;
+ this.seed = seed;
+ this.conf = new Configuration(conf);
+ }
+
+ protected Path getDataPath() {
+ return dataPath;
+ }
+
+ /**
+ * Return the value of "mapred.map.tasks".
+ *
+ * @param conf
+ * configuration
+ * @return number of map tasks
+ */
+ public static int getNumMaps(Configuration conf) {
+ return conf.getInt("mapred.map.tasks", -1);
+ }
+
+ /**
+ * Used only for DEBUG purposes. if false, the mappers doesn't output anything, so the builder has nothing
+ * to process
+ *
+ * @param conf
+ * configuration
+ * @return true if the builder has to return output. false otherwise
+ */
+ protected static boolean isOutput(Configuration conf) {
+ return conf.getBoolean("debug.mahout.rf.output", true);
+ }
+
+ /**
+ * Returns the random seed
+ *
+ * @param conf
+ * configuration
+ * @return null if no seed is available
+ */
+ public static Long getRandomSeed(Configuration conf) {
+ String seed = conf.get("mahout.rf.random.seed");
+ if (seed == null) {
+ return null;
+ }
+
+ return Long.valueOf(seed);
+ }
+
+ /**
+ * Sets the random seed value
+ *
+ * @param conf
+ * configuration
+ * @param seed
+ * random seed
+ */
+ private static void setRandomSeed(Configuration conf, long seed) {
+ conf.setLong("mahout.rf.random.seed", seed);
+ }
+
+ public static TreeBuilder getTreeBuilder(Configuration conf) {
+ String string = conf.get("mahout.rf.treebuilder");
+ if (string == null) {
+ return null;
+ }
+
+ return StringUtils.fromString(string);
+ }
+
+ private static void setTreeBuilder(Configuration conf, TreeBuilder treeBuilder) {
+ conf.set("mahout.rf.treebuilder", StringUtils.toString(treeBuilder));
+ }
+
+ /**
+ * Get the number of trees for the map-reduce job.
+ *
+ * @param conf
+ * configuration
+ * @return number of trees to build
+ */
+ public static int getNbTrees(Configuration conf) {
+ return conf.getInt("mahout.rf.nbtrees", -1);
+ }
+
+ /**
+ * Set the number of trees to grow for the map-reduce job
+ *
+ * @param conf
+ * configuration
+ * @param nbTrees
+ * number of trees to build
+ * @throws IllegalArgumentException
+ * if (nbTrees <= 0)
+ */
+ public static void setNbTrees(Configuration conf, int nbTrees) {
+ Preconditions.checkArgument(nbTrees > 0, "nbTrees should be greater than 0");
+
+ conf.setInt("mahout.rf.nbtrees", nbTrees);
+ }
+
+ /**
+ * Sets the Output directory name, will be creating in the working directory
+ *
+ * @param name
+ * output dir. name
+ */
+ public void setOutputDirName(String name) {
+ outputDirName = name;
+ }
+
+ /**
+ * Output Directory name
+ *
+ * @param conf
+ * configuration
+ * @return output dir. path (%WORKING_DIRECTORY%/OUTPUT_DIR_NAME%)
+ * @throws IOException
+ * if we cannot get the default FileSystem
+ */
+ protected Path getOutputPath(Configuration conf) throws IOException {
+ // the output directory is accessed only by this class, so use the default
+ // file system
+ FileSystem fs = FileSystem.get(conf);
+ return new Path(fs.getWorkingDirectory(), outputDirName);
+ }
+
+ /**
+ * Helper method. Get a path from the DistributedCache
+ *
+ * @param conf
+ * configuration
+ * @param index
+ * index of the path in the DistributedCache files
+ * @return path from the DistributedCache
+ * @throws IOException
+ * if no path is found
+ */
+ public static Path getDistributedCacheFile(Configuration conf, int index) throws IOException {
+ Path[] files = HadoopUtil.getCachedFiles(conf);
+
+ if (files.length <= index) {
+ throw new IOException("path not found in the DistributedCache");
+ }
+
+ return files[index];
+ }
+
+ /**
+ * Helper method. Load a Dataset stored in the DistributedCache
+ *
+ * @param conf
+ * configuration
+ * @return loaded Dataset
+ * @throws IOException
+ * if we cannot retrieve the Dataset path from the DistributedCache, or the Dataset could not be
+ * loaded
+ */
+ public static Dataset loadDataset(Configuration conf) throws IOException {
+ Path datasetPath = getDistributedCacheFile(conf, 0);
+
+ return Dataset.load(conf, datasetPath);
+ }
+
+ /**
+ * Used by the inheriting classes to configure the job
+ *
+ *
+ * @param job
+ * Hadoop's Job
+ * @throws IOException
+ * if anything goes wrong while configuring the job
+ */
+ protected abstract void configureJob(Job job) throws IOException;
+
+ /**
+ * Sequential implementation should override this method to simulate the job execution
+ *
+ * @param job
+ * Hadoop's job
+ * @return true is the job succeeded
+ */
+ protected boolean runJob(Job job) throws ClassNotFoundException, IOException, InterruptedException {
+ return job.waitForCompletion(true);
+ }
+
+ /**
+ * Parse the output files to extract the trees and pass the predictions to the callback
+ *
+ * @param job
+ * Hadoop's job
+ * @return Built DecisionForest
+ * @throws IOException
+ * if anything goes wrong while parsing the output
+ */
+ protected abstract DecisionForest parseOutput(Job job) throws IOException;
+
+ public DecisionForest build(int nbTrees)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ // int numTrees = getNbTrees(conf);
+
+ Path outputPath = getOutputPath(conf);
+ FileSystem fs = outputPath.getFileSystem(conf);
+
+ // check the output
+ if (fs.exists(outputPath)) {
+ throw new IOException("Output path already exists : " + outputPath);
+ }
+
+ if (seed != null) {
+ setRandomSeed(conf, seed);
+ }
+ setNbTrees(conf, nbTrees);
+ setTreeBuilder(conf, treeBuilder);
+
+ // put the dataset into the DistributedCache
+ DistributedCache.addCacheFile(datasetPath.toUri(), conf);
+
+ Job job = new Job(conf, "decision forest builder");
+
+ log.debug("Configuring the job...");
+ configureJob(job);
+
+ log.debug("Running the job...");
+ if (!runJob(job)) {
+ log.error("Job failed!");
+ return null;
+ }
+
+ if (isOutput(conf)) {
+ log.debug("Parsing the output...");
+ DecisionForest forest = parseOutput(job);
+ HadoopUtil.delete(conf, outputPath);
+ return forest;
+ }
+
+ return null;
+ }
+
+ /**
+ * sort the splits into order based on size, so that the biggest go first.<br>
+ * This is the same code used by Hadoop's JobClient.
+ *
+ * @param splits
+ * input splits
+ */
+ public static void sortSplits(InputSplit[] splits) {
+ Arrays.sort(splits, new Comparator<InputSplit>() {
+ @Override
+ public int compare(InputSplit a, InputSplit b) {
+ try {
+ long left = a.getLength();
+ long right = b.getLength();
+ if (left == right) {
+ return 0;
+ } else if (left < right) {
+ return 1;
+ } else {
+ return -1;
+ }
+ } catch (IOException ie) {
+ throw new IllegalStateException("Problem getting input split size", ie);
+ } catch (InterruptedException ie) {
+ throw new IllegalStateException("Problem getting input split size", ie);
+ }
+ }
+ });
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java
new file mode 100644
index 0000000..b8e5c2d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/Classifier.java
@@ -0,0 +1,237 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.FileSplit;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Mapreduce implementation that classifies the Input data using a previousely built decision forest
+ */
+public class Classifier {
+
+ private static final Logger log = LoggerFactory.getLogger(Classifier.class);
+
+ private final Path forestPath;
+ private final Path inputPath;
+ private final Path datasetPath;
+ private final Configuration conf;
+ private final Path outputPath; // path that will containt the final output of the classifier
+ private final Path mappersOutputPath; // mappers will output here
+ private double[][] results;
+
+ public double[][] getResults() {
+ return results;
+ }
+
+ public Classifier(Path forestPath,
+ Path inputPath,
+ Path datasetPath,
+ Path outputPath,
+ Configuration conf) {
+ this.forestPath = forestPath;
+ this.inputPath = inputPath;
+ this.datasetPath = datasetPath;
+ this.outputPath = outputPath;
+ this.conf = conf;
+
+ mappersOutputPath = new Path(outputPath, "mappers");
+ }
+
+ private void configureJob(Job job) throws IOException {
+
+ job.setJarByClass(Classifier.class);
+
+ FileInputFormat.setInputPaths(job, inputPath);
+ FileOutputFormat.setOutputPath(job, mappersOutputPath);
+
+ job.setOutputKeyClass(DoubleWritable.class);
+ job.setOutputValueClass(Text.class);
+
+ job.setMapperClass(CMapper.class);
+ job.setNumReduceTasks(0); // no reducers
+
+ job.setInputFormatClass(CTextInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ }
+
+ public void run() throws IOException, ClassNotFoundException, InterruptedException {
+ FileSystem fs = FileSystem.get(conf);
+
+ // check the output
+ if (fs.exists(outputPath)) {
+ throw new IOException("Output path already exists : " + outputPath);
+ }
+
+ log.info("Adding the dataset to the DistributedCache");
+ // put the dataset into the DistributedCache
+ DistributedCache.addCacheFile(datasetPath.toUri(), conf);
+
+ log.info("Adding the decision forest to the DistributedCache");
+ DistributedCache.addCacheFile(forestPath.toUri(), conf);
+
+ Job job = new Job(conf, "decision forest classifier");
+
+ log.info("Configuring the job...");
+ configureJob(job);
+
+ log.info("Running the job...");
+ if (!job.waitForCompletion(true)) {
+ throw new IllegalStateException("Job failed!");
+ }
+
+ parseOutput(job);
+
+ HadoopUtil.delete(conf, mappersOutputPath);
+ }
+
+ /**
+ * Extract the prediction for each mapper and write them in the corresponding output file.
+ * The name of the output file is based on the name of the corresponding input file.
+ * Will compute the ConfusionMatrix if necessary.
+ */
+ private void parseOutput(JobContext job) throws IOException {
+ Configuration conf = job.getConfiguration();
+ FileSystem fs = mappersOutputPath.getFileSystem(conf);
+
+ Path[] outfiles = DFUtils.listOutputFiles(fs, mappersOutputPath);
+
+ // read all the output
+ List<double[]> resList = Lists.newArrayList();
+ for (Path path : outfiles) {
+ FSDataOutputStream ofile = null;
+ try {
+ for (Pair<DoubleWritable,Text> record : new SequenceFileIterable<DoubleWritable,Text>(path, true, conf)) {
+ double key = record.getFirst().get();
+ String value = record.getSecond().toString();
+ if (ofile == null) {
+ // this is the first value, it contains the name of the input file
+ ofile = fs.create(new Path(outputPath, value).suffix(".out"));
+ } else {
+ // The key contains the correct label of the data. The value contains a prediction
+ ofile.writeChars(value); // write the prediction
+ ofile.writeChar('\n');
+
+ resList.add(new double[]{key, Double.valueOf(value)});
+ }
+ }
+ } finally {
+ Closeables.close(ofile, false);
+ }
+ }
+ results = new double[resList.size()][2];
+ resList.toArray(results);
+ }
+
+ /**
+ * TextInputFormat that does not split the input files. This ensures that each input file is processed by one single
+ * mapper.
+ */
+ private static class CTextInputFormat extends TextInputFormat {
+ @Override
+ protected boolean isSplitable(JobContext jobContext, Path path) {
+ return false;
+ }
+ }
+
+ public static class CMapper extends Mapper<LongWritable, Text, DoubleWritable, Text> {
+
+ /** used to convert input values to data instances */
+ private DataConverter converter;
+ private DecisionForest forest;
+ private final Random rng = RandomUtils.getRandom();
+ private boolean first = true;
+ private final Text lvalue = new Text();
+ private Dataset dataset;
+ private final DoubleWritable lkey = new DoubleWritable();
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context); //To change body of overridden methods use File | Settings | File Templates.
+
+ Configuration conf = context.getConfiguration();
+
+ Path[] files = HadoopUtil.getCachedFiles(conf);
+
+ if (files.length < 2) {
+ throw new IOException("not enough paths in the DistributedCache");
+ }
+ dataset = Dataset.load(conf, files[0]);
+ converter = new DataConverter(dataset);
+
+ forest = DecisionForest.load(conf, files[1]);
+ if (forest == null) {
+ throw new InterruptedException("DecisionForest not found!");
+ }
+ }
+
+ @Override
+ protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
+ if (first) {
+ FileSplit split = (FileSplit) context.getInputSplit();
+ Path path = split.getPath(); // current split path
+ lvalue.set(path.getName());
+ lkey.set(key.get());
+ context.write(lkey, lvalue);
+
+ first = false;
+ }
+
+ String line = value.toString();
+ if (!line.isEmpty()) {
+ Instance instance = converter.convert(line);
+ double prediction = forest.classify(dataset, rng, instance);
+ lkey.set(dataset.getLabel(instance));
+ lvalue.set(Double.toString(prediction));
+ context.write(lkey, lvalue);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java
new file mode 100644
index 0000000..cfd93cd
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredMapper.java
@@ -0,0 +1,74 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.data.Dataset;
+
+import java.io.IOException;
+
+/**
+ * Base class for Mapred mappers. Loads common parameters from the job
+ */
+public class MapredMapper<KEYIN,VALUEIN,KEYOUT,VALUEOUT> extends Mapper<KEYIN,VALUEIN,KEYOUT,VALUEOUT> {
+
+ private boolean noOutput;
+
+ private TreeBuilder treeBuilder;
+
+ private Dataset dataset;
+
+ /**
+ *
+ * @return whether the mapper does estimate and output predictions
+ */
+ protected boolean isOutput() {
+ return !noOutput;
+ }
+
+ protected TreeBuilder getTreeBuilder() {
+ return treeBuilder;
+ }
+
+ protected Dataset getDataset() {
+ return dataset;
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+
+ Configuration conf = context.getConfiguration();
+
+ configure(!Builder.isOutput(conf), Builder.getTreeBuilder(conf), Builder
+ .loadDataset(conf));
+ }
+
+ /**
+ * Useful for testing
+ */
+ protected void configure(boolean noOutput, TreeBuilder treeBuilder, Dataset dataset) {
+ Preconditions.checkArgument(treeBuilder != null, "TreeBuilder not found in the Job parameters");
+ this.noOutput = noOutput;
+ this.treeBuilder = treeBuilder;
+ this.dataset = dataset;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java
new file mode 100644
index 0000000..b177ce5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/MapredOutput.java
@@ -0,0 +1,119 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.node.Node;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * Used by various implementation to return the results of a build.<br>
+ * Contains a grown tree and and its oob predictions.
+ */
+public class MapredOutput implements Writable, Cloneable {
+
+ private Node tree;
+
+ private int[] predictions;
+
+ public MapredOutput() {
+ }
+
+ public MapredOutput(Node tree, int[] predictions) {
+ this.tree = tree;
+ this.predictions = predictions;
+ }
+
+ public MapredOutput(Node tree) {
+ this(tree, null);
+ }
+
+ public Node getTree() {
+ return tree;
+ }
+
+ int[] getPredictions() {
+ return predictions;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ boolean readTree = in.readBoolean();
+ if (readTree) {
+ tree = Node.read(in);
+ }
+
+ boolean readPredictions = in.readBoolean();
+ if (readPredictions) {
+ predictions = DFUtils.readIntArray(in);
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeBoolean(tree != null);
+ if (tree != null) {
+ tree.write(out);
+ }
+
+ out.writeBoolean(predictions != null);
+ if (predictions != null) {
+ DFUtils.writeArray(out, predictions);
+ }
+ }
+
+ @Override
+ public MapredOutput clone() {
+ return new MapredOutput(tree, predictions);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof MapredOutput)) {
+ return false;
+ }
+
+ MapredOutput mo = (MapredOutput) obj;
+
+ return ((tree == null && mo.getTree() == null) || (tree != null && tree.equals(mo.getTree())))
+ && Arrays.equals(predictions, mo.getPredictions());
+ }
+
+ @Override
+ public int hashCode() {
+ int hashCode = tree == null ? 1 : tree.hashCode();
+ for (int prediction : predictions) {
+ hashCode = 31 * hashCode + prediction;
+ }
+ return hashCode;
+ }
+
+ @Override
+ public String toString() {
+ return "{" + tree + " | " + Arrays.toString(predictions) + '}';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java
new file mode 100644
index 0000000..573a1e0
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemBuilder.java
@@ -0,0 +1,113 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.inmem;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.classifier.df.mapreduce.MapredOutput;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * MapReduce implementation where each mapper loads a full copy of the data in-memory. The forest trees are
+ * splitted across all the mappers
+ */
+public class InMemBuilder extends Builder {
+
+ public InMemBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed, Configuration conf) {
+ super(treeBuilder, dataPath, datasetPath, seed, conf);
+ }
+
+ public InMemBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath) {
+ this(treeBuilder, dataPath, datasetPath, null, new Configuration());
+ }
+
+ @Override
+ protected void configureJob(Job job) throws IOException {
+ Configuration conf = job.getConfiguration();
+
+ job.setJarByClass(InMemBuilder.class);
+
+ FileOutputFormat.setOutputPath(job, getOutputPath(conf));
+
+ // put the data in the DistributedCache
+ DistributedCache.addCacheFile(getDataPath().toUri(), conf);
+
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(MapredOutput.class);
+
+ job.setMapperClass(InMemMapper.class);
+ job.setNumReduceTasks(0); // no reducers
+
+ job.setInputFormatClass(InMemInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ }
+
+ @Override
+ protected DecisionForest parseOutput(Job job) throws IOException {
+ Configuration conf = job.getConfiguration();
+
+ Map<Integer,MapredOutput> output = Maps.newHashMap();
+
+ Path outputPath = getOutputPath(conf);
+ FileSystem fs = outputPath.getFileSystem(conf);
+
+ Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);
+
+ // import the InMemOutputs
+ for (Path path : outfiles) {
+ for (Pair<IntWritable,MapredOutput> record : new SequenceFileIterable<IntWritable,MapredOutput>(path, conf)) {
+ output.put(record.getFirst().get(), record.getSecond());
+ }
+ }
+
+ return processOutput(output);
+ }
+
+ /**
+ * Process the output, extracting the trees
+ */
+ private static DecisionForest processOutput(Map<Integer,MapredOutput> output) {
+ List<Node> trees = Lists.newArrayList();
+
+ for (Map.Entry<Integer,MapredOutput> entry : output.entrySet()) {
+ MapredOutput value = entry.getValue();
+ trees.add(value.getTree());
+ }
+
+ return new DecisionForest(trees);
+ }
+}
[36/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java
new file mode 100644
index 0000000..a39218e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormat.java
@@ -0,0 +1,283 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.inmem;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.InputFormat;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.RecordReader;
+import org.apache.hadoop.mapreduce.TaskAttemptContext;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+import java.util.Locale;
+import java.util.Random;
+
+/**
+ * Custom InputFormat that generates InputSplits given the desired number of trees.<br>
+ * each input split contains a subset of the trees.<br>
+ * The number of splits is equal to the number of requested splits
+ */
+public class InMemInputFormat extends InputFormat<IntWritable,NullWritable> {
+
+ private static final Logger log = LoggerFactory.getLogger(InMemInputSplit.class);
+
+ private Random rng;
+
+ private Long seed;
+
+ private boolean isSingleSeed;
+
+ /**
+ * Used for DEBUG purposes only. if true and a seed is available, all the mappers use the same seed, thus
+ * all the mapper should take the same time to build their trees.
+ */
+ private static boolean isSingleSeed(Configuration conf) {
+ return conf.getBoolean("debug.mahout.rf.single.seed", false);
+ }
+
+ @Override
+ public RecordReader<IntWritable,NullWritable> createRecordReader(InputSplit split, TaskAttemptContext context)
+ throws IOException, InterruptedException {
+ Preconditions.checkArgument(split instanceof InMemInputSplit);
+ return new InMemRecordReader((InMemInputSplit) split);
+ }
+
+ @Override
+ public List<InputSplit> getSplits(JobContext context) throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+ int numSplits = conf.getInt("mapred.map.tasks", -1);
+
+ return getSplits(conf, numSplits);
+ }
+
+ public List<InputSplit> getSplits(Configuration conf, int numSplits) {
+ int nbTrees = Builder.getNbTrees(conf);
+ int splitSize = nbTrees / numSplits;
+
+ seed = Builder.getRandomSeed(conf);
+ isSingleSeed = isSingleSeed(conf);
+
+ if (rng != null && seed != null) {
+ log.warn("getSplits() was called more than once and the 'seed' is set, "
+ + "this can lead to no-repeatable behavior");
+ }
+
+ rng = seed == null || isSingleSeed ? null : RandomUtils.getRandom(seed);
+
+ int id = 0;
+
+ List<InputSplit> splits = Lists.newArrayListWithCapacity(numSplits);
+
+ for (int index = 0; index < numSplits - 1; index++) {
+ splits.add(new InMemInputSplit(id, splitSize, nextSeed()));
+ id += splitSize;
+ }
+
+ // take care of the remainder
+ splits.add(new InMemInputSplit(id, nbTrees - id, nextSeed()));
+
+ return splits;
+ }
+
+ /**
+ * @return the seed for the next InputSplit
+ */
+ private Long nextSeed() {
+ if (seed == null) {
+ return null;
+ } else if (isSingleSeed) {
+ return seed;
+ } else {
+ return rng.nextLong();
+ }
+ }
+
+ public static class InMemRecordReader extends RecordReader<IntWritable,NullWritable> {
+
+ private final InMemInputSplit split;
+ private int pos;
+ private IntWritable key;
+ private NullWritable value;
+
+ public InMemRecordReader(InMemInputSplit split) {
+ this.split = split;
+ }
+
+ @Override
+ public float getProgress() throws IOException {
+ return pos == 0 ? 0.0f : (float) (pos - 1) / split.nbTrees;
+ }
+
+ @Override
+ public IntWritable getCurrentKey() throws IOException, InterruptedException {
+ return key;
+ }
+
+ @Override
+ public NullWritable getCurrentValue() throws IOException, InterruptedException {
+ return value;
+ }
+
+ @Override
+ public void initialize(InputSplit arg0, TaskAttemptContext arg1) throws IOException, InterruptedException {
+ key = new IntWritable();
+ value = NullWritable.get();
+ }
+
+ @Override
+ public boolean nextKeyValue() throws IOException, InterruptedException {
+ if (pos < split.nbTrees) {
+ key.set(split.firstId + pos);
+ pos++;
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ }
+
+ }
+
+ /**
+ * Custom InputSplit that indicates how many trees are built by each mapper
+ */
+ public static class InMemInputSplit extends InputSplit implements Writable {
+
+ private static final String[] NO_LOCATIONS = new String[0];
+
+ /** Id of the first tree of this split */
+ private int firstId;
+
+ private int nbTrees;
+
+ private Long seed;
+
+ public InMemInputSplit() { }
+
+ public InMemInputSplit(int firstId, int nbTrees, Long seed) {
+ this.firstId = firstId;
+ this.nbTrees = nbTrees;
+ this.seed = seed;
+ }
+
+ /**
+ * @return the Id of the first tree of this split
+ */
+ public int getFirstId() {
+ return firstId;
+ }
+
+ /**
+ * @return the number of trees
+ */
+ public int getNbTrees() {
+ return nbTrees;
+ }
+
+ /**
+ * @return the random seed or null if no seed is available
+ */
+ public Long getSeed() {
+ return seed;
+ }
+
+ @Override
+ public long getLength() throws IOException {
+ return nbTrees;
+ }
+
+ @Override
+ public String[] getLocations() throws IOException {
+ return NO_LOCATIONS;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof InMemInputSplit)) {
+ return false;
+ }
+
+ InMemInputSplit split = (InMemInputSplit) obj;
+
+ if (firstId != split.firstId || nbTrees != split.nbTrees) {
+ return false;
+ }
+ if (seed == null) {
+ return split.seed == null;
+ } else {
+ return seed.equals(split.seed);
+ }
+
+ }
+
+ @Override
+ public int hashCode() {
+ return firstId + nbTrees + (seed == null ? 0 : seed.intValue());
+ }
+
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "[firstId:%d, nbTrees:%d, seed:%d]", firstId, nbTrees, seed);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ firstId = in.readInt();
+ nbTrees = in.readInt();
+ boolean isSeed = in.readBoolean();
+ seed = isSeed ? in.readLong() : null;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(firstId);
+ out.writeInt(nbTrees);
+ out.writeBoolean(seed != null);
+ if (seed != null) {
+ out.writeLong(seed);
+ }
+ }
+
+ public static InMemInputSplit read(DataInput in) throws IOException {
+ InMemInputSplit split = new InMemInputSplit();
+ split.readFields(in);
+ return split;
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java
new file mode 100644
index 0000000..9e7e176
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemMapper.java
@@ -0,0 +1,105 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.inmem;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.mahout.classifier.df.Bagging;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.classifier.df.mapreduce.MapredMapper;
+import org.apache.mahout.classifier.df.mapreduce.MapredOutput;
+import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Random;
+
+/**
+ * In-memory mapper that grows the trees using a full copy of the data loaded in-memory. The number of trees
+ * to grow is determined by the current InMemInputSplit.
+ */
+public class InMemMapper extends MapredMapper<IntWritable,NullWritable,IntWritable,MapredOutput> {
+
+ private static final Logger log = LoggerFactory.getLogger(InMemMapper.class);
+
+ private Bagging bagging;
+
+ private Random rng;
+
+ /**
+ * Load the training data
+ */
+ private static Data loadData(Configuration conf, Dataset dataset) throws IOException {
+ Path dataPath = Builder.getDistributedCacheFile(conf, 1);
+ FileSystem fs = FileSystem.get(dataPath.toUri(), conf);
+ return DataLoader.loadData(dataset, fs, dataPath);
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+
+ Configuration conf = context.getConfiguration();
+
+ log.info("Loading the data...");
+ Data data = loadData(conf, getDataset());
+ log.info("Data loaded : {} instances", data.size());
+
+ bagging = new Bagging(getTreeBuilder(), data);
+ }
+
+ @Override
+ protected void map(IntWritable key,
+ NullWritable value,
+ Context context) throws IOException, InterruptedException {
+ map(key, context);
+ }
+
+ void map(IntWritable key, Context context) throws IOException, InterruptedException {
+
+ initRandom((InMemInputSplit) context.getInputSplit());
+
+ log.debug("Building...");
+ Node tree = bagging.build(rng);
+
+ if (isOutput()) {
+ log.debug("Outputing...");
+ MapredOutput mrOut = new MapredOutput(tree);
+
+ context.write(key, mrOut);
+ }
+ }
+
+ void initRandom(InMemInputSplit split) {
+ if (rng == null) { // first execution of this mapper
+ Long seed = split.getSeed();
+ log.debug("Initialising rng with seed : {}", seed);
+ rng = seed == null ? RandomUtils.getRandom() : RandomUtils.getRandom(seed);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java
new file mode 100644
index 0000000..61e65e8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/inmem/package-info.java
@@ -0,0 +1,22 @@
+/**
+ * <h2>In-memory mapreduce implementation of Random Decision Forests</h2>
+ *
+ * <p>Each mapper is responsible for growing a number of trees with a whole copy of the dataset loaded in memory,
+ * it uses the reference implementation's code to build each tree and estimate the oob error.</p>
+ *
+ * <p>The dataset is distributed to the slave nodes using the {@link org.apache.hadoop.filecache.DistributedCache}.
+ * A custom {@link org.apache.hadoop.mapreduce.InputFormat}
+ * ({@link org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat}) is configured with the
+ * desired number of trees and generates a number of {@link org.apache.hadoop.mapreduce.InputSplit}s
+ * equal to the configured number of maps.</p>
+ *
+ * <p>There is no need for reducers, each map outputs (the trees it built and, for each tree, the labels the
+ * tree predicted for each out-of-bag instance. This step has to be done in the mapper because only there we
+ * know which instances are o-o-b.</p>
+ *
+ * <p>The Forest builder ({@link org.apache.mahout.classifier.df.mapreduce.inmem.InMemBuilder}) is responsible
+ * for configuring and launching the job.
+ * At the end of the job it parses the output files and builds the corresponding
+ * {@link org.apache.mahout.classifier.df.DecisionForest}.</p>
+ */
+package org.apache.mahout.classifier.df.mapreduce.inmem;
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java
new file mode 100644
index 0000000..1c9a13b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilder.java
@@ -0,0 +1,157 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.classifier.df.mapreduce.MapredOutput;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Builds a random forest using partial data. Each mapper uses only the data given by its InputSplit
+ */
+public class PartialBuilder extends Builder {
+
+ private static final Logger log = LoggerFactory.getLogger(PartialBuilder.class);
+
+ public PartialBuilder(TreeBuilder treeBuilder, Path dataPath, Path datasetPath, Long seed) {
+ this(treeBuilder, dataPath, datasetPath, seed, new Configuration());
+ }
+
+ public PartialBuilder(TreeBuilder treeBuilder,
+ Path dataPath,
+ Path datasetPath,
+ Long seed,
+ Configuration conf) {
+ super(treeBuilder, dataPath, datasetPath, seed, conf);
+ }
+
+ @Override
+ protected void configureJob(Job job) throws IOException {
+ Configuration conf = job.getConfiguration();
+
+ job.setJarByClass(PartialBuilder.class);
+
+ FileInputFormat.setInputPaths(job, getDataPath());
+ FileOutputFormat.setOutputPath(job, getOutputPath(conf));
+
+ job.setOutputKeyClass(TreeID.class);
+ job.setOutputValueClass(MapredOutput.class);
+
+ job.setMapperClass(Step1Mapper.class);
+ job.setNumReduceTasks(0); // no reducers
+
+ job.setInputFormatClass(TextInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ // For this implementation to work, mapred.map.tasks needs to be set to the actual
+ // number of mappers Hadoop will use:
+ TextInputFormat inputFormat = new TextInputFormat();
+ List<?> splits = inputFormat.getSplits(job);
+ if (splits == null || splits.isEmpty()) {
+ log.warn("Unable to compute number of splits?");
+ } else {
+ int numSplits = splits.size();
+ log.info("Setting mapred.map.tasks = {}", numSplits);
+ conf.setInt("mapred.map.tasks", numSplits);
+ }
+ }
+
+ @Override
+ protected DecisionForest parseOutput(Job job) throws IOException {
+ Configuration conf = job.getConfiguration();
+
+ int numTrees = Builder.getNbTrees(conf);
+
+ Path outputPath = getOutputPath(conf);
+
+ TreeID[] keys = new TreeID[numTrees];
+ Node[] trees = new Node[numTrees];
+
+ processOutput(job, outputPath, keys, trees);
+
+ return new DecisionForest(Arrays.asList(trees));
+ }
+
+ /**
+ * Processes the output from the output path.<br>
+ *
+ * @param outputPath
+ * directory that contains the output of the job
+ * @param keys
+ * can be null
+ * @param trees
+ * can be null
+ * @throws java.io.IOException
+ */
+ protected static void processOutput(JobContext job,
+ Path outputPath,
+ TreeID[] keys,
+ Node[] trees) throws IOException {
+ Preconditions.checkArgument(keys == null && trees == null || keys != null && trees != null,
+ "if keys is null, trees should also be null");
+ Preconditions.checkArgument(keys == null || keys.length == trees.length, "keys.length != trees.length");
+
+ Configuration conf = job.getConfiguration();
+
+ FileSystem fs = outputPath.getFileSystem(conf);
+
+ Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);
+
+ // read all the outputs
+ int index = 0;
+ for (Path path : outfiles) {
+ for (Pair<TreeID,MapredOutput> record : new SequenceFileIterable<TreeID, MapredOutput>(path, conf)) {
+ TreeID key = record.getFirst();
+ MapredOutput value = record.getSecond();
+ if (keys != null) {
+ keys[index] = key;
+ }
+ if (trees != null) {
+ trees[index] = value.getTree();
+ }
+ index++;
+ }
+ }
+
+ // make sure we got all the keys/values
+ if (keys != null && index != keys.length) {
+ throw new IllegalStateException("Some key/values are missing from the output");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java
new file mode 100644
index 0000000..eaf0b15
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1Mapper.java
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.classifier.df.Bagging;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataConverter;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.classifier.df.mapreduce.MapredMapper;
+import org.apache.mahout.classifier.df.mapreduce.MapredOutput;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * First step of the Partial Data Builder. Builds the trees using the data available in the InputSplit.
+ * Predict the oob classes for each tree in its growing partition (input split).
+ */
+public class Step1Mapper extends MapredMapper<LongWritable,Text,TreeID,MapredOutput> {
+
+ private static final Logger log = LoggerFactory.getLogger(Step1Mapper.class);
+
+ /** used to convert input values to data instances */
+ private DataConverter converter;
+
+ private Random rng;
+
+ /** number of trees to be built by this mapper */
+ private int nbTrees;
+
+ /** id of the first tree */
+ private int firstTreeId;
+
+ /** mapper's partition */
+ private int partition;
+
+ /** will contain all instances if this mapper's split */
+ private final List<Instance> instances = Lists.newArrayList();
+
+ public int getFirstTreeId() {
+ return firstTreeId;
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+
+ configure(Builder.getRandomSeed(conf), conf.getInt("mapred.task.partition", -1),
+ Builder.getNumMaps(conf), Builder.getNbTrees(conf));
+ }
+
+ /**
+ * Useful when testing
+ *
+ * @param partition
+ * current mapper inputSplit partition
+ * @param numMapTasks
+ * number of running map tasks
+ * @param numTrees
+ * total number of trees in the forest
+ */
+ protected void configure(Long seed, int partition, int numMapTasks, int numTrees) {
+ converter = new DataConverter(getDataset());
+
+ // prepare random-numders generator
+ log.debug("seed : {}", seed);
+ if (seed == null) {
+ rng = RandomUtils.getRandom();
+ } else {
+ rng = RandomUtils.getRandom(seed);
+ }
+
+ // mapper's partition
+ Preconditions.checkArgument(partition >= 0, "Wrong partition ID: " + partition + ". Partition must be >= 0!");
+ this.partition = partition;
+
+ // compute number of trees to build
+ nbTrees = nbTrees(numMapTasks, numTrees, partition);
+
+ // compute first tree id
+ firstTreeId = 0;
+ for (int p = 0; p < partition; p++) {
+ firstTreeId += nbTrees(numMapTasks, numTrees, p);
+ }
+
+ log.debug("partition : {}", partition);
+ log.debug("nbTrees : {}", nbTrees);
+ log.debug("firstTreeId : {}", firstTreeId);
+ }
+
+ /**
+ * Compute the number of trees for a given partition. The first partitions may be longer
+ * than the rest because of the remainder.
+ *
+ * @param numMaps
+ * total number of maps (partitions)
+ * @param numTrees
+ * total number of trees to build
+ * @param partition
+ * partition to compute the number of trees for
+ */
+ public static int nbTrees(int numMaps, int numTrees, int partition) {
+ int treesPerMapper = numTrees / numMaps;
+ int remainder = numTrees - numMaps * treesPerMapper;
+ return treesPerMapper + (partition < remainder ? 1 : 0);
+ }
+
+ @Override
+ protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
+ instances.add(converter.convert(value.toString()));
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ // prepare the data
+ log.debug("partition: {} numInstances: {}", partition, instances.size());
+
+ Data data = new Data(getDataset(), instances);
+ Bagging bagging = new Bagging(getTreeBuilder(), data);
+
+ TreeID key = new TreeID();
+
+ log.debug("Building {} trees", nbTrees);
+ for (int treeId = 0; treeId < nbTrees; treeId++) {
+ log.debug("Building tree number : {}", treeId);
+
+ Node tree = bagging.build(rng);
+
+ key.set(partition, firstTreeId + treeId);
+
+ if (isOutput()) {
+ MapredOutput emOut = new MapredOutput(tree);
+ context.write(key, emOut);
+ }
+
+ context.progress();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java
new file mode 100644
index 0000000..d0ed5df
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeID.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.io.LongWritable;
+
+/**
+ * Indicates both the tree and the data partition used to grow the tree
+ */
+public class TreeID extends LongWritable implements Cloneable {
+
+ public static final int MAX_TREEID = 100000;
+
+ public TreeID() { }
+
+ public TreeID(int partition, int treeId) {
+ Preconditions.checkArgument(partition >= 0, "Wrong partition: " + partition + ". Partition must be >= 0!");
+ Preconditions.checkArgument(treeId >= 0, "Wrong treeId: " + treeId + ". TreeId must be >= 0!");
+ set(partition, treeId);
+ }
+
+ public void set(int partition, int treeId) {
+ set((long) partition * MAX_TREEID + treeId);
+ }
+
+ /**
+ * Data partition (InputSplit's index) that was used to grow the tree
+ */
+ public int partition() {
+ return (int) (get() / MAX_TREEID);
+ }
+
+ public int treeId() {
+ return (int) (get() % MAX_TREEID);
+ }
+
+ @Override
+ public TreeID clone() {
+ return new TreeID(partition(), treeId());
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java
new file mode 100644
index 0000000..e621c91
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/mapreduce/partial/package-info.java
@@ -0,0 +1,16 @@
+/**
+ * <h2>Partial-data mapreduce implementation of Random Decision Forests</h2>
+ *
+ * <p>The builder splits the data, using a FileInputSplit, among the mappers.
+ * Building the forest and estimating the oob error takes two job steps.</p>
+ *
+ * <p>In the first step, each mapper is responsible for growing a number of trees with its partition's,
+ * loading the data instances in its {@code map()} function, then building the trees in the {@code close()} method. It
+ * uses the reference implementation's code to build each tree and estimate the oob error.</p>
+ *
+ * <p>The second step is needed when estimating the oob error. Each mapper loads all the trees that does not
+ * belong to its own partition (were not built using the partition's data) and uses them to classify the
+ * partition's data instances. The data instances are loaded in the {@code map()} method and the classification
+ * is performed in the {@code close()} method.</p>
+ */
+package org.apache.mahout.classifier.df.mapreduce.partial;
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java b/mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java
new file mode 100644
index 0000000..3484866
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/node/CategoricalNode.java
@@ -0,0 +1,134 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.node;
+
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+
+public class CategoricalNode extends Node {
+
+ private int attr;
+ private double[] values;
+ private Node[] childs;
+
+ public CategoricalNode() {
+ }
+
+ public CategoricalNode(int attr, double[] values, Node[] childs) {
+ this.attr = attr;
+ this.values = values;
+ this.childs = childs;
+ }
+
+ @Override
+ public double classify(Instance instance) {
+ int index = ArrayUtils.indexOf(values, instance.get(attr));
+ if (index == -1) {
+ // value not available, we cannot predict
+ return Double.NaN;
+ }
+ return childs[index].classify(instance);
+ }
+
+ @Override
+ public long maxDepth() {
+ long max = 0;
+
+ for (Node child : childs) {
+ long depth = child.maxDepth();
+ if (depth > max) {
+ max = depth;
+ }
+ }
+
+ return 1 + max;
+ }
+
+ @Override
+ public long nbNodes() {
+ long nbNodes = 1;
+
+ for (Node child : childs) {
+ nbNodes += child.nbNodes();
+ }
+
+ return nbNodes;
+ }
+
+ @Override
+ protected Type getType() {
+ return Type.CATEGORICAL;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof CategoricalNode)) {
+ return false;
+ }
+
+ CategoricalNode node = (CategoricalNode) obj;
+
+ return attr == node.attr && Arrays.equals(values, node.values) && Arrays.equals(childs, node.childs);
+ }
+
+ @Override
+ public int hashCode() {
+ int hashCode = attr;
+ for (double value : values) {
+ hashCode = 31 * hashCode + (int) Double.doubleToLongBits(value);
+ }
+ for (Node node : childs) {
+ hashCode = 31 * hashCode + node.hashCode();
+ }
+ return hashCode;
+ }
+
+ @Override
+ protected String getString() {
+ StringBuilder buffer = new StringBuilder();
+
+ for (Node child : childs) {
+ buffer.append(child).append(',');
+ }
+
+ return buffer.toString();
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ attr = in.readInt();
+ values = DFUtils.readDoubleArray(in);
+ childs = DFUtils.readNodeArray(in);
+ }
+
+ @Override
+ protected void writeNode(DataOutput out) throws IOException {
+ out.writeInt(attr);
+ DFUtils.writeArray(out, values);
+ DFUtils.writeArray(out, childs);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java b/mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java
new file mode 100644
index 0000000..285a134
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/node/Leaf.java
@@ -0,0 +1,94 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.node;
+
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Represents a Leaf node
+ */
+public class Leaf extends Node {
+ private static final double EPSILON = 1.0e-6;
+
+ private double label;
+
+ Leaf() { }
+
+ public Leaf(double label) {
+ this.label = label;
+ }
+
+ @Override
+ public double classify(Instance instance) {
+ return label;
+ }
+
+ @Override
+ public long maxDepth() {
+ return 1;
+ }
+
+ @Override
+ public long nbNodes() {
+ return 1;
+ }
+
+ @Override
+ protected Type getType() {
+ return Type.LEAF;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof Leaf)) {
+ return false;
+ }
+
+ Leaf leaf = (Leaf) obj;
+
+ return Math.abs(label - leaf.label) < EPSILON;
+ }
+
+ @Override
+ public int hashCode() {
+ long bits = Double.doubleToLongBits(label);
+ return (int)(bits ^ (bits >>> 32));
+ }
+
+ @Override
+ protected String getString() {
+ return "";
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ label = in.readDouble();
+ }
+
+ @Override
+ protected void writeNode(DataOutput out) throws IOException {
+ out.writeDouble(label);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java b/mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java
new file mode 100644
index 0000000..cb6deb2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/node/Node.java
@@ -0,0 +1,95 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.node;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Represents an abstract node of a decision tree
+ */
+public abstract class Node implements Writable {
+
+ protected enum Type {
+ LEAF,
+ NUMERICAL,
+ CATEGORICAL
+ }
+
+ /**
+ * predicts the label for the instance
+ *
+ * @return -1 if the label cannot be predicted
+ */
+ public abstract double classify(Instance instance);
+
+ /**
+ * @return the total number of nodes of the tree
+ */
+ public abstract long nbNodes();
+
+ /**
+ * @return the maximum depth of the tree
+ */
+ public abstract long maxDepth();
+
+ protected abstract Type getType();
+
+ public static Node read(DataInput in) throws IOException {
+ Type type = Type.values()[in.readInt()];
+ Node node;
+
+ switch (type) {
+ case LEAF:
+ node = new Leaf();
+ break;
+ case NUMERICAL:
+ node = new NumericalNode();
+ break;
+ case CATEGORICAL:
+ node = new CategoricalNode();
+ break;
+ default:
+ throw new IllegalStateException("This implementation is not currently supported");
+ }
+
+ node.readFields(in);
+
+ return node;
+ }
+
+ @Override
+ public final String toString() {
+ return getType() + ":" + getString() + ';';
+ }
+
+ protected abstract String getString();
+
+ @Override
+ public final void write(DataOutput out) throws IOException {
+ out.writeInt(getType().ordinal());
+ writeNode(out);
+ }
+
+ protected abstract void writeNode(DataOutput out) throws IOException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java b/mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java
new file mode 100644
index 0000000..19b3e57
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/node/NumericalNode.java
@@ -0,0 +1,114 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.node;
+
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Represents a node that splits using a numerical attribute
+ */
+public class NumericalNode extends Node {
+ /** numerical attribute to split for */
+ private int attr;
+
+ /** split value */
+ private double split;
+
+ /** child node when attribute's value < split value */
+ private Node loChild;
+
+ /** child node when attribute's value >= split value */
+ private Node hiChild;
+
+ public NumericalNode() { }
+
+ public NumericalNode(int attr, double split, Node loChild, Node hiChild) {
+ this.attr = attr;
+ this.split = split;
+ this.loChild = loChild;
+ this.hiChild = hiChild;
+ }
+
+ @Override
+ public double classify(Instance instance) {
+ if (instance.get(attr) < split) {
+ return loChild.classify(instance);
+ } else {
+ return hiChild.classify(instance);
+ }
+ }
+
+ @Override
+ public long maxDepth() {
+ return 1 + Math.max(loChild.maxDepth(), hiChild.maxDepth());
+ }
+
+ @Override
+ public long nbNodes() {
+ return 1 + loChild.nbNodes() + hiChild.nbNodes();
+ }
+
+ @Override
+ protected Type getType() {
+ return Type.NUMERICAL;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof NumericalNode)) {
+ return false;
+ }
+
+ NumericalNode node = (NumericalNode) obj;
+
+ return attr == node.attr && split == node.split && loChild.equals(node.loChild) && hiChild.equals(node.hiChild);
+ }
+
+ @Override
+ public int hashCode() {
+ return attr + (int) Double.doubleToLongBits(split) + loChild.hashCode() + hiChild.hashCode();
+ }
+
+ @Override
+ protected String getString() {
+ return loChild.toString() + ',' + hiChild.toString();
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ attr = in.readInt();
+ split = in.readDouble();
+ loChild = Node.read(in);
+ hiChild = Node.read(in);
+ }
+
+ @Override
+ protected void writeNode(DataOutput out) throws IOException {
+ out.writeInt(attr);
+ out.writeDouble(split);
+ loChild.write(out);
+ hiChild.write(out);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java b/mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java
new file mode 100644
index 0000000..292b591
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/ref/SequentialBuilder.java
@@ -0,0 +1,77 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.ref;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.classifier.df.Bagging;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.node.Node;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Builds a Random Decision Forest using a given TreeBuilder to grow the trees
+ */
+public class SequentialBuilder {
+
+ private static final Logger log = LoggerFactory.getLogger(SequentialBuilder.class);
+
+ private final Random rng;
+
+ private final Bagging bagging;
+
+ /**
+ * Constructor
+ *
+ * @param rng
+ * random-numbers generator
+ * @param treeBuilder
+ * tree builder
+ * @param data
+ * training data
+ */
+ public SequentialBuilder(Random rng, TreeBuilder treeBuilder, Data data) {
+ this.rng = rng;
+ bagging = new Bagging(treeBuilder, data);
+ }
+
+ public DecisionForest build(int nbTrees) {
+ List<Node> trees = Lists.newArrayList();
+
+ for (int treeId = 0; treeId < nbTrees; treeId++) {
+ trees.add(bagging.build(rng));
+ logProgress(((float) treeId + 1) / nbTrees);
+ }
+
+ return new DecisionForest(trees);
+ }
+
+ private static void logProgress(float progress) {
+ int percent = (int) (progress * 100);
+ if (percent % 10 == 0) {
+ log.info("Building {}%", percent);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java b/mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java
new file mode 100644
index 0000000..38d3007
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/split/DefaultIgSplit.java
@@ -0,0 +1,117 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+
+import java.util.Arrays;
+
+/**
+ * Default, not optimized, implementation of IgSplit
+ */
+public class DefaultIgSplit extends IgSplit {
+
+ /** used by entropy() */
+ private int[] counts;
+
+ @Override
+ public Split computeSplit(Data data, int attr) {
+ if (data.getDataset().isNumerical(attr)) {
+ double[] values = data.values(attr);
+ double bestIg = -1;
+ double bestSplit = 0.0;
+
+ for (double value : values) {
+ double ig = numericalIg(data, attr, value);
+ if (ig > bestIg) {
+ bestIg = ig;
+ bestSplit = value;
+ }
+ }
+
+ return new Split(attr, bestIg, bestSplit);
+ } else {
+ double ig = categoricalIg(data, attr);
+
+ return new Split(attr, ig);
+ }
+ }
+
+ /**
+ * Computes the Information Gain for a CATEGORICAL attribute
+ */
+ double categoricalIg(Data data, int attr) {
+ double[] values = data.values(attr);
+ double hy = entropy(data); // H(Y)
+ double hyx = 0.0; // H(Y|X)
+ double invDataSize = 1.0 / data.size();
+
+ for (double value : values) {
+ Data subset = data.subset(Condition.equals(attr, value));
+ hyx += subset.size() * invDataSize * entropy(subset);
+ }
+
+ return hy - hyx;
+ }
+
+ /**
+ * Computes the Information Gain for a NUMERICAL attribute given a splitting value
+ */
+ double numericalIg(Data data, int attr, double split) {
+ double hy = entropy(data);
+ double invDataSize = 1.0 / data.size();
+
+ // LO subset
+ Data subset = data.subset(Condition.lesser(attr, split));
+ hy -= subset.size() * invDataSize * entropy(subset);
+
+ // HI subset
+ subset = data.subset(Condition.greaterOrEquals(attr, split));
+ hy -= subset.size() * invDataSize * entropy(subset);
+
+ return hy;
+ }
+
+ /**
+ * Computes the Entropy
+ */
+ protected double entropy(Data data) {
+ double invDataSize = 1.0 / data.size();
+
+ if (counts == null) {
+ counts = new int[data.getDataset().nblabels()];
+ }
+
+ Arrays.fill(counts, 0);
+ data.countLabels(counts);
+
+ double entropy = 0.0;
+ for (int label = 0; label < data.getDataset().nblabels(); label++) {
+ int count = counts[label];
+ if (count == 0) {
+ continue; // otherwise we get a NaN
+ }
+ double p = count * invDataSize;
+ entropy += -p * Math.log(p) / LOG2;
+ }
+
+ return entropy;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java b/mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java
new file mode 100644
index 0000000..da37cf3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/split/IgSplit.java
@@ -0,0 +1,34 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.classifier.df.data.Data;
+
+/**
+ * Computes the best split using the Information Gain measure
+ */
+public abstract class IgSplit {
+
+ static final double LOG2 = Math.log(2.0);
+
+ /**
+ * Computes the best split for the given attribute
+ */
+ public abstract Split computeSplit(Data data, int attr);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java b/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
new file mode 100644
index 0000000..7b15d2a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/split/OptIgSplit.java
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.commons.math3.stat.descriptive.rank.Percentile;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataUtils;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.TreeSet;
+
+/**
+ * <p>Optimized implementation of IgSplit.
+ * This class can be used when the criterion variable is the categorical attribute.</p>
+ *
+ * <p>This code was changed in MAHOUT-1419 to deal in sampled splits among numeric
+ * features to fix a performance problem. To generate some synthetic data that exercises
+ * the issue, try for example generating 4 features of Normal(0,1) values with a random
+ * boolean 0/1 categorical feature. In Scala:</p>
+ *
+ * {@code
+ * val r = new scala.util.Random()
+ * val pw = new java.io.PrintWriter("random.csv")
+ * (1 to 10000000).foreach(e =>
+ * pw.println(r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * r.nextDouble() + "," +
+ * (if (r.nextBoolean()) 1 else 0))
+ * )
+ * pw.close()
+ * }
+ */
+public class OptIgSplit extends IgSplit {
+
+ private static final int MAX_NUMERIC_SPLITS = 16;
+
+ @Override
+ public Split computeSplit(Data data, int attr) {
+ if (data.getDataset().isNumerical(attr)) {
+ return numericalSplit(data, attr);
+ } else {
+ return categoricalSplit(data, attr);
+ }
+ }
+
+ /**
+ * Computes the split for a CATEGORICAL attribute
+ */
+ private static Split categoricalSplit(Data data, int attr) {
+ double[] values = data.values(attr).clone();
+
+ double[] splitPoints = chooseCategoricalSplitPoints(values);
+
+ int numLabels = data.getDataset().nblabels();
+ int[][] counts = new int[splitPoints.length][numLabels];
+ int[] countAll = new int[numLabels];
+
+ computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+ int size = data.size();
+ double hy = entropy(countAll, size); // H(Y)
+ double hyx = 0.0; // H(Y|X)
+ double invDataSize = 1.0 / size;
+
+ for (int index = 0; index < splitPoints.length; index++) {
+ size = DataUtils.sum(counts[index]);
+ hyx += size * invDataSize * entropy(counts[index], size);
+ }
+
+ double ig = hy - hyx;
+ return new Split(attr, ig);
+ }
+
+ static void computeFrequencies(Data data,
+ int attr,
+ double[] splitPoints,
+ int[][] counts,
+ int[] countAll) {
+ Dataset dataset = data.getDataset();
+
+ for (int index = 0; index < data.size(); index++) {
+ Instance instance = data.get(index);
+ int label = (int) dataset.getLabel(instance);
+ double value = instance.get(attr);
+ int split = 0;
+ while (split < splitPoints.length && value > splitPoints[split]) {
+ split++;
+ }
+ if (split < splitPoints.length) {
+ counts[split][label]++;
+ } // Otherwise it's in the last split, which we don't need to count
+ countAll[label]++;
+ }
+ }
+
+ /**
+ * Computes the best split for a NUMERICAL attribute
+ */
+ static Split numericalSplit(Data data, int attr) {
+ double[] values = data.values(attr).clone();
+ Arrays.sort(values);
+
+ double[] splitPoints = chooseNumericSplitPoints(values);
+
+ int numLabels = data.getDataset().nblabels();
+ int[][] counts = new int[splitPoints.length][numLabels];
+ int[] countAll = new int[numLabels];
+ int[] countLess = new int[numLabels];
+
+ computeFrequencies(data, attr, splitPoints, counts, countAll);
+
+ int size = data.size();
+ double hy = entropy(countAll, size);
+ double invDataSize = 1.0 / size;
+
+ int best = -1;
+ double bestIg = -1.0;
+
+ // try each possible split value
+ for (int index = 0; index < splitPoints.length; index++) {
+ double ig = hy;
+
+ DataUtils.add(countLess, counts[index]);
+ DataUtils.dec(countAll, counts[index]);
+
+ // instance with attribute value < values[index]
+ size = DataUtils.sum(countLess);
+ ig -= size * invDataSize * entropy(countLess, size);
+ // instance with attribute value >= values[index]
+ size = DataUtils.sum(countAll);
+ ig -= size * invDataSize * entropy(countAll, size);
+
+ if (ig > bestIg) {
+ bestIg = ig;
+ best = index;
+ }
+ }
+
+ if (best == -1) {
+ throw new IllegalStateException("no best split found !");
+ }
+ return new Split(attr, bestIg, splitPoints[best]);
+ }
+
+ /**
+ * @return an array of values to split the numeric feature's values on when
+ * building candidate splits. When input size is <= MAX_NUMERIC_SPLITS + 1, it will
+ * return the averages between success values as split points. When larger, it will
+ * return MAX_NUMERIC_SPLITS approximate percentiles through the data.
+ */
+ private static double[] chooseNumericSplitPoints(double[] values) {
+ if (values.length <= 1) {
+ return values;
+ }
+ if (values.length <= MAX_NUMERIC_SPLITS + 1) {
+ double[] splitPoints = new double[values.length - 1];
+ for (int i = 1; i < values.length; i++) {
+ splitPoints[i-1] = (values[i] + values[i-1]) / 2.0;
+ }
+ return splitPoints;
+ }
+ Percentile distribution = new Percentile();
+ distribution.setData(values);
+ double[] percentiles = new double[MAX_NUMERIC_SPLITS];
+ for (int i = 0 ; i < percentiles.length; i++) {
+ double p = 100.0 * ((i + 1.0) / (MAX_NUMERIC_SPLITS + 1.0));
+ percentiles[i] = distribution.evaluate(p);
+ }
+ return percentiles;
+ }
+
+ private static double[] chooseCategoricalSplitPoints(double[] values) {
+ // There is no great reason to believe that categorical value order matters,
+ // but the original code worked this way, and it's not terrible in the absence
+ // of more sophisticated analysis
+ Collection<Double> uniqueOrderedCategories = new TreeSet<Double>();
+ for (double v : values) {
+ uniqueOrderedCategories.add(v);
+ }
+ double[] uniqueValues = new double[uniqueOrderedCategories.size()];
+ Iterator<Double> it = uniqueOrderedCategories.iterator();
+ for (int i = 0; i < uniqueValues.length; i++) {
+ uniqueValues[i] = it.next();
+ }
+ return uniqueValues;
+ }
+
+ /**
+ * Computes the Entropy
+ *
+ * @param counts counts[i] = numInstances with label i
+ * @param dataSize numInstances
+ */
+ private static double entropy(int[] counts, int dataSize) {
+ if (dataSize == 0) {
+ return 0.0;
+ }
+
+ double entropy = 0.0;
+
+ for (int count : counts) {
+ if (count > 0) {
+ double p = count / (double) dataSize;
+ entropy -= p * Math.log(p);
+ }
+ }
+
+ return entropy / LOG2;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java b/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
new file mode 100644
index 0000000..2974bcb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/split/RegressionSplit.java
@@ -0,0 +1,176 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Instance;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Comparator;
+
+/**
+ * Regression problem implementation of IgSplit. This class can be used when the criterion variable is the numerical
+ * attribute.
+ */
+public class RegressionSplit extends IgSplit {
+
+ /**
+ * Comparator for Instance sort
+ */
+ private static class InstanceComparator implements Comparator<Instance>, Serializable {
+ private final int attr;
+
+ InstanceComparator(int attr) {
+ this.attr = attr;
+ }
+
+ @Override
+ public int compare(Instance arg0, Instance arg1) {
+ return Double.compare(arg0.get(attr), arg1.get(attr));
+ }
+ }
+
+ @Override
+ public Split computeSplit(Data data, int attr) {
+ if (data.getDataset().isNumerical(attr)) {
+ return numericalSplit(data, attr);
+ } else {
+ return categoricalSplit(data, attr);
+ }
+ }
+
+ /**
+ * Computes the split for a CATEGORICAL attribute
+ */
+ private static Split categoricalSplit(Data data, int attr) {
+ FullRunningAverage[] ra = new FullRunningAverage[data.getDataset().nbValues(attr)];
+ double[] sk = new double[data.getDataset().nbValues(attr)];
+ for (int i = 0; i < ra.length; i++) {
+ ra[i] = new FullRunningAverage();
+ }
+ FullRunningAverage totalRa = new FullRunningAverage();
+ double totalSk = 0.0;
+
+ for (int i = 0; i < data.size(); i++) {
+ // computes the variance
+ Instance instance = data.get(i);
+ int value = (int) instance.get(attr);
+ double xk = data.getDataset().getLabel(instance);
+ if (ra[value].getCount() == 0) {
+ ra[value].addDatum(xk);
+ sk[value] = 0.0;
+ } else {
+ double mk = ra[value].getAverage();
+ ra[value].addDatum(xk);
+ sk[value] += (xk - mk) * (xk - ra[value].getAverage());
+ }
+
+ // total variance
+ if (i == 0) {
+ totalRa.addDatum(xk);
+ totalSk = 0.0;
+ } else {
+ double mk = totalRa.getAverage();
+ totalRa.addDatum(xk);
+ totalSk += (xk - mk) * (xk - totalRa.getAverage());
+ }
+ }
+
+ // computes the variance gain
+ double ig = totalSk;
+ for (double aSk : sk) {
+ ig -= aSk;
+ }
+
+ return new Split(attr, ig);
+ }
+
+ /**
+ * Computes the best split for a NUMERICAL attribute
+ */
+ private static Split numericalSplit(Data data, int attr) {
+ FullRunningAverage[] ra = new FullRunningAverage[2];
+ for (int i = 0; i < ra.length; i++) {
+ ra[i] = new FullRunningAverage();
+ }
+
+ // Instance sort
+ Instance[] instances = new Instance[data.size()];
+ for (int i = 0; i < data.size(); i++) {
+ instances[i] = data.get(i);
+ }
+ Arrays.sort(instances, new InstanceComparator(attr));
+
+ double[] sk = new double[2];
+ for (Instance instance : instances) {
+ double xk = data.getDataset().getLabel(instance);
+ if (ra[1].getCount() == 0) {
+ ra[1].addDatum(xk);
+ sk[1] = 0.0;
+ } else {
+ double mk = ra[1].getAverage();
+ ra[1].addDatum(xk);
+ sk[1] += (xk - mk) * (xk - ra[1].getAverage());
+ }
+ }
+ double totalSk = sk[1];
+
+ // find the best split point
+ double split = Double.NaN;
+ double preSplit = Double.NaN;
+ double bestVal = Double.MAX_VALUE;
+ double bestSk = 0.0;
+
+ // computes total variance
+ for (Instance instance : instances) {
+ double xk = data.getDataset().getLabel(instance);
+
+ if (instance.get(attr) > preSplit) {
+ double curVal = sk[0] / ra[0].getCount() + sk[1] / ra[1].getCount();
+ if (curVal < bestVal) {
+ bestVal = curVal;
+ bestSk = sk[0] + sk[1];
+ split = (instance.get(attr) + preSplit) / 2.0;
+ }
+ }
+
+ // computes the variance
+ if (ra[0].getCount() == 0) {
+ ra[0].addDatum(xk);
+ sk[0] = 0.0;
+ } else {
+ double mk = ra[0].getAverage();
+ ra[0].addDatum(xk);
+ sk[0] += (xk - mk) * (xk - ra[0].getAverage());
+ }
+
+ double mk = ra[1].getAverage();
+ ra[1].removeDatum(xk);
+ sk[1] -= (xk - mk) * (xk - ra[1].getAverage());
+
+ preSplit = instance.get(attr);
+ }
+
+ // computes the variance gain
+ double ig = totalSk - bestSk;
+
+ return new Split(attr, ig, split);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java b/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
new file mode 100644
index 0000000..bf079de
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/split/Split.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import java.util.Locale;
+
+/**
+ * Contains enough information to identify each split
+ */
+public final class Split {
+
+ private final int attr;
+ private final double ig;
+ private final double split;
+
+ public Split(int attr, double ig, double split) {
+ this.attr = attr;
+ this.ig = ig;
+ this.split = split;
+ }
+
+ public Split(int attr, double ig) {
+ this(attr, ig, Double.NaN);
+ }
+
+ /**
+ * @return attribute to split for
+ */
+ public int getAttr() {
+ return attr;
+ }
+
+ /**
+ * @return Information Gain of the split
+ */
+ public double getIg() {
+ return ig;
+ }
+
+ /**
+ * @return split value for NUMERICAL attributes
+ */
+ public double getSplit() {
+ return split;
+ }
+
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "attr: %d, ig: %f, split: %f", attr, ig, split);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java b/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
new file mode 100644
index 0000000..58814a8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/Describe.java
@@ -0,0 +1,148 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.DFUtils;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.DescriptorUtils;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * Generates a file descriptor for a given dataset
+ */
+public final class Describe {
+
+ private static final Logger log = LoggerFactory.getLogger(Describe.class);
+
+ private Describe() {}
+
+ public static void main(String[] args) throws IOException, DescriptorException {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option pathOpt = obuilder.withLongName("path").withShortName("p").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+ Option descriptorOpt = obuilder.withLongName("descriptor").withShortName("d").withRequired(true)
+ .withArgument(abuilder.withName("descriptor").withMinimum(1).create()).withDescription(
+ "data descriptor").create();
+
+ Option descPathOpt = obuilder.withLongName("file").withShortName("f").withRequired(true).withArgument(
+ abuilder.withName("file").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Path to generated descriptor file").create();
+
+ Option regOpt = obuilder.withLongName("regression").withDescription("Regression Problem").withShortName("r")
+ .create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(pathOpt).withOption(descPathOpt).withOption(
+ descriptorOpt).withOption(regOpt).withOption(helpOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+
+ String dataPath = cmdLine.getValue(pathOpt).toString();
+ String descPath = cmdLine.getValue(descPathOpt).toString();
+ List<String> descriptor = convert(cmdLine.getValues(descriptorOpt));
+ boolean regression = cmdLine.hasOption(regOpt);
+
+ log.debug("Data path : {}", dataPath);
+ log.debug("Descriptor path : {}", descPath);
+ log.debug("Descriptor : {}", descriptor);
+ log.debug("Regression : {}", regression);
+
+ runTool(dataPath, descriptor, descPath, regression);
+ } catch (OptionException e) {
+ log.warn(e.toString());
+ CommandLineUtil.printHelp(group);
+ }
+ }
+
+ private static void runTool(String dataPath, Iterable<String> description, String filePath, boolean regression)
+ throws DescriptorException, IOException {
+ log.info("Generating the descriptor...");
+ String descriptor = DescriptorUtils.generateDescriptor(description);
+
+ Path fPath = validateOutput(filePath);
+
+ log.info("generating the dataset...");
+ Dataset dataset = generateDataset(descriptor, dataPath, regression);
+
+ log.info("storing the dataset description");
+ String json = dataset.toJSON();
+ DFUtils.storeString(new Configuration(), fPath, json);
+ }
+
+ private static Dataset generateDataset(String descriptor, String dataPath, boolean regression) throws IOException,
+ DescriptorException {
+ Path path = new Path(dataPath);
+ FileSystem fs = path.getFileSystem(new Configuration());
+
+ return DataLoader.generateDataset(descriptor, regression, fs, path);
+ }
+
+ private static Path validateOutput(String filePath) throws IOException {
+ Path path = new Path(filePath);
+ FileSystem fs = path.getFileSystem(new Configuration());
+ if (fs.exists(path)) {
+ throw new IllegalStateException("Descriptor's file already exists");
+ }
+
+ return path;
+ }
+
+ private static List<String> convert(Collection<?> values) {
+ List<String> list = Lists.newArrayListWithCapacity(values.size());
+ for (Object value : values) {
+ list.add(value.toString());
+ }
+ return list;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java b/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
new file mode 100644
index 0000000..3b9d4ee
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/ForestVisualizer.java
@@ -0,0 +1,157 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * This tool is to visualize the Decision Forest
+ */
+public final class ForestVisualizer {
+
+ private static final Logger log = LoggerFactory.getLogger(ForestVisualizer.class);
+
+ private ForestVisualizer() {
+ }
+
+ public static String toString(DecisionForest forest, Dataset dataset, String[] attrNames) {
+
+ List<Node> trees;
+ try {
+ Method getTrees = forest.getClass().getDeclaredMethod("getTrees");
+ getTrees.setAccessible(true);
+ trees = (List<Node>) getTrees.invoke(forest);
+ } catch (IllegalAccessException e) {
+ throw new IllegalStateException(e);
+ } catch (InvocationTargetException e) {
+ throw new IllegalStateException(e);
+ } catch (NoSuchMethodException e) {
+ throw new IllegalStateException(e);
+ }
+
+ int cnt = 1;
+ StringBuilder buff = new StringBuilder();
+ for (Node tree : trees) {
+ buff.append("Tree[").append(cnt).append("]:");
+ buff.append(TreeVisualizer.toString(tree, dataset, attrNames));
+ buff.append('\n');
+ cnt++;
+ }
+ return buff.toString();
+ }
+
+ /**
+ * Decision Forest to String
+ * @param forestPath
+ * path to the Decision Forest
+ * @param datasetPath
+ * dataset path
+ * @param attrNames
+ * attribute names
+ */
+ public static String toString(String forestPath, String datasetPath, String[] attrNames) throws IOException {
+ Configuration conf = new Configuration();
+ DecisionForest forest = DecisionForest.load(conf, new Path(forestPath));
+ Dataset dataset = Dataset.load(conf, new Path(datasetPath));
+ return toString(forest, dataset, attrNames);
+ }
+
+ /**
+ * Print Decision Forest
+ * @param forestPath
+ * path to the Decision Forest
+ * @param datasetPath
+ * dataset path
+ * @param attrNames
+ * attribute names
+ */
+ public static void print(String forestPath, String datasetPath, String[] attrNames) throws IOException {
+ System.out.println(toString(forestPath, datasetPath, attrNames));
+ }
+
+ public static void main(String[] args) {
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true)
+ .withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
+ .withDescription("Dataset path").create();
+
+ Option modelOpt = obuilder.withLongName("model").withShortName("m").withRequired(true)
+ .withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create())
+ .withDescription("Path to the Decision Forest").create();
+
+ Option attrNamesOpt = obuilder.withLongName("names").withShortName("n").withRequired(false)
+ .withArgument(abuilder.withName("names").withMinimum(1).create())
+ .withDescription("Optional, Attribute names").create();
+
+ Option helpOpt = obuilder.withLongName("help").withShortName("h")
+ .withDescription("Print out help").create();
+
+ Group group = gbuilder.withName("Options").withOption(datasetOpt).withOption(modelOpt)
+ .withOption(attrNamesOpt).withOption(helpOpt).create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption("help")) {
+ CommandLineUtil.printHelp(group);
+ return;
+ }
+
+ String datasetName = cmdLine.getValue(datasetOpt).toString();
+ String modelName = cmdLine.getValue(modelOpt).toString();
+ String[] attrNames = null;
+ if (cmdLine.hasOption(attrNamesOpt)) {
+ Collection<String> names = (Collection<String>) cmdLine.getValues(attrNamesOpt);
+ if (!names.isEmpty()) {
+ attrNames = new String[names.size()];
+ names.toArray(attrNames);
+ }
+ }
+
+ print(modelName, datasetName, attrNames);
+ } catch (Exception e) {
+ log.error("Exception", e);
+ CommandLineUtil.printHelp(group);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java b/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
new file mode 100644
index 0000000..4586540
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java
@@ -0,0 +1,121 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.conf.Configured;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.CommandLineUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * Compute the frequency distribution of the "class label"<br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+public final class Frequencies extends Configured implements Tool {
+
+ private static final Logger log = LoggerFactory.getLogger(Frequencies.class);
+
+ private Frequencies() { }
+
+ @Override
+ public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+
+ Option dataOpt = obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data path").create();
+
+ Option datasetOpt = obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
+ abuilder.withName("path").withMinimum(1).create()).withDescription("dataset path").create();
+
+ Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h")
+ .create();
+
+ Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt).withOption(helpOpt)
+ .create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return 0;
+ }
+
+ String dataPath = cmdLine.getValue(dataOpt).toString();
+ String datasetPath = cmdLine.getValue(datasetOpt).toString();
+
+ log.debug("Data path : {}", dataPath);
+ log.debug("Dataset path : {}", datasetPath);
+
+ runTool(dataPath, datasetPath);
+ } catch (OptionException e) {
+ log.warn(e.toString(), e);
+ CommandLineUtil.printHelp(group);
+ }
+
+ return 0;
+ }
+
+ private void runTool(String data, String dataset) throws IOException,
+ ClassNotFoundException,
+ InterruptedException {
+
+ FileSystem fs = FileSystem.get(getConf());
+ Path workingDir = fs.getWorkingDirectory();
+
+ Path dataPath = new Path(data);
+ Path datasetPath = new Path(dataset);
+
+ log.info("Computing the frequencies...");
+ FrequenciesJob job = new FrequenciesJob(new Path(workingDir, "output"), dataPath, datasetPath);
+
+ int[][] counts = job.run(getConf());
+
+ // outputing the frequencies
+ log.info("counts[partition][class]");
+ for (int[] count : counts) {
+ log.info(Arrays.toString(count));
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new Frequencies(), args);
+ }
+
+}
[42/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericBooleanPrefUserBasedRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericBooleanPrefUserBasedRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericBooleanPrefUserBasedRecommender.java
new file mode 100644
index 0000000..15fcc9f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericBooleanPrefUserBasedRecommender.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+
+/**
+ * A variant on {@link GenericUserBasedRecommender} which is appropriate for use when no notion of preference
+ * value exists in the data.
+ */
+public final class GenericBooleanPrefUserBasedRecommender extends GenericUserBasedRecommender {
+
+ public GenericBooleanPrefUserBasedRecommender(DataModel dataModel,
+ UserNeighborhood neighborhood,
+ UserSimilarity similarity) {
+ super(dataModel, neighborhood, similarity);
+ }
+
+ /**
+ * This computation is in a technical sense, wrong, since in the domain of "boolean preference users" where
+ * all preference values are 1, this method should only ever return 1.0 or NaN. This isn't terribly useful
+ * however since it means results can't be ranked by preference value (all are 1). So instead this returns a
+ * sum of similarities to any other user in the neighborhood who has also rated the item.
+ */
+ @Override
+ protected float doEstimatePreference(long theUserID, long[] theNeighborhood, long itemID) throws TasteException {
+ if (theNeighborhood.length == 0) {
+ return Float.NaN;
+ }
+ DataModel dataModel = getDataModel();
+ UserSimilarity similarity = getSimilarity();
+ float totalSimilarity = 0.0f;
+ boolean foundAPref = false;
+ for (long userID : theNeighborhood) {
+ // See GenericItemBasedRecommender.doEstimatePreference() too
+ if (userID != theUserID && dataModel.getPreferenceValue(userID, itemID) != null) {
+ foundAPref = true;
+ totalSimilarity += (float) similarity.userSimilarity(theUserID, userID);
+ }
+ }
+ return foundAPref ? totalSimilarity : Float.NaN;
+ }
+
+ @Override
+ protected FastIDSet getAllOtherItems(long[] theNeighborhood, long theUserID, boolean includeKnownItems)
+ throws TasteException {
+ DataModel dataModel = getDataModel();
+ FastIDSet possibleItemIDs = new FastIDSet();
+ for (long userID : theNeighborhood) {
+ possibleItemIDs.addAll(dataModel.getItemIDsFromUser(userID));
+ }
+ if (!includeKnownItems) {
+ possibleItemIDs.removeAll(dataModel.getItemIDsFromUser(theUserID));
+ }
+ return possibleItemIDs;
+ }
+
+ @Override
+ public String toString() {
+ return "GenericBooleanPrefUserBasedRecommender";
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericItemBasedRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericItemBasedRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericItemBasedRecommender.java
new file mode 100644
index 0000000..413db4b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericItemBasedRecommender.java
@@ -0,0 +1,378 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Callable;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
+import org.apache.mahout.cf.taste.recommender.MostSimilarItemsCandidateItemsStrategy;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Rescorer;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.apache.mahout.common.LongPair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * A simple {@link org.apache.mahout.cf.taste.recommender.Recommender} which uses a given
+ * {@link org.apache.mahout.cf.taste.model.DataModel} and
+ * {@link org.apache.mahout.cf.taste.similarity.ItemSimilarity} to produce recommendations. This class
+ * represents Taste's support for item-based recommenders.
+ * </p>
+ *
+ * <p>
+ * The {@link org.apache.mahout.cf.taste.similarity.ItemSimilarity} is the most important point to discuss
+ * here. Item-based recommenders are useful because they can take advantage of something to be very fast: they
+ * base their computations on item similarity, not user similarity, and item similarity is relatively static.
+ * It can be precomputed, instead of re-computed in real time.
+ * </p>
+ *
+ * <p>
+ * Thus it's strongly recommended that you use
+ * {@link org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity} with pre-computed similarities if
+ * you're going to use this class. You can use
+ * {@link org.apache.mahout.cf.taste.impl.similarity.PearsonCorrelationSimilarity} too, which computes
+ * similarities in real-time, but will probably find this painfully slow for large amounts of data.
+ * </p>
+ */
+public class GenericItemBasedRecommender extends AbstractRecommender implements ItemBasedRecommender {
+
+ private static final Logger log = LoggerFactory.getLogger(GenericItemBasedRecommender.class);
+
+ private final ItemSimilarity similarity;
+ private final MostSimilarItemsCandidateItemsStrategy mostSimilarItemsCandidateItemsStrategy;
+ private final RefreshHelper refreshHelper;
+ private EstimatedPreferenceCapper capper;
+
+ private static final boolean EXCLUDE_ITEM_IF_NOT_SIMILAR_TO_ALL_BY_DEFAULT = true;
+
+ public GenericItemBasedRecommender(DataModel dataModel,
+ ItemSimilarity similarity,
+ CandidateItemsStrategy candidateItemsStrategy,
+ MostSimilarItemsCandidateItemsStrategy mostSimilarItemsCandidateItemsStrategy) {
+ super(dataModel, candidateItemsStrategy);
+ Preconditions.checkArgument(similarity != null, "similarity is null");
+ this.similarity = similarity;
+ Preconditions.checkArgument(mostSimilarItemsCandidateItemsStrategy != null,
+ "mostSimilarItemsCandidateItemsStrategy is null");
+ this.mostSimilarItemsCandidateItemsStrategy = mostSimilarItemsCandidateItemsStrategy;
+ this.refreshHelper = new RefreshHelper(new Callable<Void>() {
+ @Override
+ public Void call() {
+ capper = buildCapper();
+ return null;
+ }
+ });
+ refreshHelper.addDependency(dataModel);
+ refreshHelper.addDependency(similarity);
+ refreshHelper.addDependency(candidateItemsStrategy);
+ refreshHelper.addDependency(mostSimilarItemsCandidateItemsStrategy);
+ capper = buildCapper();
+ }
+
+ public GenericItemBasedRecommender(DataModel dataModel, ItemSimilarity similarity) {
+ this(dataModel,
+ similarity,
+ AbstractRecommender.getDefaultCandidateItemsStrategy(),
+ getDefaultMostSimilarItemsCandidateItemsStrategy());
+ }
+
+ protected static MostSimilarItemsCandidateItemsStrategy getDefaultMostSimilarItemsCandidateItemsStrategy() {
+ return new PreferredItemsNeighborhoodCandidateItemsStrategy();
+ }
+
+ public ItemSimilarity getSimilarity() {
+ return similarity;
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
+ log.debug("Recommending items for user ID '{}'", userID);
+
+ PreferenceArray preferencesFromUser = getDataModel().getPreferencesFromUser(userID);
+ if (preferencesFromUser.length() == 0) {
+ return Collections.emptyList();
+ }
+
+ FastIDSet possibleItemIDs = getAllOtherItems(userID, preferencesFromUser, includeKnownItems);
+
+ TopItems.Estimator<Long> estimator = new Estimator(userID, preferencesFromUser);
+
+ List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer,
+ estimator);
+
+ log.debug("Recommendations are: {}", topItems);
+ return topItems;
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ PreferenceArray preferencesFromUser = getDataModel().getPreferencesFromUser(userID);
+ Float actualPref = getPreferenceForItem(preferencesFromUser, itemID);
+ if (actualPref != null) {
+ return actualPref;
+ }
+ return doEstimatePreference(userID, preferencesFromUser, itemID);
+ }
+
+ private static Float getPreferenceForItem(PreferenceArray preferencesFromUser, long itemID) {
+ int size = preferencesFromUser.length();
+ for (int i = 0; i < size; i++) {
+ if (preferencesFromUser.getItemID(i) == itemID) {
+ return preferencesFromUser.getValue(i);
+ }
+ }
+ return null;
+ }
+
+ @Override
+ public List<RecommendedItem> mostSimilarItems(long itemID, int howMany) throws TasteException {
+ return mostSimilarItems(itemID, howMany, null);
+ }
+
+ @Override
+ public List<RecommendedItem> mostSimilarItems(long itemID, int howMany,
+ Rescorer<LongPair> rescorer) throws TasteException {
+ TopItems.Estimator<Long> estimator = new MostSimilarEstimator(itemID, similarity, rescorer);
+ return doMostSimilarItems(new long[] {itemID}, howMany, estimator);
+ }
+
+ @Override
+ public List<RecommendedItem> mostSimilarItems(long[] itemIDs, int howMany) throws TasteException {
+ TopItems.Estimator<Long> estimator = new MultiMostSimilarEstimator(itemIDs, similarity, null,
+ EXCLUDE_ITEM_IF_NOT_SIMILAR_TO_ALL_BY_DEFAULT);
+ return doMostSimilarItems(itemIDs, howMany, estimator);
+ }
+
+ @Override
+ public List<RecommendedItem> mostSimilarItems(long[] itemIDs, int howMany,
+ Rescorer<LongPair> rescorer) throws TasteException {
+ TopItems.Estimator<Long> estimator = new MultiMostSimilarEstimator(itemIDs, similarity, rescorer,
+ EXCLUDE_ITEM_IF_NOT_SIMILAR_TO_ALL_BY_DEFAULT);
+ return doMostSimilarItems(itemIDs, howMany, estimator);
+ }
+
+ @Override
+ public List<RecommendedItem> mostSimilarItems(long[] itemIDs,
+ int howMany,
+ boolean excludeItemIfNotSimilarToAll) throws TasteException {
+ TopItems.Estimator<Long> estimator = new MultiMostSimilarEstimator(itemIDs, similarity, null,
+ excludeItemIfNotSimilarToAll);
+ return doMostSimilarItems(itemIDs, howMany, estimator);
+ }
+
+ @Override
+ public List<RecommendedItem> mostSimilarItems(long[] itemIDs, int howMany,
+ Rescorer<LongPair> rescorer,
+ boolean excludeItemIfNotSimilarToAll) throws TasteException {
+ TopItems.Estimator<Long> estimator = new MultiMostSimilarEstimator(itemIDs, similarity, rescorer,
+ excludeItemIfNotSimilarToAll);
+ return doMostSimilarItems(itemIDs, howMany, estimator);
+ }
+
+ @Override
+ public List<RecommendedItem> recommendedBecause(long userID, long itemID, int howMany) throws TasteException {
+ Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
+
+ DataModel model = getDataModel();
+ TopItems.Estimator<Long> estimator = new RecommendedBecauseEstimator(userID, itemID);
+
+ PreferenceArray prefs = model.getPreferencesFromUser(userID);
+ int size = prefs.length();
+ FastIDSet allUserItems = new FastIDSet(size);
+ for (int i = 0; i < size; i++) {
+ allUserItems.add(prefs.getItemID(i));
+ }
+ allUserItems.remove(itemID);
+
+ return TopItems.getTopItems(howMany, allUserItems.iterator(), null, estimator);
+ }
+
+ private List<RecommendedItem> doMostSimilarItems(long[] itemIDs,
+ int howMany,
+ TopItems.Estimator<Long> estimator) throws TasteException {
+ FastIDSet possibleItemIDs = mostSimilarItemsCandidateItemsStrategy.getCandidateItems(itemIDs, getDataModel());
+ return TopItems.getTopItems(howMany, possibleItemIDs.iterator(), null, estimator);
+ }
+
+ protected float doEstimatePreference(long userID, PreferenceArray preferencesFromUser, long itemID)
+ throws TasteException {
+ double preference = 0.0;
+ double totalSimilarity = 0.0;
+ int count = 0;
+ double[] similarities = similarity.itemSimilarities(itemID, preferencesFromUser.getIDs());
+ for (int i = 0; i < similarities.length; i++) {
+ double theSimilarity = similarities[i];
+ if (!Double.isNaN(theSimilarity)) {
+ // Weights can be negative!
+ preference += theSimilarity * preferencesFromUser.getValue(i);
+ totalSimilarity += theSimilarity;
+ count++;
+ }
+ }
+ // Throw out the estimate if it was based on no data points, of course, but also if based on
+ // just one. This is a bit of a band-aid on the 'stock' item-based algorithm for the moment.
+ // The reason is that in this case the estimate is, simply, the user's rating for one item
+ // that happened to have a defined similarity. The similarity score doesn't matter, and that
+ // seems like a bad situation.
+ if (count <= 1) {
+ return Float.NaN;
+ }
+ float estimate = (float) (preference / totalSimilarity);
+ if (capper != null) {
+ estimate = capper.capEstimate(estimate);
+ }
+ return estimate;
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+ @Override
+ public String toString() {
+ return "GenericItemBasedRecommender[similarity:" + similarity + ']';
+ }
+
+ private EstimatedPreferenceCapper buildCapper() {
+ DataModel dataModel = getDataModel();
+ if (Float.isNaN(dataModel.getMinPreference()) && Float.isNaN(dataModel.getMaxPreference())) {
+ return null;
+ } else {
+ return new EstimatedPreferenceCapper(dataModel);
+ }
+ }
+
+ public static class MostSimilarEstimator implements TopItems.Estimator<Long> {
+
+ private final long toItemID;
+ private final ItemSimilarity similarity;
+ private final Rescorer<LongPair> rescorer;
+
+ public MostSimilarEstimator(long toItemID, ItemSimilarity similarity, Rescorer<LongPair> rescorer) {
+ this.toItemID = toItemID;
+ this.similarity = similarity;
+ this.rescorer = rescorer;
+ }
+
+ @Override
+ public double estimate(Long itemID) throws TasteException {
+ LongPair pair = new LongPair(toItemID, itemID);
+ if (rescorer != null && rescorer.isFiltered(pair)) {
+ return Double.NaN;
+ }
+ double originalEstimate = similarity.itemSimilarity(toItemID, itemID);
+ return rescorer == null ? originalEstimate : rescorer.rescore(pair, originalEstimate);
+ }
+ }
+
+ private final class Estimator implements TopItems.Estimator<Long> {
+
+ private final long userID;
+ private final PreferenceArray preferencesFromUser;
+
+ private Estimator(long userID, PreferenceArray preferencesFromUser) {
+ this.userID = userID;
+ this.preferencesFromUser = preferencesFromUser;
+ }
+
+ @Override
+ public double estimate(Long itemID) throws TasteException {
+ return doEstimatePreference(userID, preferencesFromUser, itemID);
+ }
+ }
+
+ private static final class MultiMostSimilarEstimator implements TopItems.Estimator<Long> {
+
+ private final long[] toItemIDs;
+ private final ItemSimilarity similarity;
+ private final Rescorer<LongPair> rescorer;
+ private final boolean excludeItemIfNotSimilarToAll;
+
+ private MultiMostSimilarEstimator(long[] toItemIDs, ItemSimilarity similarity, Rescorer<LongPair> rescorer,
+ boolean excludeItemIfNotSimilarToAll) {
+ this.toItemIDs = toItemIDs;
+ this.similarity = similarity;
+ this.rescorer = rescorer;
+ this.excludeItemIfNotSimilarToAll = excludeItemIfNotSimilarToAll;
+ }
+
+ @Override
+ public double estimate(Long itemID) throws TasteException {
+ RunningAverage average = new FullRunningAverage();
+ double[] similarities = similarity.itemSimilarities(itemID, toItemIDs);
+ for (int i = 0; i < toItemIDs.length; i++) {
+ long toItemID = toItemIDs[i];
+ LongPair pair = new LongPair(toItemID, itemID);
+ if (rescorer != null && rescorer.isFiltered(pair)) {
+ continue;
+ }
+ double estimate = similarities[i];
+ if (rescorer != null) {
+ estimate = rescorer.rescore(pair, estimate);
+ }
+ if (excludeItemIfNotSimilarToAll || !Double.isNaN(estimate)) {
+ average.addDatum(estimate);
+ }
+ }
+ double averageEstimate = average.getAverage();
+ return averageEstimate == 0 ? Double.NaN : averageEstimate;
+ }
+ }
+
+ private final class RecommendedBecauseEstimator implements TopItems.Estimator<Long> {
+
+ private final long userID;
+ private final long recommendedItemID;
+
+ private RecommendedBecauseEstimator(long userID, long recommendedItemID) {
+ this.userID = userID;
+ this.recommendedItemID = recommendedItemID;
+ }
+
+ @Override
+ public double estimate(Long itemID) throws TasteException {
+ Float pref = getDataModel().getPreferenceValue(userID, itemID);
+ if (pref == null) {
+ return Float.NaN;
+ }
+ double similarityValue = similarity.itemSimilarity(recommendedItemID, itemID);
+ return (1.0 + similarityValue) * pref;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericRecommendedItem.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericRecommendedItem.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericRecommendedItem.java
new file mode 100644
index 0000000..8c8f6ce
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericRecommendedItem.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.io.Serializable;
+
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.common.RandomUtils;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * A simple implementation of {@link RecommendedItem}.
+ * </p>
+ */
+public final class GenericRecommendedItem implements RecommendedItem, Serializable {
+
+ private final long itemID;
+ private final float value;
+
+ /**
+ * @throws IllegalArgumentException
+ * if item is null or value is NaN
+ */
+ public GenericRecommendedItem(long itemID, float value) {
+ Preconditions.checkArgument(!Float.isNaN(value), "value is NaN");
+ this.itemID = itemID;
+ this.value = value;
+ }
+
+ @Override
+ public long getItemID() {
+ return itemID;
+ }
+
+ @Override
+ public float getValue() {
+ return value;
+ }
+
+ @Override
+ public String toString() {
+ return "RecommendedItem[item:" + itemID + ", value:" + value + ']';
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) itemID ^ RandomUtils.hashFloat(value);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof GenericRecommendedItem)) {
+ return false;
+ }
+ RecommendedItem other = (RecommendedItem) o;
+ return itemID == other.getItemID() && value == other.getValue();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericUserBasedRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericUserBasedRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericUserBasedRecommender.java
new file mode 100644
index 0000000..1e2ef73
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericUserBasedRecommender.java
@@ -0,0 +1,247 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Callable;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Rescorer;
+import org.apache.mahout.cf.taste.recommender.UserBasedRecommender;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import org.apache.mahout.common.LongPair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * A simple {@link org.apache.mahout.cf.taste.recommender.Recommender}
+ * which uses a given {@link DataModel} and {@link UserNeighborhood} to produce recommendations.
+ * </p>
+ */
+public class GenericUserBasedRecommender extends AbstractRecommender implements UserBasedRecommender {
+
+ private static final Logger log = LoggerFactory.getLogger(GenericUserBasedRecommender.class);
+
+ private final UserNeighborhood neighborhood;
+ private final UserSimilarity similarity;
+ private final RefreshHelper refreshHelper;
+ private EstimatedPreferenceCapper capper;
+
+ public GenericUserBasedRecommender(DataModel dataModel,
+ UserNeighborhood neighborhood,
+ UserSimilarity similarity) {
+ super(dataModel);
+ Preconditions.checkArgument(neighborhood != null, "neighborhood is null");
+ this.neighborhood = neighborhood;
+ this.similarity = similarity;
+ this.refreshHelper = new RefreshHelper(new Callable<Void>() {
+ @Override
+ public Void call() {
+ capper = buildCapper();
+ return null;
+ }
+ });
+ refreshHelper.addDependency(dataModel);
+ refreshHelper.addDependency(similarity);
+ refreshHelper.addDependency(neighborhood);
+ capper = buildCapper();
+ }
+
+ public UserSimilarity getSimilarity() {
+ return similarity;
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
+
+ log.debug("Recommending items for user ID '{}'", userID);
+
+ long[] theNeighborhood = neighborhood.getUserNeighborhood(userID);
+
+ if (theNeighborhood.length == 0) {
+ return Collections.emptyList();
+ }
+
+ FastIDSet allItemIDs = getAllOtherItems(theNeighborhood, userID, includeKnownItems);
+
+ TopItems.Estimator<Long> estimator = new Estimator(userID, theNeighborhood);
+
+ List<RecommendedItem> topItems = TopItems
+ .getTopItems(howMany, allItemIDs.iterator(), rescorer, estimator);
+
+ log.debug("Recommendations are: {}", topItems);
+ return topItems;
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ DataModel model = getDataModel();
+ Float actualPref = model.getPreferenceValue(userID, itemID);
+ if (actualPref != null) {
+ return actualPref;
+ }
+ long[] theNeighborhood = neighborhood.getUserNeighborhood(userID);
+ return doEstimatePreference(userID, theNeighborhood, itemID);
+ }
+
+ @Override
+ public long[] mostSimilarUserIDs(long userID, int howMany) throws TasteException {
+ return mostSimilarUserIDs(userID, howMany, null);
+ }
+
+ @Override
+ public long[] mostSimilarUserIDs(long userID, int howMany, Rescorer<LongPair> rescorer) throws TasteException {
+ TopItems.Estimator<Long> estimator = new MostSimilarEstimator(userID, similarity, rescorer);
+ return doMostSimilarUsers(howMany, estimator);
+ }
+
+ private long[] doMostSimilarUsers(int howMany, TopItems.Estimator<Long> estimator) throws TasteException {
+ DataModel model = getDataModel();
+ return TopItems.getTopUsers(howMany, model.getUserIDs(), null, estimator);
+ }
+
+ protected float doEstimatePreference(long theUserID, long[] theNeighborhood, long itemID) throws TasteException {
+ if (theNeighborhood.length == 0) {
+ return Float.NaN;
+ }
+ DataModel dataModel = getDataModel();
+ double preference = 0.0;
+ double totalSimilarity = 0.0;
+ int count = 0;
+ for (long userID : theNeighborhood) {
+ if (userID != theUserID) {
+ // See GenericItemBasedRecommender.doEstimatePreference() too
+ Float pref = dataModel.getPreferenceValue(userID, itemID);
+ if (pref != null) {
+ double theSimilarity = similarity.userSimilarity(theUserID, userID);
+ if (!Double.isNaN(theSimilarity)) {
+ preference += theSimilarity * pref;
+ totalSimilarity += theSimilarity;
+ count++;
+ }
+ }
+ }
+ }
+ // Throw out the estimate if it was based on no data points, of course, but also if based on
+ // just one. This is a bit of a band-aid on the 'stock' item-based algorithm for the moment.
+ // The reason is that in this case the estimate is, simply, the user's rating for one item
+ // that happened to have a defined similarity. The similarity score doesn't matter, and that
+ // seems like a bad situation.
+ if (count <= 1) {
+ return Float.NaN;
+ }
+ float estimate = (float) (preference / totalSimilarity);
+ if (capper != null) {
+ estimate = capper.capEstimate(estimate);
+ }
+ return estimate;
+ }
+
+ protected FastIDSet getAllOtherItems(long[] theNeighborhood, long theUserID, boolean includeKnownItems)
+ throws TasteException {
+ DataModel dataModel = getDataModel();
+ FastIDSet possibleItemIDs = new FastIDSet();
+ for (long userID : theNeighborhood) {
+ possibleItemIDs.addAll(dataModel.getItemIDsFromUser(userID));
+ }
+ if (!includeKnownItems) {
+ possibleItemIDs.removeAll(dataModel.getItemIDsFromUser(theUserID));
+ }
+ return possibleItemIDs;
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+ @Override
+ public String toString() {
+ return "GenericUserBasedRecommender[neighborhood:" + neighborhood + ']';
+ }
+
+ private EstimatedPreferenceCapper buildCapper() {
+ DataModel dataModel = getDataModel();
+ if (Float.isNaN(dataModel.getMinPreference()) && Float.isNaN(dataModel.getMaxPreference())) {
+ return null;
+ } else {
+ return new EstimatedPreferenceCapper(dataModel);
+ }
+ }
+
+ private static final class MostSimilarEstimator implements TopItems.Estimator<Long> {
+
+ private final long toUserID;
+ private final UserSimilarity similarity;
+ private final Rescorer<LongPair> rescorer;
+
+ private MostSimilarEstimator(long toUserID, UserSimilarity similarity, Rescorer<LongPair> rescorer) {
+ this.toUserID = toUserID;
+ this.similarity = similarity;
+ this.rescorer = rescorer;
+ }
+
+ @Override
+ public double estimate(Long userID) throws TasteException {
+ // Don't consider the user itself as a possible most similar user
+ if (userID == toUserID) {
+ return Double.NaN;
+ }
+ if (rescorer == null) {
+ return similarity.userSimilarity(toUserID, userID);
+ } else {
+ LongPair pair = new LongPair(toUserID, userID);
+ if (rescorer.isFiltered(pair)) {
+ return Double.NaN;
+ }
+ double originalEstimate = similarity.userSimilarity(toUserID, userID);
+ return rescorer.rescore(pair, originalEstimate);
+ }
+ }
+ }
+
+ private final class Estimator implements TopItems.Estimator<Long> {
+
+ private final long theUserID;
+ private final long[] theNeighborhood;
+
+ Estimator(long theUserID, long[] theNeighborhood) {
+ this.theUserID = theUserID;
+ this.theNeighborhood = theNeighborhood;
+ }
+
+ @Override
+ public double estimate(Long itemID) throws TasteException {
+ return doEstimatePreference(theUserID, theNeighborhood, itemID);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ItemAverageRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ItemAverageRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ItemAverageRecommender.java
new file mode 100644
index 0000000..618c65f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ItemAverageRecommender.java
@@ -0,0 +1,199 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * A simple recommender that always estimates preference for an item to be the average of all known preference
+ * values for that item. No information about users is taken into account. This implementation is provided for
+ * experimentation; while simple and fast, it may not produce very good recommendations.
+ * </p>
+ */
+public final class ItemAverageRecommender extends AbstractRecommender {
+
+ private static final Logger log = LoggerFactory.getLogger(ItemAverageRecommender.class);
+
+ private final FastByIDMap<RunningAverage> itemAverages;
+ private final ReadWriteLock buildAveragesLock;
+ private final RefreshHelper refreshHelper;
+
+ public ItemAverageRecommender(DataModel dataModel) throws TasteException {
+ super(dataModel);
+ this.itemAverages = new FastByIDMap<>();
+ this.buildAveragesLock = new ReentrantReadWriteLock();
+ this.refreshHelper = new RefreshHelper(new Callable<Object>() {
+ @Override
+ public Object call() throws TasteException {
+ buildAverageDiffs();
+ return null;
+ }
+ });
+ refreshHelper.addDependency(dataModel);
+ buildAverageDiffs();
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
+ log.debug("Recommending items for user ID '{}'", userID);
+
+ PreferenceArray preferencesFromUser = getDataModel().getPreferencesFromUser(userID);
+ FastIDSet possibleItemIDs = getAllOtherItems(userID, preferencesFromUser, includeKnownItems);
+
+ TopItems.Estimator<Long> estimator = new Estimator();
+
+ List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer,
+ estimator);
+
+ log.debug("Recommendations are: {}", topItems);
+ return topItems;
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ DataModel dataModel = getDataModel();
+ Float actualPref = dataModel.getPreferenceValue(userID, itemID);
+ if (actualPref != null) {
+ return actualPref;
+ }
+ return doEstimatePreference(itemID);
+ }
+
+ private float doEstimatePreference(long itemID) {
+ buildAveragesLock.readLock().lock();
+ try {
+ RunningAverage average = itemAverages.get(itemID);
+ return average == null ? Float.NaN : (float) average.getAverage();
+ } finally {
+ buildAveragesLock.readLock().unlock();
+ }
+ }
+
+ private void buildAverageDiffs() throws TasteException {
+ try {
+ buildAveragesLock.writeLock().lock();
+ DataModel dataModel = getDataModel();
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ PreferenceArray prefs = dataModel.getPreferencesFromUser(it.nextLong());
+ int size = prefs.length();
+ for (int i = 0; i < size; i++) {
+ long itemID = prefs.getItemID(i);
+ RunningAverage average = itemAverages.get(itemID);
+ if (average == null) {
+ average = new FullRunningAverage();
+ itemAverages.put(itemID, average);
+ }
+ average.addDatum(prefs.getValue(i));
+ }
+ }
+ } finally {
+ buildAveragesLock.writeLock().unlock();
+ }
+ }
+
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ DataModel dataModel = getDataModel();
+ double prefDelta;
+ try {
+ Float oldPref = dataModel.getPreferenceValue(userID, itemID);
+ prefDelta = oldPref == null ? value : value - oldPref;
+ } catch (NoSuchUserException nsee) {
+ prefDelta = value;
+ }
+ super.setPreference(userID, itemID, value);
+ try {
+ buildAveragesLock.writeLock().lock();
+ RunningAverage average = itemAverages.get(itemID);
+ if (average == null) {
+ RunningAverage newAverage = new FullRunningAverage();
+ newAverage.addDatum(prefDelta);
+ itemAverages.put(itemID, newAverage);
+ } else {
+ average.changeDatum(prefDelta);
+ }
+ } finally {
+ buildAveragesLock.writeLock().unlock();
+ }
+ }
+
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ DataModel dataModel = getDataModel();
+ Float oldPref = dataModel.getPreferenceValue(userID, itemID);
+ super.removePreference(userID, itemID);
+ if (oldPref != null) {
+ try {
+ buildAveragesLock.writeLock().lock();
+ RunningAverage average = itemAverages.get(itemID);
+ if (average == null) {
+ throw new IllegalStateException("No preferences exist for item ID: " + itemID);
+ } else {
+ average.removeDatum(oldPref);
+ }
+ } finally {
+ buildAveragesLock.writeLock().unlock();
+ }
+ }
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+ @Override
+ public String toString() {
+ return "ItemAverageRecommender";
+ }
+
+ private final class Estimator implements TopItems.Estimator<Long> {
+
+ @Override
+ public double estimate(Long itemID) {
+ return doEstimatePreference(itemID);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ItemUserAverageRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ItemUserAverageRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ItemUserAverageRecommender.java
new file mode 100644
index 0000000..b2bcd24
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ItemUserAverageRecommender.java
@@ -0,0 +1,240 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
+
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * Like {@link ItemAverageRecommender}, except that estimated preferences are adjusted for the users' average
+ * preference value. For example, say user X has not rated item Y. Item Y's average preference value is 3.5.
+ * User X's average preference value is 4.2, and the average over all preference values is 4.0. User X prefers
+ * items 0.2 higher on average, so, the estimated preference for user X, item Y is 3.5 + 0.2 = 3.7.
+ * </p>
+ */
+public final class ItemUserAverageRecommender extends AbstractRecommender {
+
+ private static final Logger log = LoggerFactory.getLogger(ItemUserAverageRecommender.class);
+
+ private final FastByIDMap<RunningAverage> itemAverages;
+ private final FastByIDMap<RunningAverage> userAverages;
+ private final RunningAverage overallAveragePrefValue;
+ private final ReadWriteLock buildAveragesLock;
+ private final RefreshHelper refreshHelper;
+
+ public ItemUserAverageRecommender(DataModel dataModel) throws TasteException {
+ super(dataModel);
+ this.itemAverages = new FastByIDMap<>();
+ this.userAverages = new FastByIDMap<>();
+ this.overallAveragePrefValue = new FullRunningAverage();
+ this.buildAveragesLock = new ReentrantReadWriteLock();
+ this.refreshHelper = new RefreshHelper(new Callable<Object>() {
+ @Override
+ public Object call() throws TasteException {
+ buildAverageDiffs();
+ return null;
+ }
+ });
+ refreshHelper.addDependency(dataModel);
+ buildAverageDiffs();
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
+ log.debug("Recommending items for user ID '{}'", userID);
+
+ PreferenceArray preferencesFromUser = getDataModel().getPreferencesFromUser(userID);
+ FastIDSet possibleItemIDs = getAllOtherItems(userID, preferencesFromUser, includeKnownItems);
+
+ TopItems.Estimator<Long> estimator = new Estimator(userID);
+
+ List<RecommendedItem> topItems = TopItems.getTopItems(howMany, possibleItemIDs.iterator(), rescorer,
+ estimator);
+
+ log.debug("Recommendations are: {}", topItems);
+ return topItems;
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ DataModel dataModel = getDataModel();
+ Float actualPref = dataModel.getPreferenceValue(userID, itemID);
+ if (actualPref != null) {
+ return actualPref;
+ }
+ return doEstimatePreference(userID, itemID);
+ }
+
+ private float doEstimatePreference(long userID, long itemID) {
+ buildAveragesLock.readLock().lock();
+ try {
+ RunningAverage itemAverage = itemAverages.get(itemID);
+ if (itemAverage == null) {
+ return Float.NaN;
+ }
+ RunningAverage userAverage = userAverages.get(userID);
+ if (userAverage == null) {
+ return Float.NaN;
+ }
+ double userDiff = userAverage.getAverage() - overallAveragePrefValue.getAverage();
+ return (float) (itemAverage.getAverage() + userDiff);
+ } finally {
+ buildAveragesLock.readLock().unlock();
+ }
+ }
+
+ private void buildAverageDiffs() throws TasteException {
+ try {
+ buildAveragesLock.writeLock().lock();
+ DataModel dataModel = getDataModel();
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ long userID = it.nextLong();
+ PreferenceArray prefs = dataModel.getPreferencesFromUser(userID);
+ int size = prefs.length();
+ for (int i = 0; i < size; i++) {
+ long itemID = prefs.getItemID(i);
+ float value = prefs.getValue(i);
+ addDatumAndCreateIfNeeded(itemID, value, itemAverages);
+ addDatumAndCreateIfNeeded(userID, value, userAverages);
+ overallAveragePrefValue.addDatum(value);
+ }
+ }
+ } finally {
+ buildAveragesLock.writeLock().unlock();
+ }
+ }
+
+ private static void addDatumAndCreateIfNeeded(long itemID, float value, FastByIDMap<RunningAverage> averages) {
+ RunningAverage itemAverage = averages.get(itemID);
+ if (itemAverage == null) {
+ itemAverage = new FullRunningAverage();
+ averages.put(itemID, itemAverage);
+ }
+ itemAverage.addDatum(value);
+ }
+
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ DataModel dataModel = getDataModel();
+ double prefDelta;
+ try {
+ Float oldPref = dataModel.getPreferenceValue(userID, itemID);
+ prefDelta = oldPref == null ? value : value - oldPref;
+ } catch (NoSuchUserException nsee) {
+ prefDelta = value;
+ }
+ super.setPreference(userID, itemID, value);
+ try {
+ buildAveragesLock.writeLock().lock();
+ RunningAverage itemAverage = itemAverages.get(itemID);
+ if (itemAverage == null) {
+ RunningAverage newItemAverage = new FullRunningAverage();
+ newItemAverage.addDatum(prefDelta);
+ itemAverages.put(itemID, newItemAverage);
+ } else {
+ itemAverage.changeDatum(prefDelta);
+ }
+ RunningAverage userAverage = userAverages.get(userID);
+ if (userAverage == null) {
+ RunningAverage newUserAveragae = new FullRunningAverage();
+ newUserAveragae.addDatum(prefDelta);
+ userAverages.put(userID, newUserAveragae);
+ } else {
+ userAverage.changeDatum(prefDelta);
+ }
+ overallAveragePrefValue.changeDatum(prefDelta);
+ } finally {
+ buildAveragesLock.writeLock().unlock();
+ }
+ }
+
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ DataModel dataModel = getDataModel();
+ Float oldPref = dataModel.getPreferenceValue(userID, itemID);
+ super.removePreference(userID, itemID);
+ if (oldPref != null) {
+ try {
+ buildAveragesLock.writeLock().lock();
+ RunningAverage itemAverage = itemAverages.get(itemID);
+ if (itemAverage == null) {
+ throw new IllegalStateException("No preferences exist for item ID: " + itemID);
+ }
+ itemAverage.removeDatum(oldPref);
+ RunningAverage userAverage = userAverages.get(userID);
+ if (userAverage == null) {
+ throw new IllegalStateException("No preferences exist for user ID: " + userID);
+ }
+ userAverage.removeDatum(oldPref);
+ overallAveragePrefValue.removeDatum(oldPref);
+ } finally {
+ buildAveragesLock.writeLock().unlock();
+ }
+ }
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+ @Override
+ public String toString() {
+ return "ItemUserAverageRecommender";
+ }
+
+ private final class Estimator implements TopItems.Estimator<Long> {
+
+ private final long userID;
+
+ private Estimator(long userID) {
+ this.userID = userID;
+ }
+
+ @Override
+ public double estimate(Long itemID) {
+ return doEstimatePreference(userID, itemID);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/NullRescorer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/NullRescorer.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/NullRescorer.java
new file mode 100644
index 0000000..14e9ec6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/NullRescorer.java
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.Rescorer;
+import org.apache.mahout.common.LongPair;
+
+/**
+ * <p>
+ * A simple {@link Rescorer} which always returns the original score.
+ * </p>
+ */
+public final class NullRescorer<T> implements Rescorer<T>, IDRescorer {
+
+ private static final IDRescorer USER_OR_ITEM_INSTANCE = new NullRescorer<Long>();
+ private static final Rescorer<LongPair> ITEM_ITEM_PAIR_INSTANCE = new NullRescorer<>();
+ private static final Rescorer<LongPair> USER_USER_PAIR_INSTANCE = new NullRescorer<>();
+
+ private NullRescorer() {
+ }
+
+ public static IDRescorer getItemInstance() {
+ return USER_OR_ITEM_INSTANCE;
+ }
+
+ public static IDRescorer getUserInstance() {
+ return USER_OR_ITEM_INSTANCE;
+ }
+
+ public static Rescorer<LongPair> getItemItemPairInstance() {
+ return ITEM_ITEM_PAIR_INSTANCE;
+ }
+
+ public static Rescorer<LongPair> getUserUserPairInstance() {
+ return USER_USER_PAIR_INSTANCE;
+ }
+
+ /**
+ * @param thing
+ * to rescore
+ * @param originalScore
+ * current score for item
+ * @return same originalScore as new score, always
+ */
+ @Override
+ public double rescore(T thing, double originalScore) {
+ return originalScore;
+ }
+
+ @Override
+ public boolean isFiltered(T thing) {
+ return false;
+ }
+
+ @Override
+ public double rescore(long id, double originalScore) {
+ return originalScore;
+ }
+
+ @Override
+ public boolean isFiltered(long id) {
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return "NullRescorer";
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/PreferredItemsNeighborhoodCandidateItemsStrategy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/PreferredItemsNeighborhoodCandidateItemsStrategy.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/PreferredItemsNeighborhoodCandidateItemsStrategy.java
new file mode 100644
index 0000000..6297d0b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/PreferredItemsNeighborhoodCandidateItemsStrategy.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+
+public final class PreferredItemsNeighborhoodCandidateItemsStrategy extends AbstractCandidateItemsStrategy {
+
+ /**
+ * returns all items that have not been rated by the user and that were preferred by another user
+ * that has preferred at least one item that the current user has preferred too
+ */
+ @Override
+ protected FastIDSet doGetCandidateItems(long[] preferredItemIDs, DataModel dataModel, boolean includeKnownItems)
+ throws TasteException {
+ FastIDSet possibleItemsIDs = new FastIDSet();
+ for (long itemID : preferredItemIDs) {
+ PreferenceArray itemPreferences = dataModel.getPreferencesForItem(itemID);
+ int numUsersPreferringItem = itemPreferences.length();
+ for (int index = 0; index < numUsersPreferringItem; index++) {
+ possibleItemsIDs.addAll(dataModel.getItemIDsFromUser(itemPreferences.getUserID(index)));
+ }
+ }
+ if (!includeKnownItems) {
+ possibleItemsIDs.removeAll(preferredItemIDs);
+ }
+ return possibleItemsIDs;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java
new file mode 100644
index 0000000..ef11f0d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommender.java
@@ -0,0 +1,97 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.common.RandomUtils;
+
+/**
+ * Produces random recommendations and preference estimates. This is likely only useful as a novelty and for
+ * benchmarking.
+ */
+public final class RandomRecommender extends AbstractRecommender {
+
+ private final Random random = RandomUtils.getRandom();
+ private final float minPref;
+ private final float maxPref;
+
+ public RandomRecommender(DataModel dataModel) throws TasteException {
+ super(dataModel);
+ float maxPref = Float.NEGATIVE_INFINITY;
+ float minPref = Float.POSITIVE_INFINITY;
+ LongPrimitiveIterator userIterator = dataModel.getUserIDs();
+ while (userIterator.hasNext()) {
+ long userID = userIterator.next();
+ PreferenceArray prefs = dataModel.getPreferencesFromUser(userID);
+ for (int i = 0; i < prefs.length(); i++) {
+ float prefValue = prefs.getValue(i);
+ if (prefValue < minPref) {
+ minPref = prefValue;
+ }
+ if (prefValue > maxPref) {
+ maxPref = prefValue;
+ }
+ }
+ }
+ this.minPref = minPref;
+ this.maxPref = maxPref;
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ DataModel dataModel = getDataModel();
+ int numItems = dataModel.getNumItems();
+ List<RecommendedItem> result = Lists.newArrayListWithCapacity(howMany);
+ while (result.size() < howMany) {
+ LongPrimitiveIterator it = dataModel.getItemIDs();
+ it.skip(random.nextInt(numItems));
+ long itemID = it.next();
+ if (includeKnownItems || dataModel.getPreferenceValue(userID, itemID) == null) {
+ result.add(new GenericRecommendedItem(itemID, randomPref()));
+ }
+ }
+ return result;
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) {
+ return randomPref();
+ }
+
+ private float randomPref() {
+ return minPref + random.nextFloat() * (maxPref - minPref);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ getDataModel().refresh(alreadyRefreshed);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java
new file mode 100644
index 0000000..623a60b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategy.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import com.google.common.base.Preconditions;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveArrayIterator;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.SamplingLongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.iterator.FixedSizeSamplingIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Iterator;
+
+/**
+ * <p>Returns all items that have not been rated by the user <em>(3)</em> and that were preferred by another user
+ * <em>(2)</em> that has preferred at least one item <em>(1)</em> that the current user has preferred too.</p>
+ *
+ * <p>This strategy uses sampling to limit the number of items that are considered, by sampling three different
+ * things, noted above:</p>
+ *
+ * <ol>
+ * <li>The items that the user has preferred</li>
+ * <li>The users who also prefer each of those items</li>
+ * <li>The items those users also prefer</li>
+ * </ol>
+ *
+ * <p>There is a maximum associated with each of these three things; if the number of items or users exceeds
+ * that max, it is sampled so that the expected number of items or users actually used in that part of the
+ * computation is equal to the max.</p>
+ *
+ * <p>Three arguments control these three maxima. Each is a "factor" f, which establishes the max at
+ * f * log2(n), where n is the number of users or items in the data. For example if factor #2 is 5,
+ * which controls the number of users sampled per item, then 5 * log2(# users) is the maximum for this
+ * part of the computation.</p>
+ *
+ * <p>Each can be set to not do any limiting with value {@link #NO_LIMIT_FACTOR}.</p>
+ */
+public class SamplingCandidateItemsStrategy extends AbstractCandidateItemsStrategy {
+
+ private static final Logger log = LoggerFactory.getLogger(SamplingCandidateItemsStrategy.class);
+
+ /**
+ * Default factor used if not otherwise specified, for all limits. (30).
+ */
+ public static final int DEFAULT_FACTOR = 30;
+ /**
+ * Specify this value as a factor to mean no limit.
+ */
+ public static final int NO_LIMIT_FACTOR = Integer.MAX_VALUE;
+ private static final int MAX_LIMIT = Integer.MAX_VALUE;
+ private static final double LOG2 = Math.log(2.0);
+
+ private final int maxItems;
+ private final int maxUsersPerItem;
+ private final int maxItemsPerUser;
+
+ /**
+ * Defaults to using no limit ({@link #NO_LIMIT_FACTOR}) for all factors, except
+ * {@code candidatesPerUserFactor} which defaults to {@link #DEFAULT_FACTOR}.
+ *
+ * @see #SamplingCandidateItemsStrategy(int, int, int, int, int)
+ */
+ public SamplingCandidateItemsStrategy(int numUsers, int numItems) {
+ this(DEFAULT_FACTOR, DEFAULT_FACTOR, DEFAULT_FACTOR, numUsers, numItems);
+ }
+
+ /**
+ * @param itemsFactor factor controlling max items considered for a user
+ * @param usersPerItemFactor factor controlling max users considered for each of those items
+ * @param candidatesPerUserFactor factor controlling max candidate items considered from each of those users
+ * @param numUsers number of users currently in the data
+ * @param numItems number of items in the data
+ */
+ public SamplingCandidateItemsStrategy(int itemsFactor,
+ int usersPerItemFactor,
+ int candidatesPerUserFactor,
+ int numUsers,
+ int numItems) {
+ Preconditions.checkArgument(itemsFactor > 0, "itemsFactor must be greater then 0!");
+ Preconditions.checkArgument(usersPerItemFactor > 0, "usersPerItemFactor must be greater then 0!");
+ Preconditions.checkArgument(candidatesPerUserFactor > 0, "candidatesPerUserFactor must be greater then 0!");
+ Preconditions.checkArgument(numUsers > 0, "numUsers must be greater then 0!");
+ Preconditions.checkArgument(numItems > 0, "numItems must be greater then 0!");
+ maxItems = computeMaxFrom(itemsFactor, numItems);
+ maxUsersPerItem = computeMaxFrom(usersPerItemFactor, numUsers);
+ maxItemsPerUser = computeMaxFrom(candidatesPerUserFactor, numItems);
+ log.debug("maxItems {}, maxUsersPerItem {}, maxItemsPerUser {}", maxItems, maxUsersPerItem, maxItemsPerUser);
+ }
+
+ private static int computeMaxFrom(int factor, int numThings) {
+ if (factor == NO_LIMIT_FACTOR) {
+ return MAX_LIMIT;
+ }
+ long max = (long) (factor * (1.0 + Math.log(numThings) / LOG2));
+ return max > MAX_LIMIT ? MAX_LIMIT : (int) max;
+ }
+
+ @Override
+ protected FastIDSet doGetCandidateItems(long[] preferredItemIDs, DataModel dataModel, boolean includeKnownItems)
+ throws TasteException {
+ LongPrimitiveIterator preferredItemIDsIterator = new LongPrimitiveArrayIterator(preferredItemIDs);
+ if (preferredItemIDs.length > maxItems) {
+ double samplingRate = (double) maxItems / preferredItemIDs.length;
+// log.info("preferredItemIDs.length {}, samplingRate {}", preferredItemIDs.length, samplingRate);
+ preferredItemIDsIterator =
+ new SamplingLongPrimitiveIterator(preferredItemIDsIterator, samplingRate);
+ }
+ FastIDSet possibleItemsIDs = new FastIDSet();
+ while (preferredItemIDsIterator.hasNext()) {
+ long itemID = preferredItemIDsIterator.nextLong();
+ PreferenceArray prefs = dataModel.getPreferencesForItem(itemID);
+ int prefsLength = prefs.length();
+ if (prefsLength > maxUsersPerItem) {
+ Iterator<Preference> sampledPrefs =
+ new FixedSizeSamplingIterator<>(maxUsersPerItem, prefs.iterator());
+ while (sampledPrefs.hasNext()) {
+ addSomeOf(possibleItemsIDs, dataModel.getItemIDsFromUser(sampledPrefs.next().getUserID()));
+ }
+ } else {
+ for (int i = 0; i < prefsLength; i++) {
+ addSomeOf(possibleItemsIDs, dataModel.getItemIDsFromUser(prefs.getUserID(i)));
+ }
+ }
+ }
+ if (!includeKnownItems) {
+ possibleItemsIDs.removeAll(preferredItemIDs);
+ }
+ return possibleItemsIDs;
+ }
+
+ private void addSomeOf(FastIDSet possibleItemIDs, FastIDSet itemIDs) {
+ if (itemIDs.size() > maxItemsPerUser) {
+ LongPrimitiveIterator it =
+ new SamplingLongPrimitiveIterator(itemIDs.iterator(), (double) maxItemsPerUser / itemIDs.size());
+ while (it.hasNext()) {
+ possibleItemIDs.add(it.nextLong());
+ }
+ } else {
+ possibleItemIDs.addAll(itemIDs);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java
new file mode 100644
index 0000000..c6d417f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/SimilarUser.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.common.RandomUtils;
+
+/** Simply encapsulates a user and a similarity value. */
+public final class SimilarUser implements Comparable<SimilarUser> {
+
+ private final long userID;
+ private final double similarity;
+
+ public SimilarUser(long userID, double similarity) {
+ this.userID = userID;
+ this.similarity = similarity;
+ }
+
+ long getUserID() {
+ return userID;
+ }
+
+ double getSimilarity() {
+ return similarity;
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) userID ^ RandomUtils.hashDouble(similarity);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof SimilarUser)) {
+ return false;
+ }
+ SimilarUser other = (SimilarUser) o;
+ return userID == other.getUserID() && similarity == other.getSimilarity();
+ }
+
+ @Override
+ public String toString() {
+ return "SimilarUser[user:" + userID + ", similarity:" + similarity + ']';
+ }
+
+ /** Defines an ordering from most similar to least similar. */
+ @Override
+ public int compareTo(SimilarUser other) {
+ double otherSimilarity = other.getSimilarity();
+ if (similarity > otherSimilarity) {
+ return -1;
+ }
+ if (similarity < otherSimilarity) {
+ return 1;
+ }
+ long otherUserID = other.getUserID();
+ if (userID < otherUserID) {
+ return -1;
+ }
+ if (userID > otherUserID) {
+ return 1;
+ }
+ return 0;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java
new file mode 100644
index 0000000..3c27145
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/TopItems.java
@@ -0,0 +1,212 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.PriorityQueue;
+import java.util.Queue;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.GenericUserSimilarity;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * A simple class that refactors the "find top N things" logic that is used in several places.
+ * </p>
+ */
+public final class TopItems {
+
+ private static final long[] NO_IDS = new long[0];
+
+ private TopItems() { }
+
+ public static List<RecommendedItem> getTopItems(int howMany,
+ LongPrimitiveIterator possibleItemIDs,
+ IDRescorer rescorer,
+ Estimator<Long> estimator) throws TasteException {
+ Preconditions.checkArgument(possibleItemIDs != null, "possibleItemIDs is null");
+ Preconditions.checkArgument(estimator != null, "estimator is null");
+
+ Queue<RecommendedItem> topItems = new PriorityQueue<>(howMany + 1,
+ Collections.reverseOrder(ByValueRecommendedItemComparator.getInstance()));
+ boolean full = false;
+ double lowestTopValue = Double.NEGATIVE_INFINITY;
+ while (possibleItemIDs.hasNext()) {
+ long itemID = possibleItemIDs.next();
+ if (rescorer == null || !rescorer.isFiltered(itemID)) {
+ double preference;
+ try {
+ preference = estimator.estimate(itemID);
+ } catch (NoSuchItemException nsie) {
+ continue;
+ }
+ double rescoredPref = rescorer == null ? preference : rescorer.rescore(itemID, preference);
+ if (!Double.isNaN(rescoredPref) && (!full || rescoredPref > lowestTopValue)) {
+ topItems.add(new GenericRecommendedItem(itemID, (float) rescoredPref));
+ if (full) {
+ topItems.poll();
+ } else if (topItems.size() > howMany) {
+ full = true;
+ topItems.poll();
+ }
+ lowestTopValue = topItems.peek().getValue();
+ }
+ }
+ }
+ int size = topItems.size();
+ if (size == 0) {
+ return Collections.emptyList();
+ }
+ List<RecommendedItem> result = Lists.newArrayListWithCapacity(size);
+ result.addAll(topItems);
+ Collections.sort(result, ByValueRecommendedItemComparator.getInstance());
+ return result;
+ }
+
+ public static long[] getTopUsers(int howMany,
+ LongPrimitiveIterator allUserIDs,
+ IDRescorer rescorer,
+ Estimator<Long> estimator) throws TasteException {
+ Queue<SimilarUser> topUsers = new PriorityQueue<>(howMany + 1, Collections.reverseOrder());
+ boolean full = false;
+ double lowestTopValue = Double.NEGATIVE_INFINITY;
+ while (allUserIDs.hasNext()) {
+ long userID = allUserIDs.next();
+ if (rescorer != null && rescorer.isFiltered(userID)) {
+ continue;
+ }
+ double similarity;
+ try {
+ similarity = estimator.estimate(userID);
+ } catch (NoSuchUserException nsue) {
+ continue;
+ }
+ double rescoredSimilarity = rescorer == null ? similarity : rescorer.rescore(userID, similarity);
+ if (!Double.isNaN(rescoredSimilarity) && (!full || rescoredSimilarity > lowestTopValue)) {
+ topUsers.add(new SimilarUser(userID, rescoredSimilarity));
+ if (full) {
+ topUsers.poll();
+ } else if (topUsers.size() > howMany) {
+ full = true;
+ topUsers.poll();
+ }
+ lowestTopValue = topUsers.peek().getSimilarity();
+ }
+ }
+ int size = topUsers.size();
+ if (size == 0) {
+ return NO_IDS;
+ }
+ List<SimilarUser> sorted = Lists.newArrayListWithCapacity(size);
+ sorted.addAll(topUsers);
+ Collections.sort(sorted);
+ long[] result = new long[size];
+ int i = 0;
+ for (SimilarUser similarUser : sorted) {
+ result[i++] = similarUser.getUserID();
+ }
+ return result;
+ }
+
+ /**
+ * <p>
+ * Thanks to tsmorton for suggesting this functionality and writing part of the code.
+ * </p>
+ *
+ * @see GenericItemSimilarity#GenericItemSimilarity(Iterable, int)
+ * @see GenericItemSimilarity#GenericItemSimilarity(org.apache.mahout.cf.taste.similarity.ItemSimilarity,
+ * org.apache.mahout.cf.taste.model.DataModel, int)
+ */
+ public static List<GenericItemSimilarity.ItemItemSimilarity> getTopItemItemSimilarities(
+ int howMany, Iterator<GenericItemSimilarity.ItemItemSimilarity> allSimilarities) {
+
+ Queue<GenericItemSimilarity.ItemItemSimilarity> topSimilarities
+ = new PriorityQueue<>(howMany + 1, Collections.reverseOrder());
+ boolean full = false;
+ double lowestTopValue = Double.NEGATIVE_INFINITY;
+ while (allSimilarities.hasNext()) {
+ GenericItemSimilarity.ItemItemSimilarity similarity = allSimilarities.next();
+ double value = similarity.getValue();
+ if (!Double.isNaN(value) && (!full || value > lowestTopValue)) {
+ topSimilarities.add(similarity);
+ if (full) {
+ topSimilarities.poll();
+ } else if (topSimilarities.size() > howMany) {
+ full = true;
+ topSimilarities.poll();
+ }
+ lowestTopValue = topSimilarities.peek().getValue();
+ }
+ }
+ int size = topSimilarities.size();
+ if (size == 0) {
+ return Collections.emptyList();
+ }
+ List<GenericItemSimilarity.ItemItemSimilarity> result = Lists.newArrayListWithCapacity(size);
+ result.addAll(topSimilarities);
+ Collections.sort(result);
+ return result;
+ }
+
+ public static List<GenericUserSimilarity.UserUserSimilarity> getTopUserUserSimilarities(
+ int howMany, Iterator<GenericUserSimilarity.UserUserSimilarity> allSimilarities) {
+
+ Queue<GenericUserSimilarity.UserUserSimilarity> topSimilarities
+ = new PriorityQueue<>(howMany + 1, Collections.reverseOrder());
+ boolean full = false;
+ double lowestTopValue = Double.NEGATIVE_INFINITY;
+ while (allSimilarities.hasNext()) {
+ GenericUserSimilarity.UserUserSimilarity similarity = allSimilarities.next();
+ double value = similarity.getValue();
+ if (!Double.isNaN(value) && (!full || value > lowestTopValue)) {
+ topSimilarities.add(similarity);
+ if (full) {
+ topSimilarities.poll();
+ } else if (topSimilarities.size() > howMany) {
+ full = true;
+ topSimilarities.poll();
+ }
+ lowestTopValue = topSimilarities.peek().getValue();
+ }
+ }
+ int size = topSimilarities.size();
+ if (size == 0) {
+ return Collections.emptyList();
+ }
+ List<GenericUserSimilarity.UserUserSimilarity> result = Lists.newArrayListWithCapacity(size);
+ result.addAll(topSimilarities);
+ Collections.sort(result);
+ return result;
+ }
+
+ public interface Estimator<T> {
+ double estimate(T thing) throws TasteException;
+ }
+
+}
[34/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java b/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java
new file mode 100644
index 0000000..0f88a70
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptron.java
@@ -0,0 +1,332 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.BufferedReader;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.Arrays;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+
+/** Train a {@link MultilayerPerceptron}. */
+public final class TrainMultilayerPerceptron {
+
+ private static final Logger log = LoggerFactory.getLogger(TrainMultilayerPerceptron.class);
+
+ /** The parameters used by MLP. */
+ static class Parameters {
+ double learningRate;
+ double momemtumWeight;
+ double regularizationWeight;
+
+ String inputFilePath;
+ boolean skipHeader;
+ Map<String, Integer> labelsIndex = Maps.newHashMap();
+
+ String modelFilePath;
+ boolean updateModel;
+ List<Integer> layerSizeList = Lists.newArrayList();
+ String squashingFunctionName;
+ }
+
+ /*
+ private double learningRate;
+ private double momemtumWeight;
+ private double regularizationWeight;
+
+ private String inputFilePath;
+ private boolean skipHeader;
+ private Map<String, Integer> labelsIndex = Maps.newHashMap();
+
+ private String modelFilePath;
+ private boolean updateModel;
+ private List<Integer> layerSizeList = Lists.newArrayList();
+ private String squashingFunctionName;*/
+
+ public static void main(String[] args) throws Exception {
+ Parameters parameters = new Parameters();
+
+ if (parseArgs(args, parameters)) {
+ log.info("Validate model...");
+ // check whether the model already exists
+ Path modelPath = new Path(parameters.modelFilePath);
+ FileSystem modelFs = modelPath.getFileSystem(new Configuration());
+ MultilayerPerceptron mlp;
+
+ if (modelFs.exists(modelPath) && parameters.updateModel) {
+ // incrementally update existing model
+ log.info("Build model from existing model...");
+ mlp = new MultilayerPerceptron(parameters.modelFilePath);
+ } else {
+ if (modelFs.exists(modelPath)) {
+ modelFs.delete(modelPath, true); // delete the existing file
+ }
+ log.info("Build model from scratch...");
+ mlp = new MultilayerPerceptron();
+ for (int i = 0; i < parameters.layerSizeList.size(); ++i) {
+ if (i != parameters.layerSizeList.size() - 1) {
+ mlp.addLayer(parameters.layerSizeList.get(i), false, parameters.squashingFunctionName);
+ } else {
+ mlp.addLayer(parameters.layerSizeList.get(i), true, parameters.squashingFunctionName);
+ }
+ mlp.setCostFunction("Minus_Squared");
+ mlp.setLearningRate(parameters.learningRate)
+ .setMomentumWeight(parameters.momemtumWeight)
+ .setRegularizationWeight(parameters.regularizationWeight);
+ }
+ mlp.setModelPath(parameters.modelFilePath);
+ }
+
+ // set the parameters
+ mlp.setLearningRate(parameters.learningRate)
+ .setMomentumWeight(parameters.momemtumWeight)
+ .setRegularizationWeight(parameters.regularizationWeight);
+
+ // train by the training data
+ Path trainingDataPath = new Path(parameters.inputFilePath);
+ FileSystem dataFs = trainingDataPath.getFileSystem(new Configuration());
+
+ Preconditions.checkArgument(dataFs.exists(trainingDataPath), "Training dataset %s cannot be found!",
+ parameters.inputFilePath);
+
+ log.info("Read data and train model...");
+ BufferedReader reader = null;
+
+ try {
+ reader = new BufferedReader(new InputStreamReader(dataFs.open(trainingDataPath)));
+ String line;
+
+ // read training data line by line
+ if (parameters.skipHeader) {
+ reader.readLine();
+ }
+
+ int labelDimension = parameters.labelsIndex.size();
+ while ((line = reader.readLine()) != null) {
+ String[] token = line.split(",");
+ String label = token[token.length - 1];
+ int labelIndex = parameters.labelsIndex.get(label);
+
+ double[] instances = new double[token.length - 1 + labelDimension];
+ for (int i = 0; i < token.length - 1; ++i) {
+ instances[i] = Double.parseDouble(token[i]);
+ }
+ for (int i = 0; i < labelDimension; ++i) {
+ instances[token.length - 1 + i] = 0;
+ }
+ // set the corresponding dimension
+ instances[token.length - 1 + labelIndex] = 1;
+
+ Vector instance = new DenseVector(instances).viewPart(0, instances.length);
+ mlp.trainOnline(instance);
+ }
+
+ // write model back
+ log.info("Write trained model to {}", parameters.modelFilePath);
+ mlp.writeModelToFile();
+ mlp.close();
+ } finally {
+ Closeables.close(reader, true);
+ }
+ }
+ }
+
+ /**
+ * Parse the input arguments.
+ *
+ * @param args The input arguments
+ * @param parameters The parameters parsed.
+ * @return Whether the input arguments are valid.
+ * @throws Exception
+ */
+ private static boolean parseArgs(String[] args, Parameters parameters) throws Exception {
+ // build the options
+ log.info("Validate and parse arguments...");
+ DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ GroupBuilder groupBuilder = new GroupBuilder();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+ // whether skip the first row of the input file
+ Option skipHeaderOption = optionBuilder.withLongName("skipHeader")
+ .withShortName("sh").create();
+
+ Group skipHeaderGroup = groupBuilder.withOption(skipHeaderOption).create();
+
+ Option inputOption = optionBuilder
+ .withLongName("input")
+ .withShortName("i")
+ .withRequired(true)
+ .withChildren(skipHeaderGroup)
+ .withArgument(argumentBuilder.withName("path").withMinimum(1).withMaximum(1)
+ .create()).withDescription("the file path of training dataset")
+ .create();
+
+ Option labelsOption = optionBuilder
+ .withLongName("labels")
+ .withShortName("labels")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("label-name").withMinimum(2).create())
+ .withDescription("label names").create();
+
+ Option updateOption = optionBuilder
+ .withLongName("update")
+ .withShortName("u")
+ .withDescription("whether to incrementally update model if the model exists")
+ .create();
+
+ Group modelUpdateGroup = groupBuilder.withOption(updateOption).create();
+
+ Option modelOption = optionBuilder
+ .withLongName("model")
+ .withShortName("mo")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("model-path").withMinimum(1).withMaximum(1).create())
+ .withDescription("the path to store the trained model")
+ .withChildren(modelUpdateGroup).create();
+
+ Option layerSizeOption = optionBuilder
+ .withLongName("layerSize")
+ .withShortName("ls")
+ .withRequired(true)
+ .withArgument(argumentBuilder.withName("size of layer").withMinimum(2).withMaximum(5).create())
+ .withDescription("the size of each layer").create();
+
+ Option squashingFunctionOption = optionBuilder
+ .withLongName("squashingFunction")
+ .withShortName("sf")
+ .withArgument(argumentBuilder.withName("squashing function").withMinimum(1).withMaximum(1)
+ .withDefault("Sigmoid").create())
+ .withDescription("the name of squashing function (currently only supports Sigmoid)")
+ .create();
+
+ Option learningRateOption = optionBuilder
+ .withLongName("learningRate")
+ .withShortName("l")
+ .withArgument(argumentBuilder.withName("learning rate").withMaximum(1)
+ .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_LEARNING_RATE).create())
+ .withDescription("learning rate").create();
+
+ Option momemtumOption = optionBuilder
+ .withLongName("momemtumWeight")
+ .withShortName("m")
+ .withArgument(argumentBuilder.withName("momemtum weight").withMaximum(1)
+ .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_MOMENTUM_WEIGHT).create())
+ .withDescription("momemtum weight").create();
+
+ Option regularizationOption = optionBuilder
+ .withLongName("regularizationWeight")
+ .withShortName("r")
+ .withArgument(argumentBuilder.withName("regularization weight").withMaximum(1)
+ .withMinimum(1).withDefault(NeuralNetwork.DEFAULT_REGULARIZATION_WEIGHT).create())
+ .withDescription("regularization weight").create();
+
+ // parse the input
+ Parser parser = new Parser();
+ Group normalOptions = groupBuilder.withOption(inputOption)
+ .withOption(skipHeaderOption).withOption(updateOption)
+ .withOption(labelsOption).withOption(modelOption)
+ .withOption(layerSizeOption).withOption(squashingFunctionOption)
+ .withOption(learningRateOption).withOption(momemtumOption)
+ .withOption(regularizationOption).create();
+
+ parser.setGroup(normalOptions);
+
+ CommandLine commandLine = parser.parseAndHelp(args);
+ if (commandLine == null) {
+ return false;
+ }
+
+ parameters.learningRate = getDouble(commandLine, learningRateOption);
+ parameters.momemtumWeight = getDouble(commandLine, momemtumOption);
+ parameters.regularizationWeight = getDouble(commandLine, regularizationOption);
+
+ parameters.inputFilePath = getString(commandLine, inputOption);
+ parameters.skipHeader = commandLine.hasOption(skipHeaderOption);
+
+ List<String> labelsList = getStringList(commandLine, labelsOption);
+ int currentIndex = 0;
+ for (String label : labelsList) {
+ parameters.labelsIndex.put(label, currentIndex++);
+ }
+
+ parameters.modelFilePath = getString(commandLine, modelOption);
+ parameters.updateModel = commandLine.hasOption(updateOption);
+
+ parameters.layerSizeList = getIntegerList(commandLine, layerSizeOption);
+
+ parameters.squashingFunctionName = getString(commandLine, squashingFunctionOption);
+
+ System.out.printf("Input: %s, Model: %s, Update: %s, Layer size: %s, Squashing function: %s, Learning rate: %f," +
+ " Momemtum weight: %f, Regularization Weight: %f\n", parameters.inputFilePath, parameters.modelFilePath,
+ parameters.updateModel, Arrays.toString(parameters.layerSizeList.toArray()),
+ parameters.squashingFunctionName, parameters.learningRate, parameters.momemtumWeight,
+ parameters.regularizationWeight);
+
+ return true;
+ }
+
+ static Double getDouble(CommandLine commandLine, Option option) {
+ Object val = commandLine.getValue(option);
+ if (val != null) {
+ return Double.parseDouble(val.toString());
+ }
+ return null;
+ }
+
+ static String getString(CommandLine commandLine, Option option) {
+ Object val = commandLine.getValue(option);
+ if (val != null) {
+ return val.toString();
+ }
+ return null;
+ }
+
+ static List<Integer> getIntegerList(CommandLine commandLine, Option option) {
+ List<String> list = commandLine.getValues(option);
+ List<Integer> valList = Lists.newArrayList();
+ for (String str : list) {
+ valList.add(Integer.parseInt(str));
+ }
+ return valList;
+ }
+
+ static List<String> getStringList(CommandLine commandLine, Option option) {
+ return commandLine.getValues(option);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
new file mode 100644
index 0000000..f0794b3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/AbstractNaiveBayesClassifier.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * Class implementing the Naive Bayes Classifier Algorithm. Note that this class
+ * supports {@link #classifyFull}, but not {@code classify} or
+ * {@code classifyScalar}. The reason that these two methods are not
+ * supported is because the scores computed by a NaiveBayesClassifier do not
+ * represent probabilities.
+ */
+public abstract class AbstractNaiveBayesClassifier extends AbstractVectorClassifier {
+
+ private final NaiveBayesModel model;
+
+ protected AbstractNaiveBayesClassifier(NaiveBayesModel model) {
+ this.model = model;
+ }
+
+ protected NaiveBayesModel getModel() {
+ return model;
+ }
+
+ protected abstract double getScoreForLabelFeature(int label, int feature);
+
+ protected double getScoreForLabelInstance(int label, Vector instance) {
+ double result = 0.0;
+ for (Element e : instance.nonZeroes()) {
+ result += e.get() * getScoreForLabelFeature(label, e.index());
+ }
+ return result;
+ }
+
+ @Override
+ public int numCategories() {
+ return model.numLabels();
+ }
+
+ @Override
+ public Vector classifyFull(Vector instance) {
+ return classifyFull(model.createScoringVector(), instance);
+ }
+
+ @Override
+ public Vector classifyFull(Vector r, Vector instance) {
+ for (int label = 0; label < model.numLabels(); label++) {
+ r.setQuick(label, getScoreForLabelInstance(label, instance));
+ }
+ return r;
+ }
+
+ /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */
+ @Override
+ public double classifyScalar(Vector instance) {
+ throw new UnsupportedOperationException("Not supported in Naive Bayes");
+ }
+
+ /** Unsupported method. This implementation simply throws an {@link UnsupportedOperationException}. */
+ @Override
+ public Vector classify(Vector instance) {
+ throw new UnsupportedOperationException("probabilites not supported in Naive Bayes");
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
new file mode 100644
index 0000000..1e5171c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/BayesUtils.java
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.naivebayes.training.ThetaMapper;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import com.google.common.io.Closeables;
+
+public final class BayesUtils {
+
+ private static final Pattern SLASH = Pattern.compile("/");
+
+ private BayesUtils() {}
+
+ public static NaiveBayesModel readModelFromDir(Path base, Configuration conf) {
+
+ float alphaI = conf.getFloat(ThetaMapper.ALPHA_I, 1.0f);
+ boolean isComplementary = conf.getBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, true);
+
+ // read feature sums and label sums
+ Vector scoresPerLabel = null;
+ Vector scoresPerFeature = null;
+ for (Pair<Text,VectorWritable> record : new SequenceFileDirIterable<Text, VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.WEIGHTS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ String key = record.getFirst().toString();
+ VectorWritable value = record.getSecond();
+ if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE)) {
+ scoresPerFeature = value.get();
+ } else if (key.equals(TrainNaiveBayesJob.WEIGHTS_PER_LABEL)) {
+ scoresPerLabel = value.get();
+ }
+ }
+
+ Preconditions.checkNotNull(scoresPerFeature);
+ Preconditions.checkNotNull(scoresPerLabel);
+
+ Matrix scoresPerLabelAndFeature = new SparseMatrix(scoresPerLabel.size(), scoresPerFeature.size());
+ for (Pair<IntWritable,VectorWritable> entry : new SequenceFileDirIterable<IntWritable,VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.SUMMED_OBSERVATIONS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ scoresPerLabelAndFeature.assignRow(entry.getFirst().get(), entry.getSecond().get());
+ }
+
+ // perLabelThetaNormalizer is only used by the complementary model, we do not instantiate it for the standard model
+ Vector perLabelThetaNormalizer = null;
+ if (isComplementary) {
+ perLabelThetaNormalizer=scoresPerLabel.like();
+ for (Pair<Text,VectorWritable> entry : new SequenceFileDirIterable<Text,VectorWritable>(
+ new Path(base, TrainNaiveBayesJob.THETAS), PathType.LIST, PathFilters.partFilter(), conf)) {
+ if (entry.getFirst().toString().equals(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER)) {
+ perLabelThetaNormalizer = entry.getSecond().get();
+ }
+ }
+ Preconditions.checkNotNull(perLabelThetaNormalizer);
+ }
+
+ return new NaiveBayesModel(scoresPerLabelAndFeature, scoresPerFeature, scoresPerLabel, perLabelThetaNormalizer,
+ alphaI, isComplementary);
+ }
+
+ /** Write the list of labels into a map file */
+ public static int writeLabelIndex(Configuration conf, Iterable<String> labels, Path indexPath)
+ throws IOException {
+ FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class);
+ int i = 0;
+ try {
+ for (String label : labels) {
+ writer.append(new Text(label), new IntWritable(i++));
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ return i;
+ }
+
+ public static int writeLabelIndex(Configuration conf, Path indexPath,
+ Iterable<Pair<Text,IntWritable>> labels) throws IOException {
+ FileSystem fs = FileSystem.get(indexPath.toUri(), conf);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, indexPath, Text.class, IntWritable.class);
+ Collection<String> seen = Sets.newHashSet();
+ int i = 0;
+ try {
+ for (Object label : labels) {
+ String theLabel = SLASH.split(((Pair<?, ?>) label).getFirst().toString())[1];
+ if (!seen.contains(theLabel)) {
+ writer.append(new Text(theLabel), new IntWritable(i++));
+ seen.add(theLabel);
+ }
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ return i;
+ }
+
+ public static Map<Integer, String> readLabelIndex(Configuration conf, Path indexPath) {
+ Map<Integer, String> labelMap = new HashMap<>();
+ for (Pair<Text, IntWritable> pair : new SequenceFileIterable<Text, IntWritable>(indexPath, true, conf)) {
+ labelMap.put(pair.getSecond().get(), pair.getFirst().toString());
+ }
+ return labelMap;
+ }
+
+ public static OpenObjectIntHashMap<String> readIndexFromCache(Configuration conf) throws IOException {
+ OpenObjectIntHashMap<String> index = new OpenObjectIntHashMap<>();
+ for (Pair<Writable,IntWritable> entry
+ : new SequenceFileIterable<Writable,IntWritable>(HadoopUtil.getSingleCachedFile(conf), conf)) {
+ index.put(entry.getFirst().toString(), entry.getSecond().get());
+ }
+ return index;
+ }
+
+ public static Map<String,Vector> readScoresFromCache(Configuration conf) throws IOException {
+ Map<String,Vector> sumVectors = Maps.newHashMap();
+ for (Pair<Text,VectorWritable> entry
+ : new SequenceFileDirIterable<Text,VectorWritable>(HadoopUtil.getSingleCachedFile(conf),
+ PathType.LIST, PathFilters.partFilter(), conf)) {
+ sumVectors.put(entry.getFirst().toString(), entry.getSecond().get());
+ }
+ return sumVectors;
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
new file mode 100644
index 0000000..18bd3d6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifier.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/** Implementation of the Naive Bayes Classifier Algorithm */
+public class ComplementaryNaiveBayesClassifier extends AbstractNaiveBayesClassifier {
+ public ComplementaryNaiveBayesClassifier(NaiveBayesModel model) {
+ super(model);
+ }
+
+ @Override
+ public double getScoreForLabelFeature(int label, int feature) {
+ NaiveBayesModel model = getModel();
+ double weight = computeWeight(model.featureWeight(feature), model.weight(label, feature),
+ model.totalWeightSum(), model.labelWeight(label), model.alphaI(), model.numFeatures());
+ // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+ return weight / model.thetaNormalizer(label);
+ }
+
+ // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.1, Skewed Data bias
+ public static double computeWeight(double featureWeight, double featureLabelWeight,
+ double totalWeight, double labelWeight, double alphaI, double numFeatures) {
+ double numerator = featureWeight - featureLabelWeight + alphaI;
+ double denominator = totalWeight - labelWeight + alphaI * numFeatures;
+ return -Math.log(numerator / denominator);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
new file mode 100644
index 0000000..f180e8b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModel.java
@@ -0,0 +1,176 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+
+/** NaiveBayesModel holds the weight matrix, the feature and label sums and the weight normalizer vectors.*/
+public class NaiveBayesModel {
+
+ private final Vector weightsPerLabel;
+ private final Vector perlabelThetaNormalizer;
+ private final Vector weightsPerFeature;
+ private final Matrix weightsPerLabelAndFeature;
+ private final float alphaI;
+ private final double numFeatures;
+ private final double totalWeightSum;
+ private final boolean isComplementary;
+
+ public final static String COMPLEMENTARY_MODEL = "COMPLEMENTARY_MODEL";
+
+ public NaiveBayesModel(Matrix weightMatrix, Vector weightsPerFeature, Vector weightsPerLabel, Vector thetaNormalizer,
+ float alphaI, boolean isComplementary) {
+ this.weightsPerLabelAndFeature = weightMatrix;
+ this.weightsPerFeature = weightsPerFeature;
+ this.weightsPerLabel = weightsPerLabel;
+ this.perlabelThetaNormalizer = thetaNormalizer;
+ this.numFeatures = weightsPerFeature.getNumNondefaultElements();
+ this.totalWeightSum = weightsPerLabel.zSum();
+ this.alphaI = alphaI;
+ this.isComplementary=isComplementary;
+ }
+
+ public double labelWeight(int label) {
+ return weightsPerLabel.getQuick(label);
+ }
+
+ public double thetaNormalizer(int label) {
+ return perlabelThetaNormalizer.get(label);
+ }
+
+ public double featureWeight(int feature) {
+ return weightsPerFeature.getQuick(feature);
+ }
+
+ public double weight(int label, int feature) {
+ return weightsPerLabelAndFeature.getQuick(label, feature);
+ }
+
+ public float alphaI() {
+ return alphaI;
+ }
+
+ public double numFeatures() {
+ return numFeatures;
+ }
+
+ public double totalWeightSum() {
+ return totalWeightSum;
+ }
+
+ public int numLabels() {
+ return weightsPerLabel.size();
+ }
+
+ public Vector createScoringVector() {
+ return weightsPerLabel.like();
+ }
+
+ public boolean isComplemtary(){
+ return isComplementary;
+ }
+
+ public static NaiveBayesModel materialize(Path output, Configuration conf) throws IOException {
+ FileSystem fs = output.getFileSystem(conf);
+
+ Vector weightsPerLabel = null;
+ Vector perLabelThetaNormalizer = null;
+ Vector weightsPerFeature = null;
+ Matrix weightsPerLabelAndFeature;
+ float alphaI;
+ boolean isComplementary;
+
+ FSDataInputStream in = fs.open(new Path(output, "naiveBayesModel.bin"));
+ try {
+ alphaI = in.readFloat();
+ isComplementary = in.readBoolean();
+ weightsPerFeature = VectorWritable.readVector(in);
+ weightsPerLabel = new DenseVector(VectorWritable.readVector(in));
+ if (isComplementary){
+ perLabelThetaNormalizer = new DenseVector(VectorWritable.readVector(in));
+ }
+ weightsPerLabelAndFeature = new SparseRowMatrix(weightsPerLabel.size(), weightsPerFeature.size());
+ for (int label = 0; label < weightsPerLabelAndFeature.numRows(); label++) {
+ weightsPerLabelAndFeature.assignRow(label, VectorWritable.readVector(in));
+ }
+ } finally {
+ Closeables.close(in, true);
+ }
+ NaiveBayesModel model = new NaiveBayesModel(weightsPerLabelAndFeature, weightsPerFeature, weightsPerLabel,
+ perLabelThetaNormalizer, alphaI, isComplementary);
+ model.validate();
+ return model;
+ }
+
+ public void serialize(Path output, Configuration conf) throws IOException {
+ FileSystem fs = output.getFileSystem(conf);
+ FSDataOutputStream out = fs.create(new Path(output, "naiveBayesModel.bin"));
+ try {
+ out.writeFloat(alphaI);
+ out.writeBoolean(isComplementary);
+ VectorWritable.writeVector(out, weightsPerFeature);
+ VectorWritable.writeVector(out, weightsPerLabel);
+ if (isComplementary){
+ VectorWritable.writeVector(out, perlabelThetaNormalizer);
+ }
+ for (int row = 0; row < weightsPerLabelAndFeature.numRows(); row++) {
+ VectorWritable.writeVector(out, weightsPerLabelAndFeature.viewRow(row));
+ }
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public void validate() {
+ Preconditions.checkState(alphaI > 0, "alphaI has to be greater than 0!");
+ Preconditions.checkArgument(numFeatures > 0, "the vocab count has to be greater than 0!");
+ Preconditions.checkArgument(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!");
+ Preconditions.checkNotNull(weightsPerLabel, "the number of labels has to be defined!");
+ Preconditions.checkArgument(weightsPerLabel.getNumNondefaultElements() > 0,
+ "the number of labels has to be greater than 0!");
+ Preconditions.checkNotNull(weightsPerFeature, "the feature sums have to be defined");
+ Preconditions.checkArgument(weightsPerFeature.getNumNondefaultElements() > 0,
+ "the feature sums have to be greater than 0!");
+ if (isComplementary){
+ Preconditions.checkArgument(perlabelThetaNormalizer != null, "the theta normalizers have to be defined");
+ Preconditions.checkArgument(perlabelThetaNormalizer.getNumNondefaultElements() > 0,
+ "the number of theta normalizers has to be greater than 0!");
+ Preconditions.checkArgument(Math.signum(perlabelThetaNormalizer.minValue())
+ == Math.signum(perlabelThetaNormalizer.maxValue()),
+ "Theta normalizers do not all have the same sign");
+ Preconditions.checkArgument(perlabelThetaNormalizer.getNumNonZeroElements()
+ == perlabelThetaNormalizer.size(),
+ "Theta normalizers can not have zero value.");
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
new file mode 100644
index 0000000..e4ce8aa
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifier.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+
+/** Implementation of the Naive Bayes Classifier Algorithm */
+public class StandardNaiveBayesClassifier extends AbstractNaiveBayesClassifier {
+
+ public StandardNaiveBayesClassifier(NaiveBayesModel model) {
+ super(model);
+ }
+
+ @Override
+ public double getScoreForLabelFeature(int label, int feature) {
+ NaiveBayesModel model = getModel();
+ // Standard Naive Bayes does not use weight normalization
+ return computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI(), model.numFeatures());
+ }
+
+ public static double computeWeight(double featureLabelWeight, double labelWeight, double alphaI, double numFeatures) {
+ double numerator = featureLabelWeight + alphaI;
+ double denominator = labelWeight + alphaI * numFeatures;
+ return Math.log(numerator / denominator);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
new file mode 100644
index 0000000..37a3b71
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/BayesTestMapper.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.test;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+/**
+ * Run the input through the model and see if it matches.
+ * <p/>
+ * The output value is the generated label, the Pair is the expected label and true if they match:
+ */
+public class BayesTestMapper extends Mapper<Text, VectorWritable, Text, VectorWritable> {
+
+ private static final Pattern SLASH = Pattern.compile("/");
+
+ private AbstractNaiveBayesClassifier classifier;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ Path modelPath = HadoopUtil.getSingleCachedFile(conf);
+ NaiveBayesModel model = NaiveBayesModel.materialize(modelPath, conf);
+ boolean isComplementary = Boolean.parseBoolean(conf.get(TestNaiveBayesDriver.COMPLEMENTARY));
+
+ // ensure that if we are testing in complementary mode, the model has been
+ // trained complementary. a complementarty model will work for standard classification
+ // a standard model will not work for complementary classification
+ if (isComplementary) {
+ Preconditions.checkArgument((model.isComplemtary()),
+ "Complementary mode in model is different than test mode");
+ }
+
+ if (isComplementary) {
+ classifier = new ComplementaryNaiveBayesClassifier(model);
+ } else {
+ classifier = new StandardNaiveBayesClassifier(model);
+ }
+ }
+
+ @Override
+ protected void map(Text key, VectorWritable value, Context context) throws IOException, InterruptedException {
+ Vector result = classifier.classifyFull(value.get());
+ //the key is the expected value
+ context.write(new Text(SLASH.split(key.toString())[1]), new VectorWritable(result));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
new file mode 100644
index 0000000..8fd422f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/test/TestNaiveBayesDriver.java
@@ -0,0 +1,179 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.test;
+
+import com.google.common.base.Preconditions;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.ClassifierResult;
+import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.classifier.naivebayes.StandardNaiveBayesClassifier;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Test the (Complementary) Naive Bayes model that was built during training
+ * by running the iterating the test set and comparing it to the model
+ */
+public class TestNaiveBayesDriver extends AbstractJob {
+
+ private static final Logger log = LoggerFactory.getLogger(TestNaiveBayesDriver.class);
+
+ public static final String COMPLEMENTARY = "class"; //b for bayes, c for complementary
+ private static final Pattern SLASH = Pattern.compile("/");
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new TestNaiveBayesDriver(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption(addOption(DefaultOptionCreator.overwriteOption().create()));
+ addOption("model", "m", "The path to the model built during training", true);
+ addOption(buildOption("testComplementary", "c", "test complementary?", false, false, String.valueOf(false)));
+ addOption(buildOption("runSequential", "seq", "run sequential?", false, false, String.valueOf(false)));
+ addOption("labelIndex", "l", "The path to the location of the label index", true);
+ Map<String, List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), getOutputPath());
+ }
+
+ boolean sequential = hasOption("runSequential");
+ boolean succeeded;
+ if (sequential) {
+ runSequential();
+ } else {
+ succeeded = runMapReduce();
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ //load the labels
+ Map<Integer, String> labelMap = BayesUtils.readLabelIndex(getConf(), new Path(getOption("labelIndex")));
+
+ //loop over the results and create the confusion matrix
+ SequenceFileDirIterable<Text, VectorWritable> dirIterable =
+ new SequenceFileDirIterable<>(getOutputPath(), PathType.LIST, PathFilters.partFilter(), getConf());
+ ResultAnalyzer analyzer = new ResultAnalyzer(labelMap.values(), "DEFAULT");
+ analyzeResults(labelMap, dirIterable, analyzer);
+
+ log.info("{} Results: {}", hasOption("testComplementary") ? "Complementary" : "Standard NB", analyzer);
+ return 0;
+ }
+
+ private void runSequential() throws IOException {
+ boolean complementary = hasOption("testComplementary");
+ FileSystem fs = FileSystem.get(getConf());
+ NaiveBayesModel model = NaiveBayesModel.materialize(new Path(getOption("model")), getConf());
+
+ // Ensure that if we are testing in complementary mode, the model has been
+ // trained complementary. a complementarty model will work for standard classification
+ // a standard model will not work for complementary classification
+ if (complementary){
+ Preconditions.checkArgument((model.isComplemtary()),
+ "Complementary mode in model is different from test mode");
+ }
+
+ AbstractNaiveBayesClassifier classifier;
+ if (complementary) {
+ classifier = new ComplementaryNaiveBayesClassifier(model);
+ } else {
+ classifier = new StandardNaiveBayesClassifier(model);
+ }
+ SequenceFile.Writer writer = SequenceFile.createWriter(fs, getConf(), new Path(getOutputPath(), "part-r-00000"),
+ Text.class, VectorWritable.class);
+
+ try {
+ SequenceFileDirIterable<Text, VectorWritable> dirIterable =
+ new SequenceFileDirIterable<>(getInputPath(), PathType.LIST, PathFilters.partFilter(), getConf());
+ // loop through the part-r-* files in getInputPath() and get classification scores for all entries
+ for (Pair<Text, VectorWritable> pair : dirIterable) {
+ writer.append(new Text(SLASH.split(pair.getFirst().toString())[1]),
+ new VectorWritable(classifier.classifyFull(pair.getSecond().get())));
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ private boolean runMapReduce() throws IOException,
+ InterruptedException, ClassNotFoundException {
+ Path model = new Path(getOption("model"));
+ HadoopUtil.cacheFiles(model, getConf());
+ //the output key is the expected value, the output value are the scores for all the labels
+ Job testJob = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class, BayesTestMapper.class,
+ Text.class, VectorWritable.class, SequenceFileOutputFormat.class);
+ //testJob.getConfiguration().set(LABEL_KEY, getOption("--labels"));
+
+
+ boolean complementary = hasOption("testComplementary");
+ testJob.getConfiguration().set(COMPLEMENTARY, String.valueOf(complementary));
+ return testJob.waitForCompletion(true);
+ }
+
+ private static void analyzeResults(Map<Integer, String> labelMap,
+ SequenceFileDirIterable<Text, VectorWritable> dirIterable,
+ ResultAnalyzer analyzer) {
+ for (Pair<Text, VectorWritable> pair : dirIterable) {
+ int bestIdx = Integer.MIN_VALUE;
+ double bestScore = Long.MIN_VALUE;
+ for (Vector.Element element : pair.getSecond().get().all()) {
+ if (element.get() > bestScore) {
+ bestScore = element.get();
+ bestIdx = element.index();
+ }
+ }
+ if (bestIdx != Integer.MIN_VALUE) {
+ ClassifierResult classifierResult = new ClassifierResult(labelMap.get(bestIdx), bestScore);
+ analyzer.addInstance(pair.getFirst().toString(), classifierResult);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
new file mode 100644
index 0000000..2b8ee1e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ComplementaryThetaTrainer.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
+import org.apache.mahout.math.Vector;
+
+public class ComplementaryThetaTrainer {
+
+ private final Vector weightsPerFeature;
+ private final Vector weightsPerLabel;
+ private final Vector perLabelThetaNormalizer;
+ private final double alphaI;
+ private final double totalWeightSum;
+ private final double numFeatures;
+
+ public ComplementaryThetaTrainer(Vector weightsPerFeature, Vector weightsPerLabel, double alphaI) {
+ Preconditions.checkNotNull(weightsPerFeature);
+ Preconditions.checkNotNull(weightsPerLabel);
+ this.weightsPerFeature = weightsPerFeature;
+ this.weightsPerLabel = weightsPerLabel;
+ this.alphaI = alphaI;
+ perLabelThetaNormalizer = weightsPerLabel.like();
+ totalWeightSum = weightsPerLabel.zSum();
+ numFeatures = weightsPerFeature.getNumNondefaultElements();
+ }
+
+ public void train(int label, Vector perLabelWeight) {
+ double labelWeight = labelWeight(label);
+ // sum weights for each label including those with zero word counts
+ for(int i = 0; i < perLabelWeight.size(); i++){
+ Vector.Element perLabelWeightElement = perLabelWeight.getElement(i);
+ updatePerLabelThetaNormalizer(label,
+ ComplementaryNaiveBayesClassifier.computeWeight(featureWeight(perLabelWeightElement.index()),
+ perLabelWeightElement.get(), totalWeightSum(), labelWeight, alphaI(), numFeatures()));
+ }
+ }
+
+ protected double alphaI() {
+ return alphaI;
+ }
+
+ protected double numFeatures() {
+ return numFeatures;
+ }
+
+ protected double labelWeight(int label) {
+ return weightsPerLabel.get(label);
+ }
+
+ protected double totalWeightSum() {
+ return totalWeightSum;
+ }
+
+ protected double featureWeight(int feature) {
+ return weightsPerFeature.get(feature);
+ }
+
+ // http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+ protected void updatePerLabelThetaNormalizer(int label, double weight) {
+ perLabelThetaNormalizer.set(label, perLabelThetaNormalizer.get(label) + Math.abs(weight));
+ }
+
+ public Vector retrievePerLabelThetaNormalizer() {
+ return perLabelThetaNormalizer.clone();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
new file mode 100644
index 0000000..40ca2e9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapper.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+public class IndexInstancesMapper extends Mapper<Text, VectorWritable, IntWritable, VectorWritable> {
+
+ private static final Pattern SLASH = Pattern.compile("/");
+
+ public enum Counter { SKIPPED_INSTANCES }
+
+ private OpenObjectIntHashMap<String> labelIndex;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
+ labelIndex = BayesUtils.readIndexFromCache(ctx.getConfiguration());
+ }
+
+ @Override
+ protected void map(Text labelText, VectorWritable instance, Context ctx) throws IOException, InterruptedException {
+ String label = SLASH.split(labelText.toString())[1];
+ if (labelIndex.containsKey(label)) {
+ ctx.write(new IntWritable(labelIndex.get(label)), instance);
+ } else {
+ ctx.getCounter(Counter.SKIPPED_INSTANCES).increment(1);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
new file mode 100644
index 0000000..ff2ea40
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapper.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public class ThetaMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+ public static final String ALPHA_I = ThetaMapper.class.getName() + ".alphaI";
+ static final String TRAIN_COMPLEMENTARY = ThetaMapper.class.getName() + ".trainComplementary";
+
+ private ComplementaryThetaTrainer trainer;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
+ Configuration conf = ctx.getConfiguration();
+
+ float alphaI = conf.getFloat(ALPHA_I, 1.0f);
+ Map<String, Vector> scores = BayesUtils.readScoresFromCache(conf);
+
+ trainer = new ComplementaryThetaTrainer(scores.get(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
+ scores.get(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), alphaI);
+ }
+
+ @Override
+ protected void map(IntWritable key, VectorWritable value, Context ctx) throws IOException, InterruptedException {
+ trainer.train(key.get(), value.get());
+ }
+
+ @Override
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER),
+ new VectorWritable(trainer.retrievePerLabelThetaNormalizer()));
+ super.cleanup(ctx);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
new file mode 100644
index 0000000..ac1c4c9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/TrainNaiveBayesJob.java
@@ -0,0 +1,186 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.classifier.naivebayes.BayesUtils;
+import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.mapreduce.VectorSumReducer;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.base.Splitter;
+
+/** Trains a Naive Bayes Classifier (parameters for both Naive Bayes and Complementary Naive Bayes) */
+public final class TrainNaiveBayesJob extends AbstractJob {
+ private static final String TRAIN_COMPLEMENTARY = "trainComplementary";
+ private static final String ALPHA_I = "alphaI";
+ private static final String LABEL_INDEX = "labelIndex";
+ private static final String EXTRACT_LABELS = "extractLabels";
+ private static final String LABELS = "labels";
+ public static final String WEIGHTS_PER_FEATURE = "__SPF";
+ public static final String WEIGHTS_PER_LABEL = "__SPL";
+ public static final String LABEL_THETA_NORMALIZER = "_LTN";
+
+ public static final String SUMMED_OBSERVATIONS = "summedObservations";
+ public static final String WEIGHTS = "weights";
+ public static final String THETAS = "thetas";
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new TrainNaiveBayesJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption(LABELS, "l", "comma-separated list of labels to include in training", false);
+
+ addOption(buildOption(EXTRACT_LABELS, "el", "Extract the labels from the input", false, false, ""));
+ addOption(ALPHA_I, "a", "smoothing parameter", String.valueOf(1.0f));
+ addOption(buildOption(TRAIN_COMPLEMENTARY, "c", "train complementary?", false, false, String.valueOf(false)));
+ addOption(LABEL_INDEX, "li", "The path to store the label index in", false);
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ Map<String, List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), getOutputPath());
+ HadoopUtil.delete(getConf(), getTempPath());
+ }
+ Path labPath;
+ String labPathStr = getOption(LABEL_INDEX);
+ if (labPathStr != null) {
+ labPath = new Path(labPathStr);
+ } else {
+ labPath = getTempPath(LABEL_INDEX);
+ }
+ long labelSize = createLabelIndex(labPath);
+ float alphaI = Float.parseFloat(getOption(ALPHA_I));
+ boolean trainComplementary = hasOption(TRAIN_COMPLEMENTARY);
+
+ HadoopUtil.setSerializations(getConf());
+ HadoopUtil.cacheFiles(labPath, getConf());
+
+ // Add up all the vectors with the same labels, while mapping the labels into our index
+ Job indexInstances = prepareJob(getInputPath(),
+ getTempPath(SUMMED_OBSERVATIONS),
+ SequenceFileInputFormat.class,
+ IndexInstancesMapper.class,
+ IntWritable.class,
+ VectorWritable.class,
+ VectorSumReducer.class,
+ IntWritable.class,
+ VectorWritable.class,
+ SequenceFileOutputFormat.class);
+ indexInstances.setCombinerClass(VectorSumReducer.class);
+ boolean succeeded = indexInstances.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ // Sum up all the weights from the previous step, per label and per feature
+ Job weightSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),
+ getTempPath(WEIGHTS),
+ SequenceFileInputFormat.class,
+ WeightsMapper.class,
+ Text.class,
+ VectorWritable.class,
+ VectorSumReducer.class,
+ Text.class,
+ VectorWritable.class,
+ SequenceFileOutputFormat.class);
+ weightSummer.getConfiguration().set(WeightsMapper.NUM_LABELS, String.valueOf(labelSize));
+ weightSummer.setCombinerClass(VectorSumReducer.class);
+ succeeded = weightSummer.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ // Put the per label and per feature vectors into the cache
+ HadoopUtil.cacheFiles(getTempPath(WEIGHTS), getConf());
+
+ if (trainComplementary){
+ // Calculate the per label theta normalizers, write out to LABEL_THETA_NORMALIZER vector
+ // see http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - Section 3.2, Weight Magnitude Errors
+ Job thetaSummer = prepareJob(getTempPath(SUMMED_OBSERVATIONS),
+ getTempPath(THETAS),
+ SequenceFileInputFormat.class,
+ ThetaMapper.class,
+ Text.class,
+ VectorWritable.class,
+ VectorSumReducer.class,
+ Text.class,
+ VectorWritable.class,
+ SequenceFileOutputFormat.class);
+ thetaSummer.setCombinerClass(VectorSumReducer.class);
+ thetaSummer.getConfiguration().setFloat(ThetaMapper.ALPHA_I, alphaI);
+ thetaSummer.getConfiguration().setBoolean(ThetaMapper.TRAIN_COMPLEMENTARY, trainComplementary);
+ succeeded = thetaSummer.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ // Put the per label theta normalizers into the cache
+ HadoopUtil.cacheFiles(getTempPath(THETAS), getConf());
+
+ // Validate our model and then write it out to the official output
+ getConf().setFloat(ThetaMapper.ALPHA_I, alphaI);
+ getConf().setBoolean(NaiveBayesModel.COMPLEMENTARY_MODEL, trainComplementary);
+ NaiveBayesModel naiveBayesModel = BayesUtils.readModelFromDir(getTempPath(), getConf());
+ naiveBayesModel.validate();
+ naiveBayesModel.serialize(getOutputPath(), getConf());
+
+ return 0;
+ }
+
+ private long createLabelIndex(Path labPath) throws IOException {
+ long labelSize = 0;
+ if (hasOption(LABELS)) {
+ Iterable<String> labels = Splitter.on(",").split(getOption(LABELS));
+ labelSize = BayesUtils.writeLabelIndex(getConf(), labels, labPath);
+ } else if (hasOption(EXTRACT_LABELS)) {
+ Iterable<Pair<Text,IntWritable>> iterable =
+ new SequenceFileDirIterable<Text, IntWritable>(getInputPath(),
+ PathType.LIST,
+ PathFilters.logsCRCFilter(),
+ getConf());
+ labelSize = BayesUtils.writeLabelIndex(getConf(), labPath, iterable);
+ }
+ return labelSize;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
new file mode 100644
index 0000000..5563057
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapper.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import com.google.common.base.Preconditions;
+
+public class WeightsMapper extends Mapper<IntWritable, VectorWritable, Text, VectorWritable> {
+
+ static final String NUM_LABELS = WeightsMapper.class.getName() + ".numLabels";
+
+ private Vector weightsPerFeature;
+ private Vector weightsPerLabel;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ super.setup(ctx);
+ int numLabels = Integer.parseInt(ctx.getConfiguration().get(NUM_LABELS));
+ Preconditions.checkArgument(numLabels > 0, "Wrong numLabels: " + numLabels + ". Must be > 0!");
+ weightsPerLabel = new DenseVector(numLabels);
+ }
+
+ @Override
+ protected void map(IntWritable index, VectorWritable value, Context ctx) throws IOException, InterruptedException {
+ Vector instance = value.get();
+ if (weightsPerFeature == null) {
+ weightsPerFeature = new RandomAccessSparseVector(instance.size(), instance.getNumNondefaultElements());
+ }
+
+ int label = index.get();
+ weightsPerFeature.assign(instance, Functions.PLUS);
+ weightsPerLabel.set(label, weightsPerLabel.get(label) + instance.zSum());
+ }
+
+ @Override
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ if (weightsPerFeature != null) {
+ ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), new VectorWritable(weightsPerFeature));
+ ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), new VectorWritable(weightsPerLabel));
+ }
+ super.cleanup(ctx);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
new file mode 100644
index 0000000..942a101
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/BaumWelchTrainer.java
@@ -0,0 +1,165 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.DataOutputStream;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.util.Date;
+import java.util.List;
+import java.util.Scanner;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+
+/**
+ * A class for EM training of HMM from console
+ */
+public final class BaumWelchTrainer {
+
+ private BaumWelchTrainer() {
+ }
+
+ public static void main(String[] args) throws IOException {
+ DefaultOptionBuilder optionBuilder = new DefaultOptionBuilder();
+ ArgumentBuilder argumentBuilder = new ArgumentBuilder();
+
+ Option inputOption = DefaultOptionCreator.inputOption().create();
+
+ Option outputOption = DefaultOptionCreator.outputOption().create();
+
+ Option stateNumberOption = optionBuilder.withLongName("nrOfHiddenStates").
+ withDescription("Number of hidden states").
+ withShortName("nh").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ Option observedStateNumberOption = optionBuilder.withLongName("nrOfObservedStates").
+ withDescription("Number of observed states").
+ withShortName("no").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ Option epsilonOption = optionBuilder.withLongName("epsilon").
+ withDescription("Convergence threshold").
+ withShortName("e").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ Option iterationsOption = optionBuilder.withLongName("max-iterations").
+ withDescription("Maximum iterations number").
+ withShortName("m").withArgument(argumentBuilder.withMaximum(1).withMinimum(1).
+ withName("number").create()).withRequired(true).create();
+
+ Group optionGroup = new GroupBuilder().withOption(inputOption).
+ withOption(outputOption).withOption(stateNumberOption).withOption(observedStateNumberOption).
+ withOption(epsilonOption).withOption(iterationsOption).
+ withName("Options").create();
+
+ try {
+ Parser parser = new Parser();
+ parser.setGroup(optionGroup);
+ CommandLine commandLine = parser.parse(args);
+
+ String input = (String) commandLine.getValue(inputOption);
+ String output = (String) commandLine.getValue(outputOption);
+
+ int nrOfHiddenStates = Integer.parseInt((String) commandLine.getValue(stateNumberOption));
+ int nrOfObservedStates = Integer.parseInt((String) commandLine.getValue(observedStateNumberOption));
+
+ double epsilon = Double.parseDouble((String) commandLine.getValue(epsilonOption));
+ int maxIterations = Integer.parseInt((String) commandLine.getValue(iterationsOption));
+
+ //constructing random-generated HMM
+ HmmModel model = new HmmModel(nrOfHiddenStates, nrOfObservedStates, new Date().getTime());
+ List<Integer> observations = Lists.newArrayList();
+
+ //reading observations
+ try (Scanner scanner = new Scanner(new FileInputStream(input), "UTF-8")) {
+ while (scanner.hasNextInt()) {
+ observations.add(scanner.nextInt());
+ }
+ }
+
+ int[] observationsArray = new int[observations.size()];
+ for (int i = 0; i < observations.size(); ++i) {
+ observationsArray[i] = observations.get(i);
+ }
+
+ //training
+ HmmModel trainedModel = HmmTrainer.trainBaumWelch(model,
+ observationsArray, epsilon, maxIterations, true);
+
+ //serializing trained model
+ DataOutputStream stream = new DataOutputStream(new FileOutputStream(output));
+ try {
+ LossyHmmSerializer.serialize(trainedModel, stream);
+ } finally {
+ Closeables.close(stream, false);
+ }
+
+ //printing tranied model
+ System.out.println("Initial probabilities: ");
+ for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+ System.out.print(i + " ");
+ }
+ System.out.println();
+ for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+ System.out.print(trainedModel.getInitialProbabilities().get(i) + " ");
+ }
+ System.out.println();
+
+ System.out.println("Transition matrix:");
+ System.out.print(" ");
+ for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+ System.out.print(i + " ");
+ }
+ System.out.println();
+ for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+ System.out.print(i + " ");
+ for (int j = 0; j < trainedModel.getNrOfHiddenStates(); ++j) {
+ System.out.print(trainedModel.getTransitionMatrix().get(i, j) + " ");
+ }
+ System.out.println();
+ }
+ System.out.println("Emission matrix: ");
+ System.out.print(" ");
+ for (int i = 0; i < trainedModel.getNrOfOutputStates(); ++i) {
+ System.out.print(i + " ");
+ }
+ System.out.println();
+ for (int i = 0; i < trainedModel.getNrOfHiddenStates(); ++i) {
+ System.out.print(i + " ");
+ for (int j = 0; j < trainedModel.getNrOfOutputStates(); ++j) {
+ System.out.print(trainedModel.getEmissionMatrix().get(i, j) + " ");
+ }
+ System.out.println();
+ }
+ } catch (OptionException e) {
+ CommandLineUtil.printHelp(optionGroup);
+ }
+ }
+}
[06/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
new file mode 100644
index 0000000..3104cb1
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.junit.Test;
+
+public class HMMEvaluatorTest extends HMMTestBase {
+
+ /**
+ * Test to make sure the computed model likelihood ist valid. Included tests
+ * are: a) forwad == backward likelihood b) model likelihood for test seqeunce
+ * is the expected one from R reference
+ */
+ @Test
+ public void testModelLikelihood() {
+ // compute alpha and beta values
+ Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), false);
+ Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), false);
+ // now test whether forward == backward likelihood
+ double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, false);
+ double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(),
+ beta, false);
+ assertEquals(forwardLikelihood, backwardLikelihood, EPSILON);
+ // also make sure that the likelihood matches the expected one
+ assertEquals(1.8425e-4, forwardLikelihood, EPSILON);
+ }
+
+ /**
+ * Test to make sure the computed model likelihood ist valid. Included tests
+ * are: a) forwad == backward likelihood b) model likelihood for test seqeunce
+ * is the expected one from R reference
+ */
+ @Test
+ public void testScaledModelLikelihood() {
+ // compute alpha and beta values
+ Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), true);
+ Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), true);
+ // now test whether forward == backward likelihood
+ double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, true);
+ double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(),
+ beta, true);
+ assertEquals(forwardLikelihood, backwardLikelihood, EPSILON);
+ // also make sure that the likelihood matches the expected one
+ assertEquals(1.8425e-4, forwardLikelihood, EPSILON);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
new file mode 100644
index 0000000..3260f51
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.junit.Test;
+
+public class HMMModelTest extends HMMTestBase {
+
+ @Test
+ public void testRandomModelGeneration() {
+ // make sure we generate a valid random model
+ HmmModel model = new HmmModel(10, 20);
+ // check whether the model is valid
+ HmmUtils.validate(model);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
new file mode 100644
index 0000000..90f1cd8
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+
+public class HMMTestBase extends MahoutTestCase {
+
+ private HmmModel model;
+ private final int[] sequence = {1, 0, 2, 2, 0, 0, 1};
+
+ /**
+ * We initialize a new HMM model using the following parameters # hidden
+ * states: 4 ("H0","H1","H2","H3") # output states: 3 ("O0","O1","O2") #
+ * transition matrix to: H0 H1 H2 H3 from: H0 0.5 0.1 0.1 0.3 H1 0.4 0.4 0.1
+ * 0.1 H2 0.1 0.0 0.8 0.1 H3 0.1 0.1 0.1 0.7 # output matrix to: O0 O1 O2
+ * from: H0 0.8 0.1 0.1 H1 0.6 0.1 0.3 H2 0.1 0.8 0.1 H3 0.0 0.1 0.9 # initial
+ * probabilities H0 0.2
+ * <p/>
+ * H1 0.1 H2 0.4 H3 0.3
+ * <p/>
+ * We also intialize an observation sequence: "O1" "O0" "O2" "O2" "O0" "O0"
+ * "O1"
+ */
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ // intialize the hidden/output state names
+ String[] hiddenNames = {"H0", "H1", "H2", "H3"};
+ String[] outputNames = {"O0", "O1", "O2"};
+ // initialize the transition matrix
+ double[][] transitionP = {{0.5, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
+ {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
+ // initialize the emission matrix
+ double[][] emissionP = {{0.8, 0.1, 0.1}, {0.6, 0.1, 0.3},
+ {0.1, 0.8, 0.1}, {0.0, 0.1, 0.9}};
+ // initialize the initial probability vector
+ double[] initialP = {0.2, 0.1, 0.4, 0.3};
+ // now generate the model
+ model = new HmmModel(new DenseMatrix(transitionP), new DenseMatrix(
+ emissionP), new DenseVector(initialP));
+ model.registerHiddenStateNames(hiddenNames);
+ model.registerOutputStateNames(outputNames);
+ // make sure the model is valid :)
+ HmmUtils.validate(model);
+ }
+
+ protected HmmModel getModel() {
+ return model;
+ }
+
+ protected int[] getSequence() {
+ return sequence;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
new file mode 100644
index 0000000..b8f3186
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
@@ -0,0 +1,163 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class HMMTrainerTest extends HMMTestBase {
+
+ @Test
+ public void testViterbiTraining() {
+ // initialize the expected model parameters (from R)
+ // expected transition matrix
+ double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125},
+ {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
+ {0.5, 0.1, 0.1, 0.3}};
+ // initialize the emission matrix
+ double[][] emissionE = {{0.882353, 0.058824, 0.058824},
+ {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
+ {0.111111, 0.111111, 0.777778}};
+
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10, false);
+
+ // now check whether the model matches our expectations
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j], EPSILON);
+ }
+
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j], EPSILON);
+ }
+ }
+
+ }
+
+ @Test
+ public void testScaledViterbiTraining() {
+ // initialize the expected model parameters (from R)
+ // expected transition matrix
+ double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125},
+ {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
+ {0.5, 0.1, 0.1, 0.3}};
+ // initialize the emission matrix
+ double[][] emissionE = {{0.882353, 0.058824, 0.058824},
+ {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
+ {0.111111, 0.111111, 0.777778}};
+
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10,
+ true);
+
+ // now check whether the model matches our expectations
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j],
+ EPSILON);
+ }
+
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j],
+ EPSILON);
+ }
+ }
+
+ }
+
+ @Test
+ public void testBaumWelchTraining() {
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ // expected values from Matlab HMM package / R HMM package
+ double[] initialExpected = {0, 0, 1.0, 0};
+ double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683},
+ {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0},
+ {0.0024, 0.6657, 0, 0.3319}};
+ double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
+ {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
+
+ HmmModel trained = HmmTrainer.trainBaumWelch(getModel(), observed, 0.1, 10,
+ false);
+
+ Vector initialProbabilities = trained.getInitialProbabilities();
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ assertEquals(initialProbabilities.get(i), initialExpected[i],
+ 0.0001);
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j),
+ transitionExpected[i][j], 0.0001);
+ }
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j),
+ emissionExpected[i][j], 0.0001);
+ }
+ }
+ }
+
+ @Test
+ public void testScaledBaumWelchTraining() {
+ // train the given network to the following output sequence
+ int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+ // expected values from Matlab HMM package / R HMM package
+ double[] initialExpected = {0, 0, 1.0, 0};
+ double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683},
+ {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0},
+ {0.0024, 0.6657, 0, 0.3319}};
+ double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
+ {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
+
+ HmmModel trained = HmmTrainer
+ .trainBaumWelch(getModel(), observed, 0.1, 10, true);
+
+ Vector initialProbabilities = trained.getInitialProbabilities();
+ Matrix emissionMatrix = trained.getEmissionMatrix();
+ Matrix transitionMatrix = trained.getTransitionMatrix();
+
+ for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+ assertEquals(initialProbabilities.get(i), initialExpected[i],
+ 0.0001);
+ for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) {
+ assertEquals(transitionMatrix.getQuick(i, j),
+ transitionExpected[i][j], 0.0001);
+ }
+ for (int j = 0; j < trained.getNrOfOutputStates(); ++j) {
+ assertEquals(emissionMatrix.getQuick(i, j),
+ emissionExpected[i][j], 0.0001);
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
new file mode 100644
index 0000000..6c34718
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
@@ -0,0 +1,161 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class HMMUtilsTest extends HMMTestBase {
+
+ private Matrix legal22;
+ private Matrix legal23;
+ private Matrix legal33;
+ private Vector legal2;
+ private Matrix illegal22;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ legal22 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}});
+ legal23 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6},
+ {0.3, 0.3, 0.4}});
+ legal33 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8},
+ {0.1, 0.2, 0.7}, {0.2, 0.3, 0.5}});
+ legal2 = new DenseVector(new double[]{0.4, 0.6});
+ illegal22 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}});
+ }
+
+ @Test
+ public void testValidatorLegal() {
+ HmmUtils.validate(new HmmModel(legal22, legal23, legal2));
+ }
+
+ @Test
+ public void testValidatorDimensionError() {
+ try {
+ HmmUtils.validate(new HmmModel(legal33, legal23, legal2));
+ } catch (IllegalArgumentException e) {
+ // success
+ return;
+ }
+ fail();
+ }
+
+ @Test
+ public void testValidatorIllegelMatrixError() {
+ try {
+ HmmUtils.validate(new HmmModel(illegal22, legal23, legal2));
+ } catch (IllegalArgumentException e) {
+ // success
+ return;
+ }
+ fail();
+ }
+
+ @Test
+ public void testEncodeStateSequence() {
+ String[] hiddenSequence = {"H1", "H2", "H0", "H3", "H4"};
+ String[] outputSequence = {"O1", "O2", "O4", "O0"};
+ // test encoding the hidden Sequence
+ int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays
+ .asList(hiddenSequence), false, -1);
+ int[] outputSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays
+ .asList(outputSequence), true, -1);
+ // expected state sequences
+ int[] hiddenSequenceExp = {1, 2, 0, 3, -1};
+ int[] outputSequenceExp = {1, 2, -1, 0};
+ // compare
+ for (int i = 0; i < hiddenSequenceEnc.length; ++i) {
+ assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]);
+ }
+ for (int i = 0; i < outputSequenceEnc.length; ++i) {
+ assertEquals(outputSequenceExp[i], outputSequenceEnc[i]);
+ }
+ }
+
+ @Test
+ public void testDecodeStateSequence() {
+ int[] hiddenSequence = {1, 2, 0, 3, 10};
+ int[] outputSequence = {1, 2, 10, 0};
+ // test encoding the hidden Sequence
+ List<String> hiddenSequenceDec = HmmUtils.decodeStateSequence(
+ getModel(), hiddenSequence, false, "unknown");
+ List<String> outputSequenceDec = HmmUtils.decodeStateSequence(
+ getModel(), outputSequence, true, "unknown");
+ // expected state sequences
+ String[] hiddenSequenceExp = {"H1", "H2", "H0", "H3", "unknown"};
+ String[] outputSequenceExp = {"O1", "O2", "unknown", "O0"};
+ // compare
+ for (int i = 0; i < hiddenSequenceExp.length; ++i) {
+ assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i));
+ }
+ for (int i = 0; i < outputSequenceExp.length; ++i) {
+ assertEquals(outputSequenceExp[i], outputSequenceDec.get(i));
+ }
+ }
+
+ @Test
+ public void testNormalizeModel() {
+ DenseVector ip = new DenseVector(new double[]{10, 20});
+ DenseMatrix tr = new DenseMatrix(new double[][]{{10, 10}, {20, 25}});
+ DenseMatrix em = new DenseMatrix(new double[][]{{5, 7}, {10, 15}});
+ HmmModel model = new HmmModel(tr, em, ip);
+ HmmUtils.normalizeModel(model);
+ // the model should be valid now
+ HmmUtils.validate(model);
+ }
+
+ @Test
+ public void testTruncateModel() {
+ DenseVector ip = new DenseVector(new double[]{0.0001, 0.0001, 0.9998});
+ DenseMatrix tr = new DenseMatrix(new double[][]{
+ {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
+ {0.0001, 0.0001, 0.9998}});
+ DenseMatrix em = new DenseMatrix(new double[][]{
+ {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
+ {0.0001, 0.0001, 0.9998}});
+ HmmModel model = new HmmModel(tr, em, ip);
+ // now truncate the model
+ HmmModel sparseModel = HmmUtils.truncateModel(model, 0.01);
+ // first make sure this is a valid model
+ HmmUtils.validate(sparseModel);
+ // now check whether the values are as expected
+ Vector sparse_ip = sparseModel.getInitialProbabilities();
+ Matrix sparse_tr = sparseModel.getTransitionMatrix();
+ Matrix sparse_em = sparseModel.getEmissionMatrix();
+ for (int i = 0; i < sparseModel.getNrOfHiddenStates(); ++i) {
+ assertEquals(i == 2 ? 1.0 : 0.0, sparse_ip.getQuick(i), EPSILON);
+ for (int j = 0; j < sparseModel.getNrOfHiddenStates(); ++j) {
+ if (i == j) {
+ assertEquals(1.0, sparse_tr.getQuick(i, j), EPSILON);
+ assertEquals(1.0, sparse_em.getQuick(i, j), EPSILON);
+ } else {
+ assertEquals(0.0, sparse_tr.getQuick(i, j), EPSILON);
+ assertEquals(0.0, sparse_em.getQuick(i, j), EPSILON);
+ }
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
new file mode 100644
index 0000000..7ea8cb2
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.jet.random.Exponential;
+import org.junit.Test;
+
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
+
+import java.util.Random;
+
+public final class AdaptiveLogisticRegressionTest extends MahoutTestCase {
+
+ @ThreadLeakLingering(linger=1000)
+ @Test
+ public void testTrain() {
+
+ Random gen = RandomUtils.getRandom();
+ Exponential exp = new Exponential(0.5, gen);
+ Vector beta = new DenseVector(200);
+ for (Vector.Element element : beta.all()) {
+ int sign = 1;
+ if (gen.nextDouble() < 0.5) {
+ sign = -1;
+ }
+ element.set(sign * exp.nextDouble());
+ }
+
+ AdaptiveLogisticRegression.Wrapper cl = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1());
+ cl.update(new double[]{1.0e-5, 1});
+
+ for (int i = 0; i < 10000; i++) {
+ AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+ cl.train(r);
+ if (i % 1000 == 0) {
+ System.out.printf("%10d %10.3f\n", i, cl.getLearner().auc());
+ }
+ }
+ assertEquals(1, cl.getLearner().auc(), 0.1);
+
+ AdaptiveLogisticRegression adaptiveLogisticRegression = new AdaptiveLogisticRegression(2, 200, new L1());
+ adaptiveLogisticRegression.setInterval(1000);
+
+ for (int i = 0; i < 20000; i++) {
+ AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+ adaptiveLogisticRegression.train(r.getKey(), r.getActual(), r.getInstance());
+ if (i % 1000 == 0 && adaptiveLogisticRegression.getBest() != null) {
+ System.out.printf("%10d %10.4f %10.8f %.3f\n",
+ i, adaptiveLogisticRegression.auc(),
+ Math.log10(adaptiveLogisticRegression.getBest().getMappedParams()[0]), adaptiveLogisticRegression.getBest().getMappedParams()[1]);
+ }
+ }
+ assertEquals(1, adaptiveLogisticRegression.auc(), 0.1);
+ adaptiveLogisticRegression.close();
+ }
+
+ private static AdaptiveLogisticRegression.TrainingExample getExample(int i, Random gen, Vector beta) {
+ Vector data = new DenseVector(200);
+
+ for (Vector.Element element : data.all()) {
+ element.set(gen.nextDouble() < 0.3 ? 1 : 0);
+ }
+
+ double p = 1 / (1 + Math.exp(1.5 - data.dot(beta)));
+ int target = 0;
+ if (gen.nextDouble() < p) {
+ target = 1;
+ }
+ return new AdaptiveLogisticRegression.TrainingExample(i, null, target, data);
+ }
+
+ @Test
+ public void copyLearnsAsExpected() {
+ Random gen = RandomUtils.getRandom();
+ Exponential exp = new Exponential(0.5, gen);
+ Vector beta = new DenseVector(200);
+ for (Vector.Element element : beta.all()) {
+ int sign = 1;
+ if (gen.nextDouble() < 0.5) {
+ sign = -1;
+ }
+ element.set(sign * exp.nextDouble());
+ }
+
+ // train one copy of a wrapped learner
+ AdaptiveLogisticRegression.Wrapper w = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1());
+ for (int i = 0; i < 3000; i++) {
+ AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+ w.train(r);
+ if (i % 1000 == 0) {
+ System.out.printf("%10d %.3f\n", i, w.getLearner().auc());
+ }
+ }
+ System.out.printf("%10d %.3f\n", 3000, w.getLearner().auc());
+ double auc1 = w.getLearner().auc();
+
+ // then switch to a copy of that learner ... progress should continue
+ AdaptiveLogisticRegression.Wrapper w2 = w.copy();
+
+ for (int i = 0; i < 5000; i++) {
+ if (i % 1000 == 0) {
+ if (i == 0) {
+ assertEquals("Should have started with no data", 0.5, w2.getLearner().auc(), 0.0001);
+ }
+ if (i == 1000) {
+ double auc2 = w2.getLearner().auc();
+ assertTrue("Should have had head-start", Math.abs(auc2 - 0.5) > 0.1);
+ assertTrue("AUC should improve quickly on copy", auc1 < auc2);
+ }
+ System.out.printf("%10d %.3f\n", i, w2.getLearner().auc());
+ }
+ AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta);
+ w2.train(r);
+ }
+ assertEquals("Original should not change after copy is updated", auc1, w.getLearner().auc(), 1.0e-5);
+
+ // this improvement is really quite lenient
+ assertTrue("AUC should improve significantly on copy", auc1 < w2.getLearner().auc() - 0.05);
+
+ // make sure that the copy didn't lose anything
+ assertEquals(auc1, w.getLearner().auc(), 0);
+ }
+
+ @Test
+ public void stepSize() {
+ assertEquals(500, AdaptiveLogisticRegression.stepSize(15000, 2));
+ assertEquals(2000, AdaptiveLogisticRegression.stepSize(15000, 2.6));
+ assertEquals(5000, AdaptiveLogisticRegression.stepSize(24000, 2.6));
+ assertEquals(10000, AdaptiveLogisticRegression.stepSize(15000, 3));
+ }
+
+ @Test
+ @ThreadLeakLingering(linger = 1000)
+ public void constantStep() {
+ AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1());
+ lr.setInterval(5000);
+ assertEquals(20000, lr.nextStep(15000));
+ assertEquals(20000, lr.nextStep(15001));
+ assertEquals(20000, lr.nextStep(16500));
+ assertEquals(20000, lr.nextStep(19999));
+ lr.close();
+ }
+
+
+ @Test
+ @ThreadLeakLingering(linger = 1000)
+ public void growingStep() {
+ AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1());
+ lr.setInterval(2000, 10000);
+
+ // start with minimum step size
+ for (int i = 2000; i < 20000; i+=2000) {
+ assertEquals(i + 2000, lr.nextStep(i));
+ }
+
+ // then level up a bit
+ for (int i = 20000; i < 50000; i += 5000) {
+ assertEquals(i + 5000, lr.nextStep(i));
+ }
+
+ // and more, but we top out with this step size
+ for (int i = 50000; i < 500000; i += 10000) {
+ assertEquals(i + 10000, lr.nextStep(i));
+ }
+ lr.close();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
new file mode 100644
index 0000000..6ee0ddf
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.ImmutableMap;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.junit.Test;
+
+public final class CsvRecordFactoryTest extends MahoutTestCase {
+
+ @Test
+ public void testAddToVector() {
+ RecordFactory csv = new CsvRecordFactory("y", ImmutableMap.of("x1", "n", "x2", "w", "x3", "t"));
+ csv.firstLine("z,x1,y,x2,x3,q");
+ csv.maxTargetValue(2);
+
+ Vector v = new DenseVector(2000);
+ int t = csv.processLine("ignore,3.1,yes,tiger, \"this is text\",ignore", v);
+ assertEquals(0, t);
+ // should have 9 values set
+ assertEquals(9.0, v.norm(0), 0);
+ // all should be = 1 except for the 3.1
+ assertEquals(3.1, v.maxValue(), 0);
+ v.set(v.maxValueIndex(), 0);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(8.0, v.norm(1), 0);
+ assertEquals(1.0, v.maxValue(), 0);
+
+ v.assign(0);
+ t = csv.processLine("ignore,5.3,no,line, \"and more text and more\",ignore", v);
+ assertEquals(1, t);
+
+ // should have 9 values set
+ assertEquals(9.0, v.norm(0), 0);
+ // all should be = 1 except for the 3.1
+ assertEquals(5.3, v.maxValue(), 0);
+ v.set(v.maxValueIndex(), 0);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(10.339850002884626, v.norm(1), 1.0e-6);
+ assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6);
+
+ v.assign(0);
+ t = csv.processLine("ignore,5.3,invalid,line, \"and more text and more\",ignore", v);
+ assertEquals(1, t);
+
+ // should have 9 values set
+ assertEquals(9.0, v.norm(0), 0);
+ // all should be = 1 except for the 3.1
+ assertEquals(5.3, v.maxValue(), 0);
+ v.set(v.maxValueIndex(), 0);
+ assertEquals(8.0, v.norm(0), 0);
+ assertEquals(10.339850002884626, v.norm(1), 1.0e-6);
+ assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6);
+ }
+
+ @Test
+ public void testDictionaryOrder() {
+ Dictionary dict = new Dictionary();
+
+ dict.intern("a");
+ dict.intern("d");
+ dict.intern("c");
+ dict.intern("b");
+ dict.intern("qrz");
+
+ assertEquals("[a, d, c, b, qrz]", dict.values().toString());
+
+ Dictionary dict2 = Dictionary.fromList(dict.values());
+ assertEquals("[a, d, c, b, qrz]", dict2.values().toString());
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
new file mode 100644
index 0000000..06a876e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Random;
+
+public final class GradientMachineTest extends OnlineBaseTest {
+
+ @Test
+ public void testGradientmachine() throws IOException {
+ Vector target = readStandardData();
+ GradientMachine grad = new GradientMachine(8,4,2).learningRate(0.1).regularization(0.01);
+ Random gen = RandomUtils.getRandom();
+ grad.initWeights(gen);
+ train(getInput(), target, grad);
+ // TODO not sure why the RNG change made this fail. Value is 0.5-1.0 no matter what seed is chosen?
+ test(getInput(), target, grad, 1.0, 1);
+ //test(getInput(), target, grad, 0.05, 1);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
new file mode 100644
index 0000000..2373b9d
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Random;
+
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+import org.junit.Test;
+
+public final class ModelSerializerTest extends MahoutTestCase {
+
+ private static <T extends Writable> T roundTrip(T m, Class<T> clazz) throws IOException {
+ ByteArrayOutputStream buf = new ByteArrayOutputStream(1000);
+ DataOutputStream dos = new DataOutputStream(buf);
+ try {
+ PolymorphicWritable.write(dos, m);
+ } finally {
+ Closeables.close(dos, false);
+ }
+ return PolymorphicWritable.read(new DataInputStream(new ByteArrayInputStream(buf.toByteArray())), clazz);
+ }
+
+ @Test
+ public void onlineAucRoundtrip() throws IOException {
+ RandomUtils.useTestSeed();
+ OnlineAuc auc1 = new GlobalOnlineAuc();
+ Random gen = RandomUtils.getRandom();
+ for (int i = 0; i < 10000; i++) {
+ auc1.addSample(0, gen.nextGaussian());
+ auc1.addSample(1, gen.nextGaussian() + 1);
+ }
+ assertEquals(0.76, auc1.auc(), 0.01);
+
+ OnlineAuc auc3 = roundTrip(auc1, OnlineAuc.class);
+
+ assertEquals(auc1.auc(), auc3.auc(), 0);
+
+ for (int i = 0; i < 1000; i++) {
+ auc1.addSample(0, gen.nextGaussian());
+ auc1.addSample(1, gen.nextGaussian() + 1);
+
+ auc3.addSample(0, gen.nextGaussian());
+ auc3.addSample(1, gen.nextGaussian() + 1);
+ }
+
+ assertEquals(auc1.auc(), auc3.auc(), 0.01);
+ }
+
+ @Test
+ public void onlineLogisticRegressionRoundTrip() throws IOException {
+ OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, new L1());
+ train(olr, 100);
+ OnlineLogisticRegression olr3 = roundTrip(olr, OnlineLogisticRegression.class);
+ assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6);
+
+ train(olr, 100);
+ train(olr3, 100);
+
+ assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6);
+ olr.close();
+ olr3.close();
+ }
+
+ @Test
+ public void crossFoldLearnerRoundTrip() throws IOException {
+ CrossFoldLearner learner = new CrossFoldLearner(5, 2, 5, new L1());
+ train(learner, 100);
+ CrossFoldLearner olr3 = roundTrip(learner, CrossFoldLearner.class);
+ double auc1 = learner.auc();
+ assertTrue(auc1 > 0.85);
+ assertEquals(auc1, learner.auc(), 1.0e-6);
+ assertEquals(auc1, olr3.auc(), 1.0e-6);
+
+ train(learner, 100);
+ train(learner, 100);
+ train(olr3, 100);
+
+ assertEquals(learner.auc(), learner.auc(), 0.02);
+ assertEquals(learner.auc(), olr3.auc(), 0.02);
+ double auc2 = learner.auc();
+ assertTrue(auc2 > auc1);
+ learner.close();
+ olr3.close();
+ }
+
+ @ThreadLeakLingering(linger = 1000)
+ @Test
+ public void adaptiveLogisticRegressionRoundTrip() throws IOException {
+ AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5, new L1());
+ learner.setInterval(200);
+ train(learner, 400);
+ AdaptiveLogisticRegression olr3 = roundTrip(learner, AdaptiveLogisticRegression.class);
+ double auc1 = learner.auc();
+ assertTrue(auc1 > 0.85);
+ assertEquals(auc1, learner.auc(), 1.0e-6);
+ assertEquals(auc1, olr3.auc(), 1.0e-6);
+
+ train(learner, 1000);
+ train(learner, 1000);
+ train(olr3, 1000);
+
+ assertEquals(learner.auc(), learner.auc(), 0.005);
+ assertEquals(learner.auc(), olr3.auc(), 0.005);
+ double auc2 = learner.auc();
+ assertTrue(String.format("%.3f > %.3f", auc2, auc1), auc2 > auc1);
+ learner.close();
+ olr3.close();
+ }
+
+ private static void train(OnlineLearner olr, int n) {
+ Vector beta = new DenseVector(new double[]{1, -1, 0, 0.5, -0.5});
+ Random gen = RandomUtils.getRandom();
+ for (int i = 0; i < n; i++) {
+ Vector x = randomVector(gen, 5);
+
+ int target = gen.nextDouble() < beta.dot(x) ? 1 : 0;
+ olr.train(target, x);
+ }
+ }
+
+ private static Vector randomVector(final Random gen, int n) {
+ Vector x = new DenseVector(n);
+ x.assign(new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return gen.nextGaussian();
+ }
+ });
+ return x;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
new file mode 100644
index 0000000..e0a252c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.CharStreams;
+import com.google.common.io.Resources;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+public abstract class OnlineBaseTest extends MahoutTestCase {
+
+ private Matrix input;
+
+ Matrix getInput() {
+ return input;
+ }
+
+ Vector readStandardData() throws IOException {
+ // 60 test samples. First column is constant. Second and third are normally distributed from
+ // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The first 30 rows have a
+ // target variable of 0, the last 30 a target of 1. The remaining columns are are random noise.
+ input = readCsv("sgd.csv");
+
+ // regenerate the target variable
+ Vector target = new DenseVector(60);
+ target.assign(0);
+ target.viewPart(30, 30).assign(1);
+ return target;
+ }
+
+ static void train(Matrix input, Vector target, OnlineLearner lr) {
+ RandomUtils.useTestSeed();
+ Random gen = RandomUtils.getRandom();
+
+ // train on samples in random order (but only one pass)
+ for (int row : permute(gen, 60)) {
+ lr.train((int) target.get(row), input.viewRow(row));
+ }
+ lr.close();
+ }
+
+ static void test(Matrix input, Vector target, AbstractVectorClassifier lr,
+ double expected_mean_error, double expected_absolute_error) {
+ // now test the accuracy
+ Matrix tmp = lr.classify(input);
+ // mean(abs(tmp - target))
+ double meanAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;
+
+ // max(abs(tmp - target)
+ double maxAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
+
+ System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError);
+ assertEquals(0, meanAbsoluteError , expected_mean_error);
+ assertEquals(0, maxAbsoluteError, expected_absolute_error);
+
+ // convenience methods should give the same results
+ Vector v = lr.classifyScalar(input);
+ assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-5);
+ v = lr.classifyFull(input).viewColumn(1);
+ assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-4);
+ }
+
+ /**
+ * Permute the integers from 0 ... max-1
+ *
+ * @param gen The random number generator to use.
+ * @param max The number of integers to permute
+ * @return An array of jumbled integer values
+ */
+ static int[] permute(Random gen, int max) {
+ int[] permutation = new int[max];
+ permutation[0] = 0;
+ for (int i = 1; i < max; i++) {
+ int n = gen.nextInt(i + 1);
+ if (n == i) {
+ permutation[i] = i;
+ } else {
+ permutation[i] = permutation[n];
+ permutation[n] = i;
+ }
+ }
+ return permutation;
+ }
+
+
+ /**
+ * Reads a file containing CSV data. This isn't implemented quite the way you might like for a
+ * real program, but does the job for reading test data. Most notably, it will only read numbers,
+ * not quoted strings.
+ *
+ * @param resourceName Where to get the data.
+ * @return A matrix of the results.
+ * @throws IOException If there is an error reading the data
+ */
+ static Matrix readCsv(String resourceName) throws IOException {
+ Splitter onCommas = Splitter.on(',').trimResults(CharMatcher.anyOf(" \""));
+
+ Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8);
+ List<String> data = CharStreams.readLines(isr);
+ String first = data.get(0);
+ data = data.subList(1, data.size());
+
+ List<String> values = Lists.newArrayList(onCommas.split(first));
+ Matrix r = new DenseMatrix(data.size(), values.size());
+
+ int column = 0;
+ Map<String, Integer> labels = Maps.newHashMap();
+ for (String value : values) {
+ labels.put(value, column);
+ column++;
+ }
+ r.setColumnLabelBindings(labels);
+
+ int row = 0;
+ for (String line : data) {
+ column = 0;
+ values = Lists.newArrayList(onCommas.split(line));
+ for (String value : values) {
+ r.set(row, column, Double.parseDouble(value));
+ column++;
+ }
+ row++;
+ }
+
+ return r;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
new file mode 100644
index 0000000..44b7525
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.io.Resources;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.junit.Assert;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+
+public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
+
+ private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class);
+
+ /**
+ * The CrossFoldLearner is probably the best learner to use for new applications.
+ *
+ * @throws IOException If test resources aren't readable.
+ */
+ @Test
+ public void crossValidation() throws IOException {
+ Vector target = readStandardData();
+
+ CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1())
+ .lambda(1 * 1.0e-3)
+ .learningRate(50);
+
+
+ train(getInput(), target, lr);
+
+ System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
+ test(getInput(), target, lr, 0.05, 0.3);
+
+ }
+
+ @Test
+ public void crossValidatedAuc() throws IOException {
+ RandomUtils.useTestSeed();
+ Random gen = RandomUtils.getRandom();
+
+ Matrix data = readCsv("cancer.csv");
+ CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1())
+ .stepOffset(10)
+ .decayExponent(0.7)
+ .lambda(1 * 1.0e-3)
+ .learningRate(5);
+ int k = 0;
+ int[] ordering = permute(gen, data.numRows());
+ for (int epoch = 0; epoch < 100; epoch++) {
+ for (int row : ordering) {
+ lr.train(row, (int) data.get(row, 9), data.viewRow(row));
+ System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc());
+ }
+ assertEquals(1, lr.auc(), 0.2);
+ }
+ assertEquals(1, lr.auc(), 0.1);
+ }
+
+ /**
+ * Verifies that a classifier with known coefficients does the right thing.
+ */
+ @Test
+ public void testClassify() {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1));
+ // set up some internal coefficients as if we had learned them
+ lr.setBeta(0, 0, -1);
+ lr.setBeta(1, 0, -2);
+
+ // zero vector gives no information. All classes are equal.
+ Vector v = lr.classify(new DenseVector(new double[]{0, 0}));
+ assertEquals(1 / 3.0, v.get(0), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-8);
+
+ v = lr.classifyFull(new DenseVector(new double[]{0, 0}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(0), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(2), 1.0e-8);
+
+ // weights for second vector component are still zero so all classifications are equally likely
+ v = lr.classify(new DenseVector(new double[]{0, 1}));
+ assertEquals(1 / 3.0, v.get(0), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-3);
+
+ v = lr.classifyFull(new DenseVector(new double[]{0, 1}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / 3.0, v.get(0), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(1), 1.0e-3);
+ assertEquals(1 / 3.0, v.get(2), 1.0e-3);
+
+ // but the weights on the first component are non-zero
+ v = lr.classify(new DenseVector(new double[]{1, 0}));
+ assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 0}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8);
+ assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8);
+
+ lr.setBeta(0, 1, 1);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3);
+ assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3);
+ assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3);
+
+ lr.setBeta(1, 1, 3);
+
+ v = lr.classifyFull(new DenseVector(new double[]{1, 1}));
+ assertEquals(1.0, v.zSum(), 1.0e-8);
+ assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8);
+ assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8);
+ assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8);
+ }
+
+ @Test
+ public void iris() throws IOException {
+ // this test trains a 3-way classifier on the famous Iris dataset.
+ // a similar exercise can be accomplished in R using this code:
+ // library(nnet)
+ // correct = rep(0,100)
+ // for (j in 1:100) {
+ // i = order(runif(150))
+ // train = iris[i[1:100],]
+ // test = iris[i[101:150],]
+ // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train)
+ // correct[j] = mean(predict(m, newdata=test) == test$Species)
+ // }
+ // hist(correct)
+ //
+ // Note that depending on the training/test split, performance can be better or worse.
+ // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy
+ // of 100%
+ //
+ // This test uses a deterministic split that is neither outstandingly good nor bad
+
+
+ RandomUtils.useTestSeed();
+ Splitter onComma = Splitter.on(",");
+
+ // read the data
+ List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8);
+
+ // holds features
+ List<Vector> data = Lists.newArrayList();
+
+ // holds target variable
+ List<Integer> target = Lists.newArrayList();
+
+ // for decoding target values
+ Dictionary dict = new Dictionary();
+
+ // for permuting data later
+ List<Integer> order = Lists.newArrayList();
+
+ for (String line : raw.subList(1, raw.size())) {
+ // order gets a list of indexes
+ order.add(order.size());
+
+ // parse the predictor variables
+ Vector v = new DenseVector(5);
+ v.set(0, 1);
+ int i = 1;
+ Iterable<String> values = onComma.split(line);
+ for (String value : Iterables.limit(values, 4)) {
+ v.set(i++, Double.parseDouble(value));
+ }
+ data.add(v);
+
+ // and the target
+ target.add(dict.intern(Iterables.get(values, 4)));
+ }
+
+ // randomize the order ... original data has each species all together
+ // note that this randomization is deterministic
+ Random random = RandomUtils.getRandom();
+ Collections.shuffle(order, random);
+
+ // select training and test data
+ List<Integer> train = order.subList(0, 100);
+ List<Integer> test = order.subList(100, 150);
+ logger.warn("Training set = {}", train);
+ logger.warn("Test set = {}", test);
+
+ // now train many times and collect information on accuracy each time
+ int[] correct = new int[test.size() + 1];
+ for (int run = 0; run < 200; run++) {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1));
+ // 30 training passes should converge to > 95% accuracy nearly always but never to 100%
+ for (int pass = 0; pass < 30; pass++) {
+ Collections.shuffle(train, random);
+ for (int k : train) {
+ lr.train(target.get(k), data.get(k));
+ }
+ }
+
+ // check the accuracy on held out data
+ int x = 0;
+ int[] count = new int[3];
+ for (Integer k : test) {
+ int r = lr.classifyFull(data.get(k)).maxValueIndex();
+ count[r]++;
+ x += r == target.get(k) ? 1 : 0;
+ }
+ correct[x]++;
+ }
+
+ // verify we never saw worse than 95% correct,
+ for (int i = 0; i < Math.floor(0.95 * test.size()); i++) {
+ assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]);
+ }
+ // nor perfect
+ assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size()]);
+ }
+
+ @Test
+ public void testTrain() throws Exception {
+ Vector target = readStandardData();
+
+
+ // lambda here needs to be relatively small to avoid swamping the actual signal, but can be
+ // larger than usual because the data are dense. The learning rate doesn't matter too much
+ // for this example, but should generally be < 1
+ // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias
+ // --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+ .lambda(1 * 1.0e-3)
+ .learningRate(50);
+
+ train(getInput(), target, lr);
+ test(getInput(), target, lr, 0.05, 0.3);
+ }
+
+ /**
+ * Test for Serialization/DeSerialization
+ *
+ */
+ @Test
+ public void testSerializationAndDeSerialization() throws Exception {
+ OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1())
+ .lambda(1 * 1.0e-3)
+ .stepOffset(11)
+ .alpha(0.01)
+ .learningRate(50)
+ .decayExponent(-0.02);
+
+ lr.close();
+
+ byte[] output;
+
+ try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
+ DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream)) {
+ PolymorphicWritable.write(dataOutputStream, lr);
+ output = byteArrayOutputStream.toByteArray();
+ }
+
+ OnlineLogisticRegression read;
+
+ try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(output);
+ DataInputStream dataInputStream = new DataInputStream(byteArrayInputStream)) {
+ read = PolymorphicWritable.read(dataInputStream, OnlineLogisticRegression.class);
+ }
+
+ //lambda
+ Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7);
+
+ // Reflection to get private variables
+ //stepOffset
+ Field stepOffset = lr.getClass().getDeclaredField("stepOffset");
+ stepOffset.setAccessible(true);
+ int stepOffsetVal = (Integer) stepOffset.get(lr);
+ Assert.assertEquals(11, stepOffsetVal);
+
+ //decayFactor (alpha)
+ Field decayFactor = lr.getClass().getDeclaredField("decayFactor");
+ decayFactor.setAccessible(true);
+ double decayFactorVal = (Double) decayFactor.get(lr);
+ Assert.assertEquals(0.01, decayFactorVal, 1.0e-7);
+
+ //learning rate (mu0)
+ Field mu0 = lr.getClass().getDeclaredField("mu0");
+ mu0.setAccessible(true);
+ double mu0Val = (Double) mu0.get(lr);
+ Assert.assertEquals(50, mu0Val, 1.0e-7);
+
+ //forgettingExponent (decayExponent)
+ Field forgettingExponent = lr.getClass().getDeclaredField("forgettingExponent");
+ forgettingExponent.setAccessible(true);
+ double forgettingExponentVal = (Double) forgettingExponent.get(lr);
+ Assert.assertEquals(-0.02, forgettingExponentVal, 1.0e-7);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
new file mode 100644
index 0000000..df97d38
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public final class PassiveAggressiveTest extends OnlineBaseTest {
+
+ @Test
+ public void testPassiveAggressive() throws IOException {
+ Vector target = readStandardData();
+ PassiveAggressive pa = new PassiveAggressive(2,8).learningRate(0.1);
+ train(getInput(), target, pa);
+ test(getInput(), target, pa, 0.11, 0.31);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java b/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
new file mode 100644
index 0000000..62e10c6
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
@@ -0,0 +1,152 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import java.io.IOException;
+import java.util.Random;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.stats.Sampler;
+
+public final class ClusteringTestUtils {
+
+ private ClusteringTestUtils() {
+ }
+
+ public static void writePointsToFile(Iterable<VectorWritable> points,
+ Path path,
+ FileSystem fs,
+ Configuration conf) throws IOException {
+ writePointsToFile(points, false, path, fs, conf);
+ }
+
+ public static void writePointsToFile(Iterable<VectorWritable> points,
+ boolean intWritable,
+ Path path,
+ FileSystem fs,
+ Configuration conf) throws IOException {
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs,
+ conf,
+ path,
+ intWritable ? IntWritable.class : LongWritable.class,
+ VectorWritable.class);
+ try {
+ int recNum = 0;
+ for (VectorWritable point : points) {
+ writer.append(intWritable ? new IntWritable(recNum++) : new LongWritable(recNum++), point);
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ public static Matrix sampledCorpus(Matrix matrix, Random random,
+ int numDocs, int numSamples, int numTopicsPerDoc) {
+ Matrix corpus = new SparseRowMatrix(numDocs, matrix.numCols());
+ LDASampler modelSampler = new LDASampler(matrix, random);
+ Vector topicVector = new DenseVector(matrix.numRows());
+ for (int i = 0; i < numTopicsPerDoc; i++) {
+ int topic = random.nextInt(topicVector.size());
+ topicVector.set(topic, topicVector.get(topic) + 1);
+ }
+ for (int docId = 0; docId < numDocs; docId++) {
+ for (int sample : modelSampler.sample(topicVector, numSamples)) {
+ corpus.set(docId, sample, corpus.get(docId, sample) + 1);
+ }
+ }
+ return corpus;
+ }
+
+ public static Matrix randomStructuredModel(int numTopics, int numTerms) {
+ return randomStructuredModel(numTopics, numTerms, new DoubleFunction() {
+ @Override public double apply(double d) {
+ return 1.0 / (1 + Math.abs(d));
+ }
+ });
+ }
+
+ public static Matrix randomStructuredModel(int numTopics, int numTerms, DoubleFunction decay) {
+ Matrix model = new DenseMatrix(numTopics, numTerms);
+ int width = numTerms / numTopics;
+ for (int topic = 0; topic < numTopics; topic++) {
+ int topicCentroid = width * (1+topic);
+ for (int i = 0; i < numTerms; i++) {
+ int distance = Math.abs(topicCentroid - i);
+ if (distance > numTerms / 2) {
+ distance = numTerms - distance;
+ }
+ double v = decay.apply(distance);
+ model.set(topic, i, v);
+ }
+ }
+ return model;
+ }
+
+ /**
+ * Takes in a {@link Matrix} of topic distributions (such as generated by {@link org.apache.mahout.clustering.lda.cvb.CVB0Driver} or
+ * {@link org.apache.mahout.clustering.lda.cvb.InMemoryCollapsedVariationalBayes0}, and constructs
+ * a set of samplers over this distribution, which may be sampled from by providing a distribution
+ * over topics, and a number of samples desired
+ */
+ static class LDASampler {
+ private final Random random;
+ private final Sampler[] samplers;
+
+ LDASampler(Matrix model, Random random) {
+ this.random = random;
+ samplers = new Sampler[model.numRows()];
+ for (int i = 0; i < samplers.length; i++) {
+ samplers[i] = new Sampler(random, model.viewRow(i));
+ }
+ }
+
+ /**
+ *
+ * @param topicDistribution vector of p(topicId) for all topicId < model.numTopics()
+ * @param numSamples the number of times to sample (with replacement) from the model
+ * @return array of length numSamples, with each entry being a sample from the model. There
+ * may be repeats
+ */
+ public int[] sample(Vector topicDistribution, int numSamples) {
+ Preconditions.checkNotNull(topicDistribution);
+ Preconditions.checkArgument(numSamples > 0, "numSamples must be positive");
+ Preconditions.checkArgument(topicDistribution.size() == samplers.length,
+ "topicDistribution must have same cardinality as the sampling model");
+ int[] samples = new int[numSamples];
+ Sampler topicSampler = new Sampler(random, topicDistribution);
+ for (int i = 0; i < numSamples; i++) {
+ samples[i] = samplers[topicSampler.sample()].sample();
+ }
+ return samples;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java b/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
new file mode 100644
index 0000000..1cbfb02
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public final class TestClusterInterface extends MahoutTestCase {
+
+ private static final DistanceMeasure measure = new ManhattanDistanceMeasure();
+
+ @Test
+ public void testClusterAsFormatString() {
+ double[] d = { 1.1, 2.2, 3.3 };
+ Vector m = new DenseVector(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String formatString = cluster.asFormatString(null);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[1.1,2.2,3.3]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+ @Test
+ public void testClusterAsFormatStringSparse() {
+ double[] d = { 1.1, 0.0, 3.3 };
+ Vector m = new SequentialAccessSparseVector(3);
+ m.assign(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String formatString = cluster.asFormatString(null);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+ @Test
+ public void testClusterAsFormatStringWithBindings() {
+ double[] d = { 1.1, 2.2, 3.3 };
+ Vector m = new DenseVector(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String[] bindings = { "fee", null, "foo" };
+ String formatString = cluster.asFormatString(bindings);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[{\"fee\":1.1},{\"1\":2.2},{\"foo\":3.3}]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+ @Test
+ public void testClusterAsFormatStringSparseWithBindings() {
+ double[] d = { 1.1, 0.0, 3.3 };
+ Vector m = new SequentialAccessSparseVector(3);
+ m.assign(d);
+ Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure);
+ String formatString = cluster.asFormatString(null);
+ assertTrue(formatString.contains("\"r\":[]"));
+ assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]"));
+ assertTrue(formatString.contains("\"n\":0"));
+ assertTrue(formatString.contains("\"identifier\":\"CL-123\""));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java b/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
new file mode 100644
index 0000000..43417fc
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
@@ -0,0 +1,186 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.util.Collection;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.SquareRootFunction;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public final class TestGaussianAccumulators extends MahoutTestCase {
+
+ private static final Logger log = LoggerFactory.getLogger(TestGaussianAccumulators.class);
+
+ private Collection<VectorWritable> sampleData = Lists.newArrayList();
+ private int sampleN;
+ private Vector sampleMean;
+ private Vector sampleStd;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ sampleData = Lists.newArrayList();
+ generateSamples();
+ sampleN = 0;
+ Vector sum = new DenseVector(2);
+ for (VectorWritable v : sampleData) {
+ sum.assign(v.get(), Functions.PLUS);
+ sampleN++;
+ }
+ sampleMean = sum.divide(sampleN);
+
+ Vector sampleVar = new DenseVector(2);
+ for (VectorWritable v : sampleData) {
+ Vector delta = v.get().minus(sampleMean);
+ sampleVar.assign(delta.times(delta), Functions.PLUS);
+ }
+ sampleVar = sampleVar.divide(sampleN - 1);
+ sampleStd = sampleVar.clone();
+ sampleStd.assign(new SquareRootFunction());
+ log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]",
+ sampleN, sampleMean.get(0), sampleMean.get(1), sampleStd.get(0), sampleStd.get(1));
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ *
+ * @param num
+ * int number of samples to generate
+ * @param mx
+ * double x-value of the sample mean
+ * @param my
+ * double y-value of the sample mean
+ * @param sdx
+ * double x-value standard deviation of the samples
+ * @param sdy
+ * double y-value standard deviation of the samples
+ */
+ private void generate2dSamples(int num, double mx, double my, double sdx, double sdy) {
+ log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy);
+ for (int i = 0; i < num; i++) {
+ sampleData.add(new VectorWritable(new DenseVector(new double[] { UncommonDistributions.rNorm(mx, sdx),
+ UncommonDistributions.rNorm(my, sdy) })));
+ }
+ }
+
+ private void generateSamples() {
+ generate2dSamples(50000, 1, 2, 3, 4);
+ }
+
+ @Test
+ public void testAccumulatorNoSamples() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
+ assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON);
+ }
+
+ @Test
+ public void testAccumulatorOneSample() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ Vector sample = new DenseVector(2);
+ accumulator0.observe(sample, 1.0);
+ accumulator1.observe(sample, 1.0);
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
+ assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON);
+ }
+
+ @Test
+ public void testOLAccumulatorResults() {
+ GaussianAccumulator accumulator = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator.observe(vw.get(), 1.0);
+ }
+ accumulator.compute();
+ log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]",
+ accumulator.getN(),
+ accumulator.getMean().get(0),
+ accumulator.getMean().get(1),
+ accumulator.getStd().get(0),
+ accumulator.getStd().get(1));
+ assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+ assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
+ assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), EPSILON);
+ }
+
+ @Test
+ public void testRSAccumulatorResults() {
+ GaussianAccumulator accumulator = new RunningSumsGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator.observe(vw.get(), 1.0);
+ }
+ accumulator.compute();
+ log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]",
+ (int) accumulator.getN(),
+ accumulator.getMean().get(0),
+ accumulator.getMean().get(1),
+ accumulator.getStd().get(0),
+ accumulator.getStd().get(1));
+ assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+ assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON);
+ assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), 0.0001);
+ }
+
+ @Test
+ public void testAccumulatorWeightedResults() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator0.observe(vw.get(), 0.5);
+ accumulator1.observe(vw.get(), 0.5);
+ }
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
+ assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
+ assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01);
+ }
+
+ @Test
+ public void testAccumulatorWeightedResults2() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator0.observe(vw.get(), 1.5);
+ accumulator1.observe(vw.get(), 1.5);
+ }
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON);
+ assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001);
+ assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01);
+ }
+}
[27/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java
new file mode 100644
index 0000000..7b7816c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java
@@ -0,0 +1,513 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configurable;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.DistributedRowMatrixWriter;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.Sampler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Thin wrapper around a {@link Matrix} of counts of occurrences of (topic, term) pairs. Dividing
+ * {code topicTermCount.viewRow(topic).get(term)} by the sum over the values for all terms in that
+ * row yields p(term | topic). Instead dividing it by all topic columns for that term yields
+ * p(topic | term).
+ *
+ * Multithreading is enabled for the {@code update(Matrix)} method: this method is async, and
+ * merely submits the matrix to a work queue. When all work has been submitted,
+ * {@code awaitTermination()} should be called, which will block until updates have been
+ * accumulated.
+ */
+public class TopicModel implements Configurable, Iterable<MatrixSlice> {
+
+ private static final Logger log = LoggerFactory.getLogger(TopicModel.class);
+
+ private final String[] dictionary;
+ private final Matrix topicTermCounts;
+ private final Vector topicSums;
+ private final int numTopics;
+ private final int numTerms;
+ private final double eta;
+ private final double alpha;
+
+ private Configuration conf;
+
+ private final Sampler sampler;
+ private final int numThreads;
+ private ThreadPoolExecutor threadPool;
+ private Updater[] updaters;
+
+ public int getNumTerms() {
+ return numTerms;
+ }
+
+ public int getNumTopics() {
+ return numTopics;
+ }
+
+ public TopicModel(int numTopics, int numTerms, double eta, double alpha, String[] dictionary,
+ double modelWeight) {
+ this(numTopics, numTerms, eta, alpha, null, dictionary, 1, modelWeight);
+ }
+
+ public TopicModel(Configuration conf, double eta, double alpha,
+ String[] dictionary, int numThreads, double modelWeight, Path... modelpath) throws IOException {
+ this(loadModel(conf, modelpath), eta, alpha, dictionary, numThreads, modelWeight);
+ }
+
+ public TopicModel(int numTopics, int numTerms, double eta, double alpha, String[] dictionary,
+ int numThreads, double modelWeight) {
+ this(new DenseMatrix(numTopics, numTerms), new DenseVector(numTopics), eta, alpha, dictionary,
+ numThreads, modelWeight);
+ }
+
+ public TopicModel(int numTopics, int numTerms, double eta, double alpha, Random random,
+ String[] dictionary, int numThreads, double modelWeight) {
+ this(randomMatrix(numTopics, numTerms, random), eta, alpha, dictionary, numThreads, modelWeight);
+ }
+
+ private TopicModel(Pair<Matrix, Vector> model, double eta, double alpha, String[] dict,
+ int numThreads, double modelWeight) {
+ this(model.getFirst(), model.getSecond(), eta, alpha, dict, numThreads, modelWeight);
+ }
+
+ public TopicModel(Matrix topicTermCounts, Vector topicSums, double eta, double alpha,
+ String[] dictionary, double modelWeight) {
+ this(topicTermCounts, topicSums, eta, alpha, dictionary, 1, modelWeight);
+ }
+
+ public TopicModel(Matrix topicTermCounts, double eta, double alpha, String[] dictionary,
+ int numThreads, double modelWeight) {
+ this(topicTermCounts, viewRowSums(topicTermCounts),
+ eta, alpha, dictionary, numThreads, modelWeight);
+ }
+
+ public TopicModel(Matrix topicTermCounts, Vector topicSums, double eta, double alpha,
+ String[] dictionary, int numThreads, double modelWeight) {
+ this.dictionary = dictionary;
+ this.topicTermCounts = topicTermCounts;
+ this.topicSums = topicSums;
+ this.numTopics = topicSums.size();
+ this.numTerms = topicTermCounts.numCols();
+ this.eta = eta;
+ this.alpha = alpha;
+ this.sampler = new Sampler(RandomUtils.getRandom());
+ this.numThreads = numThreads;
+ if (modelWeight != 1) {
+ topicSums.assign(Functions.mult(modelWeight));
+ for (int x = 0; x < numTopics; x++) {
+ topicTermCounts.viewRow(x).assign(Functions.mult(modelWeight));
+ }
+ }
+ initializeThreadPool();
+ }
+
+ private static Vector viewRowSums(Matrix m) {
+ Vector v = new DenseVector(m.numRows());
+ for (MatrixSlice slice : m) {
+ v.set(slice.index(), slice.vector().norm(1));
+ }
+ return v;
+ }
+
+ private synchronized void initializeThreadPool() {
+ if (threadPool != null) {
+ threadPool.shutdown();
+ try {
+ threadPool.awaitTermination(100, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ log.error("Could not terminate all threads for TopicModel in time.", e);
+ }
+ }
+ threadPool = new ThreadPoolExecutor(numThreads, numThreads, 0, TimeUnit.SECONDS,
+ new ArrayBlockingQueue<Runnable>(numThreads * 10));
+ threadPool.allowCoreThreadTimeOut(false);
+ updaters = new Updater[numThreads];
+ for (int i = 0; i < numThreads; i++) {
+ updaters[i] = new Updater();
+ threadPool.submit(updaters[i]);
+ }
+ }
+
+ Matrix topicTermCounts() {
+ return topicTermCounts;
+ }
+
+ @Override
+ public Iterator<MatrixSlice> iterator() {
+ return topicTermCounts.iterateAll();
+ }
+
+ public Vector topicSums() {
+ return topicSums;
+ }
+
+ private static Pair<Matrix,Vector> randomMatrix(int numTopics, int numTerms, Random random) {
+ Matrix topicTermCounts = new DenseMatrix(numTopics, numTerms);
+ Vector topicSums = new DenseVector(numTopics);
+ if (random != null) {
+ for (int x = 0; x < numTopics; x++) {
+ for (int term = 0; term < numTerms; term++) {
+ topicTermCounts.viewRow(x).set(term, random.nextDouble());
+ }
+ }
+ }
+ for (int x = 0; x < numTopics; x++) {
+ topicSums.set(x, random == null ? 1.0 : topicTermCounts.viewRow(x).norm(1));
+ }
+ return Pair.of(topicTermCounts, topicSums);
+ }
+
+ public static Pair<Matrix, Vector> loadModel(Configuration conf, Path... modelPaths)
+ throws IOException {
+ int numTopics = -1;
+ int numTerms = -1;
+ List<Pair<Integer, Vector>> rows = Lists.newArrayList();
+ for (Path modelPath : modelPaths) {
+ for (Pair<IntWritable, VectorWritable> row
+ : new SequenceFileIterable<IntWritable, VectorWritable>(modelPath, true, conf)) {
+ rows.add(Pair.of(row.getFirst().get(), row.getSecond().get()));
+ numTopics = Math.max(numTopics, row.getFirst().get());
+ if (numTerms < 0) {
+ numTerms = row.getSecond().get().size();
+ }
+ }
+ }
+ if (rows.isEmpty()) {
+ throw new IOException(Arrays.toString(modelPaths) + " have no vectors in it");
+ }
+ numTopics++;
+ Matrix model = new DenseMatrix(numTopics, numTerms);
+ Vector topicSums = new DenseVector(numTopics);
+ for (Pair<Integer, Vector> pair : rows) {
+ model.viewRow(pair.getFirst()).assign(pair.getSecond());
+ topicSums.set(pair.getFirst(), pair.getSecond().norm(1));
+ }
+ return Pair.of(model, topicSums);
+ }
+
+ // NOTE: this is purely for debug purposes. It is not performant to "toString()" a real model
+ @Override
+ public String toString() {
+ StringBuilder buf = new StringBuilder();
+ for (int x = 0; x < numTopics; x++) {
+ String v = dictionary != null
+ ? vectorToSortedString(topicTermCounts.viewRow(x).normalize(1), dictionary)
+ : topicTermCounts.viewRow(x).asFormatString();
+ buf.append(v).append('\n');
+ }
+ return buf.toString();
+ }
+
+ public int sampleTerm(Vector topicDistribution) {
+ return sampler.sample(topicTermCounts.viewRow(sampler.sample(topicDistribution)));
+ }
+
+ public int sampleTerm(int topic) {
+ return sampler.sample(topicTermCounts.viewRow(topic));
+ }
+
+ public synchronized void reset() {
+ for (int x = 0; x < numTopics; x++) {
+ topicTermCounts.assignRow(x, new SequentialAccessSparseVector(numTerms));
+ }
+ topicSums.assign(1.0);
+ if (threadPool.isTerminated()) {
+ initializeThreadPool();
+ }
+ }
+
+ public synchronized void stop() {
+ for (Updater updater : updaters) {
+ updater.shutdown();
+ }
+ threadPool.shutdown();
+ try {
+ if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) {
+ log.warn("Threadpool timed out on await termination - jobs still running!");
+ }
+ } catch (InterruptedException e) {
+ log.error("Interrupted shutting down!", e);
+ }
+ }
+
+ public void renormalize() {
+ for (int x = 0; x < numTopics; x++) {
+ topicTermCounts.assignRow(x, topicTermCounts.viewRow(x).normalize(1));
+ topicSums.assign(1.0);
+ }
+ }
+
+ public void trainDocTopicModel(Vector original, Vector topics, Matrix docTopicModel) {
+ // first calculate p(topic|term,document) for all terms in original, and all topics,
+ // using p(term|topic) and p(topic|doc)
+ pTopicGivenTerm(original, topics, docTopicModel);
+ normalizeByTopic(docTopicModel);
+ // now multiply, term-by-term, by the document, to get the weighted distribution of
+ // term-topic pairs from this document.
+ for (Element e : original.nonZeroes()) {
+ for (int x = 0; x < numTopics; x++) {
+ Vector docTopicModelRow = docTopicModel.viewRow(x);
+ docTopicModelRow.setQuick(e.index(), docTopicModelRow.getQuick(e.index()) * e.get());
+ }
+ }
+ // now recalculate \(p(topic|doc)\) by summing contributions from all of pTopicGivenTerm
+ topics.assign(0.0);
+ for (int x = 0; x < numTopics; x++) {
+ topics.set(x, docTopicModel.viewRow(x).norm(1));
+ }
+ // now renormalize so that \(sum_x(p(x|doc))\) = 1
+ topics.assign(Functions.mult(1 / topics.norm(1)));
+ }
+
+ public Vector infer(Vector original, Vector docTopics) {
+ Vector pTerm = original.like();
+ for (Element e : original.nonZeroes()) {
+ int term = e.index();
+ // p(a) = sum_x (p(a|x) * p(x|i))
+ double pA = 0;
+ for (int x = 0; x < numTopics; x++) {
+ pA += (topicTermCounts.viewRow(x).get(term) / topicSums.get(x)) * docTopics.get(x);
+ }
+ pTerm.set(term, pA);
+ }
+ return pTerm;
+ }
+
+ public void update(Matrix docTopicCounts) {
+ for (int x = 0; x < numTopics; x++) {
+ updaters[x % updaters.length].update(x, docTopicCounts.viewRow(x));
+ }
+ }
+
+ public void updateTopic(int topic, Vector docTopicCounts) {
+ topicTermCounts.viewRow(topic).assign(docTopicCounts, Functions.PLUS);
+ topicSums.set(topic, topicSums.get(topic) + docTopicCounts.norm(1));
+ }
+
+ public void update(int termId, Vector topicCounts) {
+ for (int x = 0; x < numTopics; x++) {
+ Vector v = topicTermCounts.viewRow(x);
+ v.set(termId, v.get(termId) + topicCounts.get(x));
+ }
+ topicSums.assign(topicCounts, Functions.PLUS);
+ }
+
+ public void persist(Path outputDir, boolean overwrite) throws IOException {
+ FileSystem fs = outputDir.getFileSystem(conf);
+ if (overwrite) {
+ fs.delete(outputDir, true); // CHECK second arg
+ }
+ DistributedRowMatrixWriter.write(outputDir, conf, topicTermCounts);
+ }
+
+ /**
+ * Computes {@code \(p(topic x | term a, document i)\)} distributions given input document {@code i}.
+ * {@code \(pTGT[x][a]\)} is the (un-normalized) {@code \(p(x|a,i)\)}, or if docTopics is {@code null},
+ * {@code \(p(a|x)\)} (also un-normalized).
+ *
+ * @param document doc-term vector encoding {@code \(w(term a|document i)\)}.
+ * @param docTopics {@code docTopics[x]} is the overall weight of topic {@code x} in given
+ * document. If {@code null}, a topic weight of {@code 1.0} is used for all topics.
+ * @param termTopicDist storage for output {@code \(p(x|a,i)\)} distributions.
+ */
+ private void pTopicGivenTerm(Vector document, Vector docTopics, Matrix termTopicDist) {
+ // for each topic x
+ for (int x = 0; x < numTopics; x++) {
+ // get p(topic x | document i), or 1.0 if docTopics is null
+ double topicWeight = docTopics == null ? 1.0 : docTopics.get(x);
+ // get w(term a | topic x)
+ Vector topicTermRow = topicTermCounts.viewRow(x);
+ // get \sum_a w(term a | topic x)
+ double topicSum = topicSums.get(x);
+ // get p(topic x | term a) distribution to update
+ Vector termTopicRow = termTopicDist.viewRow(x);
+
+ // for each term a in document i with non-zero weight
+ for (Element e : document.nonZeroes()) {
+ int termIndex = e.index();
+
+ // calc un-normalized p(topic x | term a, document i)
+ double termTopicLikelihood = (topicTermRow.get(termIndex) + eta) * (topicWeight + alpha)
+ / (topicSum + eta * numTerms);
+ termTopicRow.set(termIndex, termTopicLikelihood);
+ }
+ }
+ }
+
+ /**
+ * \(sum_x sum_a (c_ai * log(p(x|i) * p(a|x)))\)
+ */
+ public double perplexity(Vector document, Vector docTopics) {
+ double perplexity = 0;
+ double norm = docTopics.norm(1) + (docTopics.size() * alpha);
+ for (Element e : document.nonZeroes()) {
+ int term = e.index();
+ double prob = 0;
+ for (int x = 0; x < numTopics; x++) {
+ double d = (docTopics.get(x) + alpha) / norm;
+ double p = d * (topicTermCounts.viewRow(x).get(term) + eta)
+ / (topicSums.get(x) + eta * numTerms);
+ prob += p;
+ }
+ perplexity += e.get() * Math.log(prob);
+ }
+ return -perplexity;
+ }
+
+ private void normalizeByTopic(Matrix perTopicSparseDistributions) {
+ // then make sure that each of these is properly normalized by topic: sum_x(p(x|t,d)) = 1
+ for (Element e : perTopicSparseDistributions.viewRow(0).nonZeroes()) {
+ int a = e.index();
+ double sum = 0;
+ for (int x = 0; x < numTopics; x++) {
+ sum += perTopicSparseDistributions.viewRow(x).get(a);
+ }
+ for (int x = 0; x < numTopics; x++) {
+ perTopicSparseDistributions.viewRow(x).set(a,
+ perTopicSparseDistributions.viewRow(x).get(a) / sum);
+ }
+ }
+ }
+
+ public static String vectorToSortedString(Vector vector, String[] dictionary) {
+ List<Pair<String,Double>> vectorValues = Lists.newArrayListWithCapacity(vector.getNumNondefaultElements());
+ for (Element e : vector.nonZeroes()) {
+ vectorValues.add(Pair.of(dictionary != null ? dictionary[e.index()] : String.valueOf(e.index()),
+ e.get()));
+ }
+ Collections.sort(vectorValues, new Comparator<Pair<String, Double>>() {
+ @Override public int compare(Pair<String, Double> x, Pair<String, Double> y) {
+ return y.getSecond().compareTo(x.getSecond());
+ }
+ });
+ Iterator<Pair<String,Double>> listIt = vectorValues.iterator();
+ StringBuilder bldr = new StringBuilder(2048);
+ bldr.append('{');
+ int i = 0;
+ while (listIt.hasNext() && i < 25) {
+ i++;
+ Pair<String,Double> p = listIt.next();
+ bldr.append(p.getFirst());
+ bldr.append(':');
+ bldr.append(p.getSecond());
+ bldr.append(',');
+ }
+ if (bldr.length() > 1) {
+ bldr.setCharAt(bldr.length() - 1, '}');
+ }
+ return bldr.toString();
+ }
+
+ @Override
+ public void setConf(Configuration configuration) {
+ this.conf = configuration;
+ }
+
+ @Override
+ public Configuration getConf() {
+ return conf;
+ }
+
+ private final class Updater implements Runnable {
+ private final ArrayBlockingQueue<Pair<Integer, Vector>> queue =
+ new ArrayBlockingQueue<>(100);
+ private boolean shutdown = false;
+ private boolean shutdownComplete = false;
+
+ public void shutdown() {
+ try {
+ synchronized (this) {
+ while (!shutdownComplete) {
+ shutdown = true;
+ wait(10000L); // Arbitrarily, wait 10 seconds rather than forever for this
+ }
+ }
+ } catch (InterruptedException e) {
+ log.warn("Interrupted waiting to shutdown() : ", e);
+ }
+ }
+
+ public boolean update(int topic, Vector v) {
+ if (shutdown) { // maybe don't do this?
+ throw new IllegalStateException("In SHUTDOWN state: cannot submit tasks");
+ }
+ while (true) { // keep trying if interrupted
+ try {
+ // start async operation by submitting to the queue
+ queue.put(Pair.of(topic, v));
+ // return once you got access to the queue
+ return true;
+ } catch (InterruptedException e) {
+ log.warn("Interrupted trying to queue update:", e);
+ }
+ }
+ }
+
+ @Override
+ public void run() {
+ while (!shutdown) {
+ try {
+ Pair<Integer, Vector> pair = queue.poll(1, TimeUnit.SECONDS);
+ if (pair != null) {
+ updateTopic(pair.getFirst(), pair.getSecond());
+ }
+ } catch (InterruptedException e) {
+ log.warn("Interrupted waiting to poll for update", e);
+ }
+ }
+ // in shutdown mode, finish remaining tasks!
+ for (Pair<Integer, Vector> pair : queue) {
+ updateTopic(pair.getFirst(), pair.getSecond());
+ }
+ synchronized (this) {
+ shutdownComplete = true;
+ notifyAll();
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/package-info.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/package-info.java b/mr/src/main/java/org/apache/mahout/clustering/package-info.java
new file mode 100644
index 0000000..9926b91
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/package-info.java
@@ -0,0 +1,13 @@
+/**
+ * <p></p>This package provides several clustering algorithm implementations. Clustering usually groups a set of
+ * objects into groups of similar items. The definition of similarity usually is up to you - for text documents,
+ * cosine-distance/-similarity is recommended. Mahout also features other types of distance measure like
+ * Euclidean distance.</p>
+ *
+ * <p></p>Input of each clustering algorithm is a set of vectors representing your items. For texts in general these are
+ * <a href="http://en.wikipedia.org/wiki/TFIDF">TFIDF</a> or
+ * <a href="http://en.wikipedia.org/wiki/Bag_of_words">Bag of words</a> representations of the documents.</p>
+ *
+ * <p>Output of each clustering algorithm is either a hard or soft assignment of items to clusters.</p>
+ */
+package org.apache.mahout.clustering;
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputJob.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputJob.java
new file mode 100644
index 0000000..aa12b9e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputJob.java
@@ -0,0 +1,84 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+
+public final class AffinityMatrixInputJob {
+
+ private AffinityMatrixInputJob() {
+ }
+
+ /**
+ * Initializes and executes the job of reading the documents containing
+ * the data of the affinity matrix in (x_i, x_j, value) format.
+ */
+ public static void runJob(Path input, Path output, int rows, int cols)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, output);
+
+ conf.setInt(Keys.AFFINITY_DIMENSIONS, rows);
+ Job job = new Job(conf, "AffinityMatrixInputJob: " + input + " -> M/R -> " + output);
+
+ job.setMapOutputKeyClass(IntWritable.class);
+ job.setMapOutputValueClass(DistributedRowMatrix.MatrixEntryWritable.class);
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setMapperClass(AffinityMatrixInputMapper.class);
+ job.setReducerClass(AffinityMatrixInputReducer.class);
+
+ FileInputFormat.addInputPath(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setJarByClass(AffinityMatrixInputJob.class);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+ /**
+ * A transparent wrapper for the above method which handles the tedious tasks
+ * of setting and retrieving system Paths. Hands back a fully-populated
+ * and initialized DistributedRowMatrix.
+ */
+ public static DistributedRowMatrix runJob(Path input, Path output, int dimensions)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Path seqFiles = new Path(output, "seqfiles-" + (System.nanoTime() & 0xFF));
+ runJob(input, seqFiles, dimensions, dimensions);
+ DistributedRowMatrix a = new DistributedRowMatrix(seqFiles,
+ new Path(seqFiles, "seqtmp-" + (System.nanoTime() & 0xFF)),
+ dimensions, dimensions);
+ a.setConf(new Configuration());
+ return a;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputMapper.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputMapper.java
new file mode 100644
index 0000000..30d2404
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputMapper.java
@@ -0,0 +1,78 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>Handles reading the files representing the affinity matrix. Since the affinity
+ * matrix is representative of a graph, each line in all the files should
+ * take the form:</p>
+ *
+ * {@code i,j,value}
+ *
+ * <p>where {@code i} and {@code j} are the {@code i}th and
+ * {@code j} data points in the entire set, and {@code value}
+ * represents some measurement of their relative absolute magnitudes. This
+ * is, simply, a method for representing a graph textually.
+ */
+public class AffinityMatrixInputMapper
+ extends Mapper<LongWritable, Text, IntWritable, DistributedRowMatrix.MatrixEntryWritable> {
+
+ private static final Logger log = LoggerFactory.getLogger(AffinityMatrixInputMapper.class);
+
+ private static final Pattern COMMA_PATTERN = Pattern.compile(",");
+
+ @Override
+ protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
+
+ String[] elements = COMMA_PATTERN.split(value.toString());
+ log.debug("(DEBUG - MAP) Key[{}], Value[{}]", key.get(), value);
+
+ // enforce well-formed textual representation of the graph
+ if (elements.length != 3) {
+ throw new IOException("Expected input of length 3, received "
+ + elements.length + ". Please make sure you adhere to "
+ + "the structure of (i,j,value) for representing a graph in text. "
+ + "Input line was: '" + value + "'.");
+ }
+ if (elements[0].isEmpty() || elements[1].isEmpty() || elements[2].isEmpty()) {
+ throw new IOException("Found an element of 0 length. Please be sure you adhere to the structure of "
+ + "(i,j,value) for representing a graph in text.");
+ }
+
+ // parse the line of text into a DistributedRowMatrix entry,
+ // making the row (elements[0]) the key to the Reducer, and
+ // setting the column (elements[1]) in the entry itself
+ DistributedRowMatrix.MatrixEntryWritable toAdd = new DistributedRowMatrix.MatrixEntryWritable();
+ IntWritable row = new IntWritable(Integer.valueOf(elements[0]));
+ toAdd.setRow(-1); // already set as the Reducer's key
+ toAdd.setCol(Integer.valueOf(elements[1]));
+ toAdd.setVal(Double.valueOf(elements[2]));
+ context.write(row, toAdd);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputReducer.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputReducer.java
new file mode 100644
index 0000000..d892969
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputReducer.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Tasked with taking each DistributedRowMatrix entry and collecting them
+ * into vectors corresponding to rows. The input and output keys are the same,
+ * corresponding to the row in the ensuing matrix. The matrix entries are
+ * entered into a vector according to the column to which they belong, and
+ * the vector is then given the key corresponding to its row.
+ */
+public class AffinityMatrixInputReducer
+ extends Reducer<IntWritable, DistributedRowMatrix.MatrixEntryWritable, IntWritable, VectorWritable> {
+
+ private static final Logger log = LoggerFactory.getLogger(AffinityMatrixInputReducer.class);
+
+ @Override
+ protected void reduce(IntWritable row, Iterable<DistributedRowMatrix.MatrixEntryWritable> values, Context context)
+ throws IOException, InterruptedException {
+ int size = context.getConfiguration().getInt(Keys.AFFINITY_DIMENSIONS, Integer.MAX_VALUE);
+ RandomAccessSparseVector out = new RandomAccessSparseVector(size, 100);
+
+ for (DistributedRowMatrix.MatrixEntryWritable element : values) {
+ out.setQuick(element.getCol(), element.getVal());
+ if (log.isDebugEnabled()) {
+ log.debug("(DEBUG - REDUCE) Row[{}], Column[{}], Value[{}]",
+ row.get(), element.getCol(), element.getVal());
+ }
+ }
+ SequentialAccessSparseVector output = new SequentialAccessSparseVector(out);
+ context.write(row, new VectorWritable(output));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/IntDoublePairWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/IntDoublePairWritable.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/IntDoublePairWritable.java
new file mode 100644
index 0000000..593cc58
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/IntDoublePairWritable.java
@@ -0,0 +1,75 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+
+/**
+ * This class is a Writable implementation of the mahout.common.Pair
+ * generic class. Since the generic types would also themselves have to
+ * implement Writable, it made more sense to create a more specialized
+ * version of the class altogether.
+ *
+ * In essence, this can be treated as a single Vector Element.
+ */
+public class IntDoublePairWritable implements Writable {
+
+ private int key;
+ private double value;
+
+ public IntDoublePairWritable() {
+ }
+
+ public IntDoublePairWritable(int k, double v) {
+ this.key = k;
+ this.value = v;
+ }
+
+ public void setKey(int k) {
+ this.key = k;
+ }
+
+ public void setValue(double v) {
+ this.value = v;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.key = in.readInt();
+ this.value = in.readDouble();
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(key);
+ out.writeDouble(value);
+ }
+
+ public int getKey() {
+ return key;
+ }
+
+ public double getValue() {
+ return value;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/Keys.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/Keys.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/Keys.java
new file mode 100644
index 0000000..268a365
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/Keys.java
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+public class Keys {
+
+ /**
+ * Sets the SequenceFile index for the diagonal matrix.
+ */
+ public static final int DIAGONAL_CACHE_INDEX = 1;
+
+ public static final String AFFINITY_DIMENSIONS = "org.apache.mahout.clustering.spectral.common.affinitydimensions";
+
+ private Keys() {}
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/MatrixDiagonalizeJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/MatrixDiagonalizeJob.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/MatrixDiagonalizeJob.java
new file mode 100644
index 0000000..f245f99
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/MatrixDiagonalizeJob.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * Given a matrix, this job returns a vector whose i_th element is the
+ * sum of all the elements in the i_th row of the original matrix.
+ */
+public final class MatrixDiagonalizeJob {
+
+ private MatrixDiagonalizeJob() {
+ }
+
+ public static Vector runJob(Path affInput, int dimensions)
+ throws IOException, ClassNotFoundException, InterruptedException {
+
+ // set up all the job tasks
+ Configuration conf = new Configuration();
+ Path diagOutput = new Path(affInput.getParent(), "diagonal");
+ HadoopUtil.delete(conf, diagOutput);
+ conf.setInt(Keys.AFFINITY_DIMENSIONS, dimensions);
+ Job job = new Job(conf, "MatrixDiagonalizeJob");
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setMapOutputKeyClass(NullWritable.class);
+ job.setMapOutputValueClass(IntDoublePairWritable.class);
+ job.setOutputKeyClass(NullWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setMapperClass(MatrixDiagonalizeMapper.class);
+ job.setReducerClass(MatrixDiagonalizeReducer.class);
+
+ FileInputFormat.addInputPath(job, affInput);
+ FileOutputFormat.setOutputPath(job, diagOutput);
+
+ job.setJarByClass(MatrixDiagonalizeJob.class);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+
+ // read the results back from the path
+ return VectorCache.load(conf, new Path(diagOutput, "part-r-00000"));
+ }
+
+ public static class MatrixDiagonalizeMapper
+ extends Mapper<IntWritable, VectorWritable, NullWritable, IntDoublePairWritable> {
+
+ @Override
+ protected void map(IntWritable key, VectorWritable row, Context context)
+ throws IOException, InterruptedException {
+ // store the sum
+ IntDoublePairWritable store = new IntDoublePairWritable(key.get(), row.get().zSum());
+ context.write(NullWritable.get(), store);
+ }
+ }
+
+ public static class MatrixDiagonalizeReducer
+ extends Reducer<NullWritable, IntDoublePairWritable, NullWritable, VectorWritable> {
+
+ @Override
+ protected void reduce(NullWritable key, Iterable<IntDoublePairWritable> values,
+ Context context) throws IOException, InterruptedException {
+ // create the return vector
+ Vector retval = new DenseVector(context.getConfiguration().getInt(Keys.AFFINITY_DIMENSIONS, Integer.MAX_VALUE));
+ // put everything in its correct spot
+ for (IntDoublePairWritable e : values) {
+ retval.setQuick(e.getKey(), e.getValue());
+ }
+ // write it out
+ context.write(key, new VectorWritable(retval));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/UnitVectorizerJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/UnitVectorizerJob.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/UnitVectorizerJob.java
new file mode 100644
index 0000000..56cb237
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/UnitVectorizerJob.java
@@ -0,0 +1,79 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * <p>Given a DistributedRowMatrix, this job normalizes each row to unit
+ * vector length. If the input is a matrix U, and the output is a matrix
+ * W, the job follows:</p>
+ *
+ * <p>{@code v_ij = u_ij / sqrt(sum_j(u_ij * u_ij))}</p>
+ */
+public final class UnitVectorizerJob {
+
+ private UnitVectorizerJob() {
+ }
+
+ public static void runJob(Path input, Path output)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ Configuration conf = new Configuration();
+ Job job = new Job(conf, "UnitVectorizerJob");
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setMapperClass(UnitVectorizerMapper.class);
+ job.setNumReduceTasks(0);
+
+ FileInputFormat.addInputPath(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setJarByClass(UnitVectorizerJob.class);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+ public static class UnitVectorizerMapper
+ extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
+
+ @Override
+ protected void map(IntWritable row, VectorWritable vector, Context context)
+ throws IOException, InterruptedException {
+ context.write(row, new VectorWritable(vector.get().normalize(2)));
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/VectorCache.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/VectorCache.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/VectorCache.java
new file mode 100644
index 0000000..60e0a2e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/VectorCache.java
@@ -0,0 +1,123 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Arrays;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * This class handles reading and writing vectors to the Hadoop
+ * distributed cache. Created as a result of Eigencuts' liberal use
+ * of such functionality, but available to any algorithm requiring it.
+ */
+public final class VectorCache {
+
+ private static final Logger log = LoggerFactory.getLogger(VectorCache.class);
+
+ private VectorCache() {
+ }
+
+ /**
+ * @param key SequenceFile key
+ * @param vector Vector to save, to be wrapped as VectorWritable
+ */
+ public static void save(Writable key,
+ Vector vector,
+ Path output,
+ Configuration conf,
+ boolean overwritePath,
+ boolean deleteOnExit) throws IOException {
+
+ FileSystem fs = FileSystem.get(output.toUri(), conf);
+ output = fs.makeQualified(output);
+ if (overwritePath) {
+ HadoopUtil.delete(conf, output);
+ }
+
+ // set the cache
+ DistributedCache.setCacheFiles(new URI[]{output.toUri()}, conf);
+
+ // set up the writer
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, output,
+ IntWritable.class, VectorWritable.class);
+ try {
+ writer.append(key, new VectorWritable(vector));
+ } finally {
+ Closeables.close(writer, false);
+ }
+
+ if (deleteOnExit) {
+ fs.deleteOnExit(output);
+ }
+ }
+
+ /**
+ * Calls the save() method, setting the cache to overwrite any previous
+ * Path and to delete the path after exiting
+ */
+ public static void save(Writable key, Vector vector, Path output, Configuration conf) throws IOException {
+ save(key, vector, output, conf, true, true);
+ }
+
+ /**
+ * Loads the vector from {@link DistributedCache}. Returns null if no vector exists.
+ */
+ public static Vector load(Configuration conf) throws IOException {
+ Path[] files = HadoopUtil.getCachedFiles(conf);
+
+ if (files.length != 1) {
+ throw new IOException("Cannot read Frequency list from Distributed Cache (" + files.length + ')');
+ }
+
+ if (log.isInfoEnabled()) {
+ log.info("Files are: {}", Arrays.toString(files));
+ }
+ return load(conf, files[0]);
+ }
+
+ /**
+ * Loads a Vector from the specified path. Returns null if no vector exists.
+ */
+ public static Vector load(Configuration conf, Path input) throws IOException {
+ log.info("Loading vector from: {}", input);
+ SequenceFileValueIterator<VectorWritable> iterator =
+ new SequenceFileValueIterator<>(input, true, conf);
+ try {
+ return iterator.next().get();
+ } finally {
+ Closeables.close(iterator, true);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/VectorMatrixMultiplicationJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/VectorMatrixMultiplicationJob.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/VectorMatrixMultiplicationJob.java
new file mode 100644
index 0000000..c42ab70
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/VectorMatrixMultiplicationJob.java
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+
+/**
+ * <p>This class handles the three-way multiplication of the digonal matrix
+ * and the Markov transition matrix inherent in the Eigencuts algorithm.
+ * The equation takes the form:</p>
+ *
+ * {@code W = D^(1/2) * M * D^(1/2)}
+ *
+ * <p>Since the diagonal matrix D has only n non-zero elements, it is represented
+ * as a dense vector in this job, rather than a full n-by-n matrix. This job
+ * performs the multiplications and returns the new DRM.
+ */
+public final class VectorMatrixMultiplicationJob {
+
+ private VectorMatrixMultiplicationJob() {
+ }
+
+ /**
+ * Invokes the job.
+ * @param markovPath Path to the markov DRM's sequence files
+ */
+ public static DistributedRowMatrix runJob(Path markovPath, Vector diag, Path outputPath)
+ throws IOException, ClassNotFoundException, InterruptedException {
+
+ return runJob(markovPath, diag, outputPath, new Path(outputPath, "tmp"));
+ }
+
+ public static DistributedRowMatrix runJob(Path markovPath, Vector diag, Path outputPath, Path tmpPath)
+ throws IOException, ClassNotFoundException, InterruptedException {
+
+ // set up the serialization of the diagonal vector
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(markovPath.toUri(), conf);
+ markovPath = fs.makeQualified(markovPath);
+ outputPath = fs.makeQualified(outputPath);
+ Path vectorOutputPath = new Path(outputPath.getParent(), "vector");
+ VectorCache.save(new IntWritable(Keys.DIAGONAL_CACHE_INDEX), diag, vectorOutputPath, conf);
+
+ // set up the job itself
+ Job job = new Job(conf, "VectorMatrixMultiplication");
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setMapperClass(VectorMatrixMultiplicationMapper.class);
+ job.setNumReduceTasks(0);
+
+ FileInputFormat.addInputPath(job, markovPath);
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ job.setJarByClass(VectorMatrixMultiplicationJob.class);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+
+ // build the resulting DRM from the results
+ return new DistributedRowMatrix(outputPath, tmpPath,
+ diag.size(), diag.size());
+ }
+
+ public static class VectorMatrixMultiplicationMapper
+ extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
+
+ private Vector diagonal;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ // read in the diagonal vector from the distributed cache
+ super.setup(context);
+ Configuration config = context.getConfiguration();
+ diagonal = VectorCache.load(config);
+ if (diagonal == null) {
+ throw new IOException("No vector loaded from cache!");
+ }
+ if (!(diagonal instanceof DenseVector)) {
+ diagonal = new DenseVector(diagonal);
+ }
+ }
+
+ @Override
+ protected void map(IntWritable key, VectorWritable row, Context ctx)
+ throws IOException, InterruptedException {
+
+ for (Vector.Element e : row.get().all()) {
+ double dii = Functions.SQRT.apply(diagonal.get(key.get()));
+ double djj = Functions.SQRT.apply(diagonal.get(e.index()));
+ double mij = e.get();
+ e.set(dii * mij * djj);
+ }
+ ctx.write(key, row);
+ }
+
+ /**
+ * Performs the setup of the Mapper. Used by unit tests.
+ * @param diag
+ */
+ void setup(Vector diag) {
+ this.diagonal = diag;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/VertexWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/VertexWritable.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/VertexWritable.java
new file mode 100644
index 0000000..0d70cac
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/VertexWritable.java
@@ -0,0 +1,101 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+
+/**
+ * Represents a vertex within the affinity graph for Eigencuts.
+ */
+public class VertexWritable implements Writable {
+
+ /** the row */
+ private int i;
+
+ /** the column */
+ private int j;
+
+ /** the value at this vertex */
+ private double value;
+
+ /** an extra type delimeter, can probably be null */
+ private String type;
+
+ public VertexWritable() {
+ }
+
+ public VertexWritable(int i, int j, double v, String t) {
+ this.i = i;
+ this.j = j;
+ this.value = v;
+ this.type = t;
+ }
+
+ public int getRow() {
+ return i;
+ }
+
+ public void setRow(int i) {
+ this.i = i;
+ }
+
+ public int getCol() {
+ return j;
+ }
+
+ public void setCol(int j) {
+ this.j = j;
+ }
+
+ public double getValue() {
+ return value;
+ }
+
+ public void setValue(double v) {
+ this.value = v;
+ }
+
+ public String getType() {
+ return type;
+ }
+
+ public void setType(String t) {
+ this.type = t;
+ }
+
+ @Override
+ public void readFields(DataInput arg0) throws IOException {
+ this.i = arg0.readInt();
+ this.j = arg0.readInt();
+ this.value = arg0.readDouble();
+ this.type = arg0.readUTF();
+ }
+
+ @Override
+ public void write(DataOutput arg0) throws IOException {
+ arg0.writeInt(i);
+ arg0.writeInt(j);
+ arg0.writeDouble(value);
+ arg0.writeUTF(type);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/EigenSeedGenerator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/EigenSeedGenerator.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/EigenSeedGenerator.java
new file mode 100644
index 0000000..5f9c1a6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/EigenSeedGenerator.java
@@ -0,0 +1,124 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral.kmeans;
+
+import java.io.IOException;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.clustering.kmeans.Kluster;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+
+/**
+ * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, select k vectors and write them to the
+ * output file as a {@link org.apache.mahout.clustering.kmeans.Kluster} representing the initial centroid to use. The
+ * selection criterion is the rows with max value in that respective column
+ */
+public final class EigenSeedGenerator {
+
+ private static final Logger log = LoggerFactory.getLogger(EigenSeedGenerator.class);
+
+ public static final String K = "k";
+
+ private EigenSeedGenerator() {}
+
+ public static Path buildFromEigens(Configuration conf, Path input, Path output, int k, DistanceMeasure measure)
+ throws IOException {
+ // delete the output directory
+ FileSystem fs = FileSystem.get(output.toUri(), conf);
+ HadoopUtil.delete(conf, output);
+ Path outFile = new Path(output, "part-eigenSeed");
+ boolean newFile = fs.createNewFile(outFile);
+ if (newFile) {
+ Path inputPathPattern;
+
+ if (fs.getFileStatus(input).isDir()) {
+ inputPathPattern = new Path(input, "*");
+ } else {
+ inputPathPattern = input;
+ }
+
+ FileStatus[] inputFiles = fs.globStatus(inputPathPattern, PathFilters.logsCRCFilter());
+ SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outFile, Text.class, ClusterWritable.class);
+ Map<Integer,Double> maxEigens = Maps.newHashMapWithExpectedSize(k); // store
+ // max
+ // value
+ // of
+ // each
+ // column
+ Map<Integer,Text> chosenTexts = Maps.newHashMapWithExpectedSize(k);
+ Map<Integer,ClusterWritable> chosenClusters = Maps.newHashMapWithExpectedSize(k);
+
+ for (FileStatus fileStatus : inputFiles) {
+ if (!fileStatus.isDir()) {
+ for (Pair<Writable,VectorWritable> record : new SequenceFileIterable<Writable,VectorWritable>(
+ fileStatus.getPath(), true, conf)) {
+ Writable key = record.getFirst();
+ VectorWritable value = record.getSecond();
+
+ for (Vector.Element e : value.get().nonZeroes()) {
+ int index = e.index();
+ double v = Math.abs(e.get());
+
+ if (!maxEigens.containsKey(index) || v > maxEigens.get(index)) {
+ maxEigens.put(index, v);
+ Text newText = new Text(key.toString());
+ chosenTexts.put(index, newText);
+ Kluster newCluster = new Kluster(value.get(), index, measure);
+ newCluster.observe(value.get(), 1);
+ ClusterWritable clusterWritable = new ClusterWritable();
+ clusterWritable.setValue(newCluster);
+ chosenClusters.put(index, clusterWritable);
+ }
+ }
+ }
+ }
+ }
+
+ try {
+ for (Integer key : maxEigens.keySet()) {
+ writer.append(chosenTexts.get(key), chosenClusters.get(key));
+ }
+ log.info("EigenSeedGenerator:: Wrote {} Klusters to {}", chosenTexts.size(), outFile);
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ return outFile;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/SpectralKMeansDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/SpectralKMeansDriver.java b/mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/SpectralKMeansDriver.java
new file mode 100644
index 0000000..427de91
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/SpectralKMeansDriver.java
@@ -0,0 +1,243 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ * <p/>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p/>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral.kmeans;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.WeightedVectorWritable;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.clustering.spectral.AffinityMatrixInputJob;
+import org.apache.mahout.clustering.spectral.MatrixDiagonalizeJob;
+import org.apache.mahout.clustering.spectral.UnitVectorizerJob;
+import org.apache.mahout.clustering.spectral.VectorMatrixMultiplicationJob;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.hadoop.stochasticsvd.SSVDSolver;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Performs spectral k-means clustering on the top k eigenvectors of the input affinity matrix.
+ */
+public class SpectralKMeansDriver extends AbstractJob {
+ private static final Logger log = LoggerFactory.getLogger(SpectralKMeansDriver.class);
+
+ public static final int REDUCERS = 10;
+ public static final int BLOCKHEIGHT = 30000;
+ public static final int OVERSAMPLING = 15;
+ public static final int POWERITERS = 0;
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new SpectralKMeansDriver(), args);
+ }
+
+ @Override
+ public int run(String[] arg0) throws Exception {
+
+ Configuration conf = getConf();
+ addInputOption();
+ addOutputOption();
+ addOption("dimensions", "d", "Square dimensions of affinity matrix", true);
+ addOption("clusters", "k", "Number of clusters and top eigenvectors", true);
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+ addOption(DefaultOptionCreator.convergenceOption().create());
+ addOption(DefaultOptionCreator.maxIterationsOption().create());
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addFlag("usessvd", "ssvd", "Uses SSVD as the eigensolver. Default is the Lanczos solver.");
+ addOption("reduceTasks", "t", "Number of reducers for SSVD", String.valueOf(REDUCERS));
+ addOption("outerProdBlockHeight", "oh", "Block height of outer products for SSVD", String.valueOf(BLOCKHEIGHT));
+ addOption("oversampling", "p", "Oversampling parameter for SSVD", String.valueOf(OVERSAMPLING));
+ addOption("powerIter", "q", "Additional power iterations for SSVD", String.valueOf(POWERITERS));
+
+ Map<String, List<String>> parsedArgs = parseArguments(arg0);
+ if (parsedArgs == null) {
+ return 0;
+ }
+
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(conf, getTempPath());
+ HadoopUtil.delete(conf, getOutputPath());
+ }
+ int numDims = Integer.parseInt(getOption("dimensions"));
+ int clusters = Integer.parseInt(getOption("clusters"));
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
+ double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION));
+ int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
+
+ Path tempdir = new Path(getOption("tempDir"));
+ int reducers = Integer.parseInt(getOption("reduceTasks"));
+ int blockheight = Integer.parseInt(getOption("outerProdBlockHeight"));
+ int oversampling = Integer.parseInt(getOption("oversampling"));
+ int poweriters = Integer.parseInt(getOption("powerIter"));
+ run(conf, input, output, numDims, clusters, measure, convergenceDelta, maxIterations, tempdir, reducers,
+ blockheight, oversampling, poweriters);
+
+ return 0;
+ }
+
+ public static void run(Configuration conf, Path input, Path output, int numDims, int clusters,
+ DistanceMeasure measure, double convergenceDelta, int maxIterations, Path tempDir)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ run(conf, input, output, numDims, clusters, measure, convergenceDelta, maxIterations, tempDir, REDUCERS,
+ BLOCKHEIGHT, OVERSAMPLING, POWERITERS);
+ }
+
+ /**
+ * Run the Spectral KMeans clustering on the supplied arguments
+ *
+ * @param conf
+ * the Configuration to be used
+ * @param input
+ * the Path to the input tuples directory
+ * @param output
+ * the Path to the output directory
+ * @param numDims
+ * the int number of dimensions of the affinity matrix
+ * @param clusters
+ * the int number of eigenvectors and thus clusters to produce
+ * @param measure
+ * the DistanceMeasure for the k-Means calculations
+ * @param convergenceDelta
+ * the double convergence delta for the k-Means calculations
+ * @param maxIterations
+ * the int maximum number of iterations for the k-Means calculations
+ * @param tempDir
+ * Temporary directory for intermediate calculations
+ * @param numReducers
+ * Number of reducers
+ * @param blockHeight
+ * @param oversampling
+ * @param poweriters
+ */
+ public static void run(Configuration conf, Path input, Path output, int numDims, int clusters,
+ DistanceMeasure measure, double convergenceDelta, int maxIterations, Path tempDir,
+ int numReducers, int blockHeight, int oversampling, int poweriters)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ HadoopUtil.delete(conf, tempDir);
+ Path outputCalc = new Path(tempDir, "calculations");
+ Path outputTmp = new Path(tempDir, "temporary");
+
+ // Take in the raw CSV text file and split it ourselves,
+ // creating our own SequenceFiles for the matrices to read later
+ // (similar to the style of syntheticcontrol.canopy.InputMapper)
+ Path affSeqFiles = new Path(outputCalc, "seqfile");
+ AffinityMatrixInputJob.runJob(input, affSeqFiles, numDims, numDims);
+
+ // Construct the affinity matrix using the newly-created sequence files
+ DistributedRowMatrix A = new DistributedRowMatrix(affSeqFiles, new Path(outputTmp, "afftmp"), numDims, numDims);
+
+ Configuration depConf = new Configuration(conf);
+ A.setConf(depConf);
+
+ // Construct the diagonal matrix D (represented as a vector)
+ Vector D = MatrixDiagonalizeJob.runJob(affSeqFiles, numDims);
+
+ // Calculate the normalized Laplacian of the form: L = D^(-0.5)AD^(-0.5)
+ DistributedRowMatrix L = VectorMatrixMultiplicationJob.runJob(affSeqFiles, D, new Path(outputCalc, "laplacian"),
+ new Path(outputCalc, outputCalc));
+ L.setConf(depConf);
+
+ Path data;
+
+ // SSVD requires an array of Paths to function. So we pass in an array of length one
+ Path[] LPath = new Path[1];
+ LPath[0] = L.getRowPath();
+
+ Path SSVDout = new Path(outputCalc, "SSVD");
+
+ SSVDSolver solveIt = new SSVDSolver(depConf, LPath, SSVDout, blockHeight, clusters, oversampling, numReducers);
+
+ solveIt.setComputeV(false);
+ solveIt.setComputeU(true);
+ solveIt.setOverwrite(true);
+ solveIt.setQ(poweriters);
+ // solveIt.setBroadcast(false);
+ solveIt.run();
+ data = new Path(solveIt.getUPath());
+
+ // Normalize the rows of Wt to unit length
+ // normalize is important because it reduces the occurrence of two unique clusters combining into one
+ Path unitVectors = new Path(outputCalc, "unitvectors");
+
+ UnitVectorizerJob.runJob(data, unitVectors);
+
+ DistributedRowMatrix Wt = new DistributedRowMatrix(unitVectors, new Path(unitVectors, "tmp"), clusters, numDims);
+ Wt.setConf(depConf);
+ data = Wt.getRowPath();
+
+ // Generate initial clusters using EigenSeedGenerator which picks rows as centroids if that row contains max
+ // eigen value in that column
+ Path initialclusters = EigenSeedGenerator.buildFromEigens(conf, data,
+ new Path(output, Cluster.INITIAL_CLUSTERS_DIR), clusters, measure);
+
+ // Run the KMeansDriver
+ Path answer = new Path(output, "kmeans_out");
+ KMeansDriver.run(conf, data, initialclusters, answer, convergenceDelta, maxIterations, true, 0.0, false);
+
+ // Restore name to id mapping and read through the cluster assignments
+ Path mappingPath = new Path(new Path(conf.get("hadoop.tmp.dir")), "generic_input_mapping");
+ List<String> mapping = new ArrayList<>();
+ FileSystem fs = FileSystem.get(mappingPath.toUri(), conf);
+ if (fs.exists(mappingPath)) {
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, mappingPath, conf);
+ Text mappingValue = new Text();
+ IntWritable mappingIndex = new IntWritable();
+ while (reader.next(mappingIndex, mappingValue)) {
+ String s = mappingValue.toString();
+ mapping.add(s);
+ }
+ HadoopUtil.delete(conf, mappingPath);
+ } else {
+ log.warn("generic input mapping file not found!");
+ }
+
+ Path clusteredPointsPath = new Path(answer, "clusteredPoints");
+ Path inputPath = new Path(clusteredPointsPath, "part-m-00000");
+ int id = 0;
+ for (Pair<IntWritable, WeightedVectorWritable> record :
+ new SequenceFileIterable<IntWritable, WeightedVectorWritable>(inputPath, conf)) {
+ if (!mapping.isEmpty()) {
+ log.info("{}: {}", mapping.get(id++), record.getFirst().get());
+ } else {
+ log.info("{}: {}", id++, record.getFirst().get());
+ }
+ }
+ }
+}
[21/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/MatrixUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/MatrixUtils.java b/mr/src/main/java/org/apache/mahout/math/MatrixUtils.java
new file mode 100644
index 0000000..f9ca52e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/MatrixUtils.java
@@ -0,0 +1,114 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+import java.io.IOException;
+import java.util.List;
+
+public final class MatrixUtils {
+
+ private MatrixUtils() {
+ }
+
+ public static void write(Path outputDir, Configuration conf, VectorIterable matrix)
+ throws IOException {
+ FileSystem fs = outputDir.getFileSystem(conf);
+ fs.delete(outputDir, true);
+ SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outputDir,
+ IntWritable.class, VectorWritable.class);
+ IntWritable topic = new IntWritable();
+ VectorWritable vector = new VectorWritable();
+ for (MatrixSlice slice : matrix) {
+ topic.set(slice.index());
+ vector.set(slice.vector());
+ writer.append(topic, vector);
+ }
+ writer.close();
+ }
+
+ public static Matrix read(Configuration conf, Path... modelPaths) throws IOException {
+ int numRows = -1;
+ int numCols = -1;
+ boolean sparse = false;
+ List<Pair<Integer, Vector>> rows = Lists.newArrayList();
+ for (Path modelPath : modelPaths) {
+ for (Pair<IntWritable, VectorWritable> row
+ : new SequenceFileIterable<IntWritable, VectorWritable>(modelPath, true, conf)) {
+ rows.add(Pair.of(row.getFirst().get(), row.getSecond().get()));
+ numRows = Math.max(numRows, row.getFirst().get());
+ sparse = !row.getSecond().get().isDense();
+ if (numCols < 0) {
+ numCols = row.getSecond().get().size();
+ }
+ }
+ }
+ if (rows.isEmpty()) {
+ throw new IOException(Arrays.toString(modelPaths) + " have no vectors in it");
+ }
+ numRows++;
+ Vector[] arrayOfRows = new Vector[numRows];
+ for (Pair<Integer, Vector> pair : rows) {
+ arrayOfRows[pair.getFirst()] = pair.getSecond();
+ }
+ Matrix matrix;
+ if (sparse) {
+ matrix = new SparseRowMatrix(numRows, numCols, arrayOfRows);
+ } else {
+ matrix = new DenseMatrix(numRows, numCols);
+ for (int i = 0; i < numRows; i++) {
+ matrix.assignRow(i, arrayOfRows[i]);
+ }
+ }
+ return matrix;
+ }
+
+ public static OpenObjectIntHashMap<String> readDictionary(Configuration conf, Path... dictPath) {
+ OpenObjectIntHashMap<String> dictionary = new OpenObjectIntHashMap<>();
+ for (Path dictionaryFile : dictPath) {
+ for (Pair<Writable, IntWritable> record
+ : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) {
+ dictionary.put(record.getFirst().toString(), record.getSecond().get());
+ }
+ }
+ return dictionary;
+ }
+
+ public static String[] invertDictionary(OpenObjectIntHashMap<String> termIdMap) {
+ int maxTermId = -1;
+ for (String term : termIdMap.keys()) {
+ maxTermId = Math.max(maxTermId, termIdMap.get(term));
+ }
+ maxTermId++;
+ String[] dictionary = new String[maxTermId];
+ for (String term : termIdMap.keys()) {
+ dictionary[termIdMap.get(term)] = term;
+ }
+ return dictionary;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java b/mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java
new file mode 100644
index 0000000..0c45c9a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/MultiLabelVectorWritable.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.apache.hadoop.io.Writable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Writable to handle serialization of a vector and a variable list of
+ * associated label indexes.
+ */
+public final class MultiLabelVectorWritable implements Writable {
+
+ private final VectorWritable vectorWritable = new VectorWritable();
+ private int[] labels;
+
+ public MultiLabelVectorWritable() {
+ }
+
+ public MultiLabelVectorWritable(Vector vector, int[] labels) {
+ this.vectorWritable.set(vector);
+ this.labels = labels;
+ }
+
+ public Vector getVector() {
+ return vectorWritable.get();
+ }
+
+ public void setVector(Vector vector) {
+ vectorWritable.set(vector);
+ }
+
+ public void setLabels(int[] labels) {
+ this.labels = labels;
+ }
+
+ public int[] getLabels() {
+ return labels;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ vectorWritable.readFields(in);
+ int labelSize = in.readInt();
+ labels = new int[labelSize];
+ for (int i = 0; i < labelSize; i++) {
+ labels[i] = in.readInt();
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ vectorWritable.write(out);
+ out.writeInt(labels.length);
+ for (int label : labels) {
+ out.writeInt(label);
+ }
+ }
+
+ public static MultiLabelVectorWritable read(DataInput in) throws IOException {
+ MultiLabelVectorWritable writable = new MultiLabelVectorWritable();
+ writable.readFields(in);
+ return writable;
+ }
+
+ public static void write(DataOutput out, SequentialAccessSparseVector ssv, int[] labels) throws IOException {
+ new MultiLabelVectorWritable(ssv, labels).write(out);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java b/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
new file mode 100644
index 0000000..1a6ff16
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
@@ -0,0 +1,385 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configurable;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapred.JobClient;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterators;
+
+/**
+ * DistributedRowMatrix is a FileSystem-backed VectorIterable in which the vectors live in a
+ * SequenceFile<WritableComparable,VectorWritable>, and distributed operations are executed as M/R passes on
+ * Hadoop. The usage is as follows: <p>
+ * <p>
+ * <pre>
+ * // the path must already contain an already created SequenceFile!
+ * DistributedRowMatrix m = new DistributedRowMatrix("path/to/vector/sequenceFile", "tmp/path", 10000000, 250000);
+ * m.setConf(new Configuration());
+ * // now if we want to multiply a vector by this matrix, it's dimension must equal the row dimension of this
+ * // matrix. If we want to timesSquared() a vector by this matrix, its dimension must equal the column dimension
+ * // of the matrix.
+ * Vector v = new DenseVector(250000);
+ * // now the following operation will be done via a M/R pass via Hadoop.
+ * Vector w = m.timesSquared(v);
+ * </pre>
+ *
+ */
+public class DistributedRowMatrix implements VectorIterable, Configurable {
+ public static final String KEEP_TEMP_FILES = "DistributedMatrix.keep.temp.files";
+
+ private static final Logger log = LoggerFactory.getLogger(DistributedRowMatrix.class);
+
+ private final Path inputPath;
+ private final Path outputTmpPath;
+ private Configuration conf;
+ private Path rowPath;
+ private Path outputTmpBasePath;
+ private final int numRows;
+ private final int numCols;
+ private boolean keepTempFiles;
+
+ public DistributedRowMatrix(Path inputPath,
+ Path outputTmpPath,
+ int numRows,
+ int numCols) {
+ this(inputPath, outputTmpPath, numRows, numCols, false);
+ }
+
+ public DistributedRowMatrix(Path inputPath,
+ Path outputTmpPath,
+ int numRows,
+ int numCols,
+ boolean keepTempFiles) {
+ this.inputPath = inputPath;
+ this.outputTmpPath = outputTmpPath;
+ this.numRows = numRows;
+ this.numCols = numCols;
+ this.keepTempFiles = keepTempFiles;
+ }
+
+ @Override
+ public Configuration getConf() {
+ return conf;
+ }
+
+ @Override
+ public void setConf(Configuration conf) {
+ this.conf = conf;
+ try {
+ FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
+ rowPath = fs.makeQualified(inputPath);
+ outputTmpBasePath = fs.makeQualified(outputTmpPath);
+ keepTempFiles = conf.getBoolean(KEEP_TEMP_FILES, false);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ }
+
+ public Path getRowPath() {
+ return rowPath;
+ }
+
+ public Path getOutputTempPath() {
+ return outputTmpBasePath;
+ }
+
+ public void setOutputTempPathString(String outPathString) {
+ try {
+ outputTmpBasePath = FileSystem.get(conf).makeQualified(new Path(outPathString));
+ } catch (IOException ioe) {
+ log.warn("Unable to set outputBasePath to {}, leaving as {}",
+ outPathString, outputTmpBasePath);
+ }
+ }
+
+ @Override
+ public Iterator<MatrixSlice> iterateAll() {
+ try {
+ Path pathPattern = rowPath;
+ if (FileSystem.get(conf).getFileStatus(rowPath).isDir()) {
+ pathPattern = new Path(rowPath, "*");
+ }
+ return Iterators.transform(
+ new SequenceFileDirIterator<IntWritable,VectorWritable>(pathPattern,
+ PathType.GLOB,
+ PathFilters.logsCRCFilter(),
+ null,
+ true,
+ conf),
+ new Function<Pair<IntWritable,VectorWritable>,MatrixSlice>() {
+ @Override
+ public MatrixSlice apply(Pair<IntWritable, VectorWritable> from) {
+ return new MatrixSlice(from.getSecond().get(), from.getFirst().get());
+ }
+ });
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ }
+
+ @Override
+ public int numSlices() {
+ return numRows();
+ }
+
+ @Override
+ public int numRows() {
+ return numRows;
+ }
+
+ @Override
+ public int numCols() {
+ return numCols;
+ }
+
+
+ /**
+ * This implements matrix this.transpose().times(other)
+ * @param other a DistributedRowMatrix
+ * @return a DistributedRowMatrix containing the product
+ */
+ public DistributedRowMatrix times(DistributedRowMatrix other) throws IOException {
+ return times(other, new Path(outputTmpBasePath.getParent(), "productWith-" + (System.nanoTime() & 0xFF)));
+ }
+
+ /**
+ * This implements matrix this.transpose().times(other)
+ * @param other a DistributedRowMatrix
+ * @param outPath path to write result to
+ * @return a DistributedRowMatrix containing the product
+ */
+ public DistributedRowMatrix times(DistributedRowMatrix other, Path outPath) throws IOException {
+ if (numRows != other.numRows()) {
+ throw new CardinalityException(numRows, other.numRows());
+ }
+
+ Configuration initialConf = getConf() == null ? new Configuration() : getConf();
+ Configuration conf =
+ MatrixMultiplicationJob.createMatrixMultiplyJobConf(initialConf,
+ rowPath,
+ other.rowPath,
+ outPath,
+ other.numCols);
+ JobClient.runJob(new JobConf(conf));
+ DistributedRowMatrix out = new DistributedRowMatrix(outPath, outputTmpPath, numCols, other.numCols());
+ out.setConf(conf);
+ return out;
+ }
+
+ public Vector columnMeans() throws IOException {
+ return columnMeans("SequentialAccessSparseVector");
+ }
+
+ /**
+ * Returns the column-wise mean of a DistributedRowMatrix
+ *
+ * @param vectorClass
+ * desired class for the column-wise mean vector e.g.
+ * RandomAccessSparseVector, DenseVector
+ * @return Vector containing the column-wise mean of this
+ */
+ public Vector columnMeans(String vectorClass) throws IOException {
+ Path outputVectorTmpPath =
+ new Path(outputTmpBasePath, new Path(Long.toString(System.nanoTime())));
+ Configuration initialConf =
+ getConf() == null ? new Configuration() : getConf();
+ String vectorClassFull = "org.apache.mahout.math." + vectorClass;
+ Vector mean = MatrixColumnMeansJob.run(initialConf, rowPath, outputVectorTmpPath, vectorClassFull);
+ if (!keepTempFiles) {
+ FileSystem fs = outputVectorTmpPath.getFileSystem(conf);
+ fs.delete(outputVectorTmpPath, true);
+ }
+ return mean;
+ }
+
+ public DistributedRowMatrix transpose() throws IOException {
+ Path outputPath = new Path(rowPath.getParent(), "transpose-" + (System.nanoTime() & 0xFF));
+ Configuration initialConf = getConf() == null ? new Configuration() : getConf();
+ Job transposeJob = TransposeJob.buildTransposeJob(initialConf, rowPath, outputPath, numRows);
+
+ try {
+ transposeJob.waitForCompletion(true);
+ } catch (Exception e) {
+ throw new IllegalStateException("transposition failed", e);
+ }
+
+ DistributedRowMatrix m = new DistributedRowMatrix(outputPath, outputTmpPath, numCols, numRows);
+ m.setConf(this.conf);
+ return m;
+ }
+
+ @Override
+ public Vector times(Vector v) {
+ try {
+ Configuration initialConf = getConf() == null ? new Configuration() : getConf();
+ Path outputVectorTmpPath = new Path(outputTmpBasePath, new Path(Long.toString(System.nanoTime())));
+
+ Job job = TimesSquaredJob.createTimesJob(initialConf, v, numRows, rowPath, outputVectorTmpPath);
+
+ try {
+ job.waitForCompletion(true);
+ } catch (Exception e) {
+ throw new IllegalStateException("times failed", e);
+ }
+
+ Vector result = TimesSquaredJob.retrieveTimesSquaredOutputVector(outputVectorTmpPath, conf);
+ if (!keepTempFiles) {
+ FileSystem fs = outputVectorTmpPath.getFileSystem(conf);
+ fs.delete(outputVectorTmpPath, true);
+ }
+ return result;
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ }
+
+ @Override
+ public Vector timesSquared(Vector v) {
+ try {
+ Configuration initialConf = getConf() == null ? new Configuration() : getConf();
+ Path outputVectorTmpPath = new Path(outputTmpBasePath, new Path(Long.toString(System.nanoTime())));
+
+ Job job = TimesSquaredJob.createTimesSquaredJob(initialConf, v, rowPath, outputVectorTmpPath);
+
+ try {
+ job.waitForCompletion(true);
+ } catch (Exception e) {
+ throw new IllegalStateException("timesSquared failed", e);
+ }
+
+ Vector result = TimesSquaredJob.retrieveTimesSquaredOutputVector(outputVectorTmpPath, conf);
+ if (!keepTempFiles) {
+ FileSystem fs = outputVectorTmpPath.getFileSystem(conf);
+ fs.delete(outputVectorTmpPath, true);
+ }
+ return result;
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ }
+
+ @Override
+ public Iterator<MatrixSlice> iterator() {
+ return iterateAll();
+ }
+
+ public static class MatrixEntryWritable implements WritableComparable<MatrixEntryWritable> {
+ private int row;
+ private int col;
+ private double val;
+
+ public int getRow() {
+ return row;
+ }
+
+ public void setRow(int row) {
+ this.row = row;
+ }
+
+ public int getCol() {
+ return col;
+ }
+
+ public void setCol(int col) {
+ this.col = col;
+ }
+
+ public double getVal() {
+ return val;
+ }
+
+ public void setVal(double val) {
+ this.val = val;
+ }
+
+ @Override
+ public int compareTo(MatrixEntryWritable o) {
+ if (row > o.row) {
+ return 1;
+ } else if (row < o.row) {
+ return -1;
+ } else {
+ if (col > o.col) {
+ return 1;
+ } else if (col < o.col) {
+ return -1;
+ } else {
+ return 0;
+ }
+ }
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof MatrixEntryWritable)) {
+ return false;
+ }
+ MatrixEntryWritable other = (MatrixEntryWritable) o;
+ return row == other.row && col == other.col;
+ }
+
+ @Override
+ public int hashCode() {
+ return row + 31 * col;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(row);
+ out.writeInt(col);
+ out.writeDouble(val);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ row = in.readInt();
+ col = in.readInt();
+ val = in.readDouble();
+ }
+
+ @Override
+ public String toString() {
+ return "(" + row + ',' + col + "):" + val;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/MatrixColumnMeansJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/MatrixColumnMeansJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/MatrixColumnMeansJob.java
new file mode 100644
index 0000000..b4f459a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/MatrixColumnMeansJob.java
@@ -0,0 +1,236 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package org.apache.mahout.math.hadoop;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import com.google.common.io.Closeables;
+
+/**
+ * MatrixColumnMeansJob is a job for calculating the column-wise mean of a
+ * DistributedRowMatrix. This job can be accessed using
+ * DistributedRowMatrix.columnMeans()
+ */
+public final class MatrixColumnMeansJob {
+
+ public static final String VECTOR_CLASS =
+ "DistributedRowMatrix.columnMeans.vector.class";
+
+ private MatrixColumnMeansJob() {
+ }
+
+ public static Vector run(Configuration conf,
+ Path inputPath,
+ Path outputVectorTmpPath) throws IOException {
+ return run(conf, inputPath, outputVectorTmpPath, null);
+ }
+
+ /**
+ * Job for calculating column-wise mean of a DistributedRowMatrix
+ *
+ * @param initialConf
+ * @param inputPath
+ * path to DistributedRowMatrix input
+ * @param outputVectorTmpPath
+ * path for temporary files created during job
+ * @param vectorClass
+ * String of desired class for returned vector e.g. DenseVector,
+ * RandomAccessSparseVector (may be null for {@link DenseVector} )
+ * @return Vector containing column-wise mean of DistributedRowMatrix
+ */
+ public static Vector run(Configuration initialConf,
+ Path inputPath,
+ Path outputVectorTmpPath,
+ String vectorClass) throws IOException {
+
+ try {
+ initialConf.set(VECTOR_CLASS,
+ vectorClass == null ? DenseVector.class.getName()
+ : vectorClass);
+
+ Job job = new Job(initialConf, "MatrixColumnMeansJob");
+ job.setJarByClass(MatrixColumnMeansJob.class);
+
+ FileOutputFormat.setOutputPath(job, outputVectorTmpPath);
+
+ outputVectorTmpPath.getFileSystem(job.getConfiguration())
+ .delete(outputVectorTmpPath, true);
+ job.setNumReduceTasks(1);
+ FileOutputFormat.setOutputPath(job, outputVectorTmpPath);
+ FileInputFormat.addInputPath(job, inputPath);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ FileOutputFormat.setOutputPath(job, outputVectorTmpPath);
+
+ job.setMapperClass(MatrixColumnMeansMapper.class);
+ job.setReducerClass(MatrixColumnMeansReducer.class);
+ job.setMapOutputKeyClass(NullWritable.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.submit();
+ job.waitForCompletion(true);
+
+ Path tmpFile = new Path(outputVectorTmpPath, "part-r-00000");
+ SequenceFileValueIterator<VectorWritable> iterator =
+ new SequenceFileValueIterator<>(tmpFile, true, initialConf);
+ try {
+ if (iterator.hasNext()) {
+ return iterator.next().get();
+ } else {
+ return (Vector) Class.forName(vectorClass).getConstructor(int.class)
+ .newInstance(0);
+ }
+ } finally {
+ Closeables.close(iterator, true);
+ }
+ } catch (IOException ioe) {
+ throw ioe;
+ } catch (Throwable thr) {
+ throw new IOException(thr);
+ }
+ }
+
+ /**
+ * Mapper for calculation of column-wise mean.
+ */
+ public static class MatrixColumnMeansMapper extends
+ Mapper<Writable, VectorWritable, NullWritable, VectorWritable> {
+
+ private Vector runningSum;
+ private String vectorClass;
+
+ @Override
+ public void setup(Context context) {
+ vectorClass = context.getConfiguration().get(VECTOR_CLASS);
+ }
+
+ /**
+ * The mapper computes a running sum of the vectors the task has seen.
+ * Element 0 of the running sum vector contains a count of the number of
+ * vectors that have been seen. The remaining elements contain the
+ * column-wise running sum. Nothing is written at this stage
+ */
+ @Override
+ public void map(Writable r, VectorWritable v, Context context)
+ throws IOException {
+ if (runningSum == null) {
+ /*
+ * If this is the first vector the mapper has seen, instantiate a new
+ * vector using the parameter VECTOR_CLASS
+ */
+ runningSum = ClassUtils.instantiateAs(vectorClass,
+ Vector.class,
+ new Class<?>[] { int.class },
+ new Object[] { v.get().size() + 1 });
+ runningSum.set(0, 1);
+ runningSum.viewPart(1, v.get().size()).assign(v.get());
+ } else {
+ runningSum.set(0, runningSum.get(0) + 1);
+ runningSum.viewPart(1, v.get().size()).assign(v.get(), Functions.PLUS);
+ }
+ }
+
+ /**
+ * The column-wise sum is written at the cleanup stage. A single reducer is
+ * forced so null can be used for the key
+ */
+ @Override
+ public void cleanup(Context context) throws InterruptedException,
+ IOException {
+ if (runningSum != null) {
+ context.write(NullWritable.get(), new VectorWritable(runningSum));
+ }
+ }
+
+ }
+
+ /**
+ * The reducer adds the partial column-wise sums from each of the mappers to
+ * compute the total column-wise sum. The total sum is then divided by the
+ * total count of vectors to determine the column-wise mean.
+ */
+ public static class MatrixColumnMeansReducer extends
+ Reducer<NullWritable, VectorWritable, IntWritable, VectorWritable> {
+
+ private static final IntWritable ONE = new IntWritable(1);
+
+ private String vectorClass;
+ private Vector outputVector;
+ private final VectorWritable outputVectorWritable = new VectorWritable();
+
+ @Override
+ public void setup(Context context) {
+ vectorClass = context.getConfiguration().get(VECTOR_CLASS);
+ }
+
+ @Override
+ public void reduce(NullWritable n,
+ Iterable<VectorWritable> vectors,
+ Context context) throws IOException, InterruptedException {
+
+ /**
+ * Add together partial column-wise sums from mappers
+ */
+ for (VectorWritable v : vectors) {
+ if (outputVector == null) {
+ outputVector = v.get();
+ } else {
+ outputVector.assign(v.get(), Functions.PLUS);
+ }
+ }
+
+ /**
+ * Divide total column-wise sum by count of vectors, which corresponds to
+ * the number of rows in the DistributedRowMatrix
+ */
+ if (outputVector != null) {
+ outputVectorWritable.set(outputVector.viewPart(1,
+ outputVector.size() - 1)
+ .divide(outputVector.get(0)));
+ context.write(ONE, outputVectorWritable);
+ } else {
+ Vector emptyVector = ClassUtils.instantiateAs(vectorClass,
+ Vector.class,
+ new Class<?>[] { int.class },
+ new Object[] { 0 });
+ context.write(ONE, new VectorWritable(emptyVector));
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/MatrixMultiplicationJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/MatrixMultiplicationJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/MatrixMultiplicationJob.java
new file mode 100644
index 0000000..48eda08
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/MatrixMultiplicationJob.java
@@ -0,0 +1,177 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapred.FileOutputFormat;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reducer;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapred.SequenceFileInputFormat;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.mapred.join.CompositeInputFormat;
+import org.apache.hadoop.mapred.join.TupleWritable;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * This still uses the old MR api and as with all things in Mahout that are MapReduce is now part of 'mahout-mr'.
+ * There is no plan to convert the old MR api used here to the new MR api.
+ * This will be replaced by the new Spark based Linear Algebra bindings.
+ */
+
+public class MatrixMultiplicationJob extends AbstractJob {
+
+ private static final String OUT_CARD = "output.vector.cardinality";
+
+ public static Configuration createMatrixMultiplyJobConf(Path aPath,
+ Path bPath,
+ Path outPath,
+ int outCardinality) {
+ return createMatrixMultiplyJobConf(new Configuration(), aPath, bPath, outPath, outCardinality);
+ }
+
+ public static Configuration createMatrixMultiplyJobConf(Configuration initialConf,
+ Path aPath,
+ Path bPath,
+ Path outPath,
+ int outCardinality) {
+ JobConf conf = new JobConf(initialConf, MatrixMultiplicationJob.class);
+ conf.setInputFormat(CompositeInputFormat.class);
+ conf.set("mapred.join.expr", CompositeInputFormat.compose(
+ "inner", SequenceFileInputFormat.class, aPath, bPath));
+ conf.setInt(OUT_CARD, outCardinality);
+ conf.setOutputFormat(SequenceFileOutputFormat.class);
+ FileOutputFormat.setOutputPath(conf, outPath);
+ conf.setMapperClass(MatrixMultiplyMapper.class);
+ conf.setCombinerClass(MatrixMultiplicationReducer.class);
+ conf.setReducerClass(MatrixMultiplicationReducer.class);
+ conf.setMapOutputKeyClass(IntWritable.class);
+ conf.setMapOutputValueClass(VectorWritable.class);
+ conf.setOutputKeyClass(IntWritable.class);
+ conf.setOutputValueClass(VectorWritable.class);
+ return conf;
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new MatrixMultiplicationJob(), args);
+ }
+
+ @Override
+ public int run(String[] strings) throws Exception {
+ addOption("numRowsA", "nra", "Number of rows of the first input matrix", true);
+ addOption("numColsA", "nca", "Number of columns of the first input matrix", true);
+ addOption("numRowsB", "nrb", "Number of rows of the second input matrix", true);
+
+ addOption("numColsB", "ncb", "Number of columns of the second input matrix", true);
+ addOption("inputPathA", "ia", "Path to the first input matrix", true);
+ addOption("inputPathB", "ib", "Path to the second input matrix", true);
+
+ addOption("outputPath", "op", "Path to the output matrix", false);
+
+ Map<String, List<String>> argMap = parseArguments(strings);
+ if (argMap == null) {
+ return -1;
+ }
+
+ DistributedRowMatrix a = new DistributedRowMatrix(new Path(getOption("inputPathA")),
+ new Path(getOption("tempDir")),
+ Integer.parseInt(getOption("numRowsA")),
+ Integer.parseInt(getOption("numColsA")));
+ DistributedRowMatrix b = new DistributedRowMatrix(new Path(getOption("inputPathB")),
+ new Path(getOption("tempDir")),
+ Integer.parseInt(getOption("numRowsB")),
+ Integer.parseInt(getOption("numColsB")));
+
+ a.setConf(new Configuration(getConf()));
+ b.setConf(new Configuration(getConf()));
+
+ if (hasOption("outputPath")) {
+ a.times(b, new Path(getOption("outputPath")));
+ } else {
+ a.times(b);
+ }
+
+ return 0;
+ }
+
+ public static class MatrixMultiplyMapper extends MapReduceBase
+ implements Mapper<IntWritable,TupleWritable,IntWritable,VectorWritable> {
+
+ private int outCardinality;
+ private final IntWritable row = new IntWritable();
+
+ @Override
+ public void configure(JobConf conf) {
+ outCardinality = conf.getInt(OUT_CARD, Integer.MAX_VALUE);
+ }
+
+ @Override
+ public void map(IntWritable index,
+ TupleWritable v,
+ OutputCollector<IntWritable,VectorWritable> out,
+ Reporter reporter) throws IOException {
+ boolean firstIsOutFrag = ((VectorWritable)v.get(0)).get().size() == outCardinality;
+ Vector outFrag = firstIsOutFrag ? ((VectorWritable)v.get(0)).get() : ((VectorWritable)v.get(1)).get();
+ Vector multiplier = firstIsOutFrag ? ((VectorWritable)v.get(1)).get() : ((VectorWritable)v.get(0)).get();
+
+ VectorWritable outVector = new VectorWritable();
+ for (Vector.Element e : multiplier.nonZeroes()) {
+ row.set(e.index());
+ outVector.set(outFrag.times(e.get()));
+ out.collect(row, outVector);
+ }
+ }
+ }
+
+ public static class MatrixMultiplicationReducer extends MapReduceBase
+ implements Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ @Override
+ public void reduce(IntWritable rowNum,
+ Iterator<VectorWritable> it,
+ OutputCollector<IntWritable,VectorWritable> out,
+ Reporter reporter) throws IOException {
+ if (!it.hasNext()) {
+ return;
+ }
+ Vector accumulator = new RandomAccessSparseVector(it.next().get());
+ while (it.hasNext()) {
+ Vector row = it.next().get();
+ accumulator.assign(row, Functions.PLUS);
+ }
+ out.collect(rowNum, new VectorWritable(new SequentialAccessSparseVector(accumulator)));
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/TimesSquaredJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/TimesSquaredJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/TimesSquaredJob.java
new file mode 100644
index 0000000..e234eb9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/TimesSquaredJob.java
@@ -0,0 +1,251 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import com.google.common.base.Preconditions;
+
+import java.io.IOException;
+import java.net.URI;
+
+public final class TimesSquaredJob {
+
+ public static final String INPUT_VECTOR = "DistributedMatrix.times.inputVector";
+ public static final String IS_SPARSE_OUTPUT = "DistributedMatrix.times.outputVector.sparse";
+ public static final String OUTPUT_VECTOR_DIMENSION = "DistributedMatrix.times.output.dimension";
+
+ public static final String OUTPUT_VECTOR_FILENAME = "DistributedMatrix.times.outputVector";
+
+ private TimesSquaredJob() { }
+
+ public static Job createTimesSquaredJob(Vector v, Path matrixInputPath, Path outputVectorPath)
+ throws IOException {
+ return createTimesSquaredJob(new Configuration(), v, matrixInputPath, outputVectorPath);
+ }
+
+ public static Job createTimesSquaredJob(Configuration initialConf, Vector v, Path matrixInputPath,
+ Path outputVectorPath) throws IOException {
+
+ return createTimesSquaredJob(initialConf, v, matrixInputPath, outputVectorPath, TimesSquaredMapper.class,
+ VectorSummingReducer.class);
+ }
+
+ public static Job createTimesJob(Vector v, int outDim, Path matrixInputPath, Path outputVectorPath)
+ throws IOException {
+
+ return createTimesJob(new Configuration(), v, outDim, matrixInputPath, outputVectorPath);
+ }
+
+ public static Job createTimesJob(Configuration initialConf, Vector v, int outDim, Path matrixInputPath,
+ Path outputVectorPath) throws IOException {
+
+ return createTimesSquaredJob(initialConf, v, outDim, matrixInputPath, outputVectorPath, TimesMapper.class,
+ VectorSummingReducer.class);
+ }
+
+ public static Job createTimesSquaredJob(Vector v, Path matrixInputPath, Path outputVectorPathBase,
+ Class<? extends TimesSquaredMapper> mapClass, Class<? extends VectorSummingReducer> redClass) throws IOException {
+
+ return createTimesSquaredJob(new Configuration(), v, matrixInputPath, outputVectorPathBase, mapClass, redClass);
+ }
+
+ public static Job createTimesSquaredJob(Configuration initialConf, Vector v, Path matrixInputPath,
+ Path outputVectorPathBase, Class<? extends TimesSquaredMapper> mapClass,
+ Class<? extends VectorSummingReducer> redClass) throws IOException {
+
+ return createTimesSquaredJob(initialConf, v, v.size(), matrixInputPath, outputVectorPathBase, mapClass, redClass);
+ }
+
+ public static Job createTimesSquaredJob(Vector v, int outputVectorDim, Path matrixInputPath,
+ Path outputVectorPathBase, Class<? extends TimesSquaredMapper> mapClass,
+ Class<? extends VectorSummingReducer> redClass) throws IOException {
+
+ return createTimesSquaredJob(new Configuration(), v, outputVectorDim, matrixInputPath, outputVectorPathBase,
+ mapClass, redClass);
+ }
+
+ public static Job createTimesSquaredJob(Configuration initialConf, Vector v, int outputVectorDim,
+ Path matrixInputPath, Path outputVectorPathBase, Class<? extends TimesSquaredMapper> mapClass,
+ Class<? extends VectorSummingReducer> redClass) throws IOException {
+
+ FileSystem fs = FileSystem.get(matrixInputPath.toUri(), initialConf);
+ matrixInputPath = fs.makeQualified(matrixInputPath);
+ outputVectorPathBase = fs.makeQualified(outputVectorPathBase);
+
+ long now = System.nanoTime();
+ Path inputVectorPath = new Path(outputVectorPathBase, INPUT_VECTOR + '/' + now);
+
+
+ SequenceFile.Writer inputVectorPathWriter = null;
+
+ try {
+ inputVectorPathWriter = new SequenceFile.Writer(fs, initialConf, inputVectorPath, NullWritable.class,
+ VectorWritable.class);
+ inputVectorPathWriter.append(NullWritable.get(), new VectorWritable(v));
+ } finally {
+ Closeables.close(inputVectorPathWriter, false);
+ }
+
+ URI ivpURI = inputVectorPath.toUri();
+ DistributedCache.setCacheFiles(new URI[] { ivpURI }, initialConf);
+
+ Job job = HadoopUtil.prepareJob(matrixInputPath, new Path(outputVectorPathBase, OUTPUT_VECTOR_FILENAME),
+ SequenceFileInputFormat.class, mapClass, NullWritable.class, VectorWritable.class, redClass,
+ NullWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, initialConf);
+ job.setCombinerClass(redClass);
+ job.setJobName("TimesSquaredJob: " + matrixInputPath);
+
+ Configuration conf = job.getConfiguration();
+ conf.set(INPUT_VECTOR, ivpURI.toString());
+ conf.setBoolean(IS_SPARSE_OUTPUT, !v.isDense());
+ conf.setInt(OUTPUT_VECTOR_DIMENSION, outputVectorDim);
+
+ return job;
+ }
+
+ public static Vector retrieveTimesSquaredOutputVector(Path outputVectorTmpPath, Configuration conf)
+ throws IOException {
+ Path outputFile = new Path(outputVectorTmpPath, OUTPUT_VECTOR_FILENAME + "/part-r-00000");
+ SequenceFileValueIterator<VectorWritable> iterator =
+ new SequenceFileValueIterator<>(outputFile, true, conf);
+ try {
+ return iterator.next().get();
+ } finally {
+ Closeables.close(iterator, true);
+ }
+ }
+
+ public static class TimesSquaredMapper<T extends WritableComparable>
+ extends Mapper<T,VectorWritable, NullWritable,VectorWritable> {
+
+ private Vector outputVector;
+ private Vector inputVector;
+
+ Vector getOutputVector() {
+ return outputVector;
+ }
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ try {
+ Configuration conf = ctx.getConfiguration();
+ Path[] localFiles = DistributedCache.getLocalCacheFiles(conf);
+ Preconditions.checkArgument(localFiles != null && localFiles.length >= 1,
+ "missing paths from the DistributedCache");
+
+ Path inputVectorPath = HadoopUtil.getSingleCachedFile(conf);
+
+ SequenceFileValueIterator<VectorWritable> iterator =
+ new SequenceFileValueIterator<>(inputVectorPath, true, conf);
+ try {
+ inputVector = iterator.next().get();
+ } finally {
+ Closeables.close(iterator, true);
+ }
+
+ int outDim = conf.getInt(OUTPUT_VECTOR_DIMENSION, Integer.MAX_VALUE);
+ outputVector = conf.getBoolean(IS_SPARSE_OUTPUT, false)
+ ? new RandomAccessSparseVector(outDim, 10)
+ : new DenseVector(outDim);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ }
+
+ @Override
+ protected void map(T key, VectorWritable v, Context context) throws IOException, InterruptedException {
+
+ double d = scale(v);
+ if (d == 1.0) {
+ outputVector.assign(v.get(), Functions.PLUS);
+ } else if (d != 0.0) {
+ outputVector.assign(v.get(), Functions.plusMult(d));
+ }
+ }
+
+ protected double scale(VectorWritable v) {
+ return v.get().dot(inputVector);
+ }
+
+ @Override
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ ctx.write(NullWritable.get(), new VectorWritable(outputVector));
+ }
+
+ }
+
+ public static class TimesMapper extends TimesSquaredMapper<IntWritable> {
+
+
+ @Override
+ protected void map(IntWritable rowNum, VectorWritable v, Context context) throws IOException, InterruptedException {
+ double d = scale(v);
+ if (d != 0.0) {
+ getOutputVector().setQuick(rowNum.get(), d);
+ }
+ }
+ }
+
+ public static class VectorSummingReducer extends Reducer<NullWritable,VectorWritable,NullWritable,VectorWritable> {
+
+ private Vector outputVector;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ Configuration conf = ctx.getConfiguration();
+ int outputDimension = conf.getInt(OUTPUT_VECTOR_DIMENSION, Integer.MAX_VALUE);
+ outputVector = conf.getBoolean(IS_SPARSE_OUTPUT, false)
+ ? new RandomAccessSparseVector(outputDimension, 10)
+ : new DenseVector(outputDimension);
+ }
+
+ @Override
+ protected void reduce(NullWritable key, Iterable<VectorWritable> vectors, Context ctx)
+ throws IOException, InterruptedException {
+
+ for (VectorWritable v : vectors) {
+ if (v != null) {
+ outputVector.assign(v.get(), Functions.PLUS);
+ }
+ }
+ ctx.write(NullWritable.get(), new VectorWritable(outputVector));
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java
new file mode 100644
index 0000000..60066c6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java
@@ -0,0 +1,85 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.mapreduce.MergeVectorsCombiner;
+import org.apache.mahout.common.mapreduce.MergeVectorsReducer;
+import org.apache.mahout.common.mapreduce.TransposeMapper;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+/** Transpose a matrix */
+public class TransposeJob extends AbstractJob {
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new TransposeJob(), args);
+ }
+
+ @Override
+ public int run(String[] strings) throws Exception {
+ addInputOption();
+ addOption("numRows", "nr", "Number of rows of the input matrix");
+ addOption("numCols", "nc", "Number of columns of the input matrix");
+ Map<String, List<String>> parsedArgs = parseArguments(strings);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ int numRows = Integer.parseInt(getOption("numRows"));
+ int numCols = Integer.parseInt(getOption("numCols"));
+
+ DistributedRowMatrix matrix = new DistributedRowMatrix(getInputPath(), getTempPath(), numRows, numCols);
+ matrix.setConf(new Configuration(getConf()));
+ matrix.transpose();
+
+ return 0;
+ }
+
+ public static Job buildTransposeJob(Path matrixInputPath, Path matrixOutputPath, int numInputRows)
+ throws IOException {
+ return buildTransposeJob(new Configuration(), matrixInputPath, matrixOutputPath, numInputRows);
+ }
+
+ public static Job buildTransposeJob(Configuration initialConf, Path matrixInputPath, Path matrixOutputPath,
+ int numInputRows) throws IOException {
+
+ Job job = HadoopUtil.prepareJob(matrixInputPath, matrixOutputPath, SequenceFileInputFormat.class,
+ TransposeMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class, IntWritable.class,
+ VectorWritable.class, SequenceFileOutputFormat.class, initialConf);
+ job.setCombinerClass(MergeVectorsCombiner.class);
+ job.getConfiguration().setInt(TransposeMapper.NEW_NUM_COLS_PARAM, numInputRows);
+
+ job.setJobName("TransposeJob: " + matrixInputPath);
+
+ return job;
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/DistributedLanczosSolver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/DistributedLanczosSolver.java b/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/DistributedLanczosSolver.java
new file mode 100644
index 0000000..89dddcc
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/DistributedLanczosSolver.java
@@ -0,0 +1,298 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.decomposer;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configurable;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.decomposer.lanczos.LanczosSolver;
+import org.apache.mahout.math.decomposer.lanczos.LanczosState;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * See the SSVD code for a better option than using this:
+ *
+ * http://mahout.apache.org/users/dim-reduction/ssvd.html
+ * @see org.apache.mahout.math.hadoop.stochasticsvd.SSVDSolver
+ */
+public class DistributedLanczosSolver extends LanczosSolver implements Tool {
+
+ public static final String RAW_EIGENVECTORS = "rawEigenvectors";
+
+ private static final Logger log = LoggerFactory.getLogger(DistributedLanczosSolver.class);
+
+ private Configuration conf;
+
+ private Map<String, List<String>> parsedArgs;
+
+ /**
+ * For the distributed case, the best guess at a useful initialization state for Lanczos we'll chose to be
+ * uniform over all input dimensions, L_2 normalized.
+ */
+ public static Vector getInitialVector(VectorIterable corpus) {
+ Vector initialVector = new DenseVector(corpus.numCols());
+ initialVector.assign(1.0 / Math.sqrt(corpus.numCols()));
+ return initialVector;
+ }
+
+ public LanczosState runJob(Configuration originalConfig,
+ LanczosState state,
+ int desiredRank,
+ boolean isSymmetric,
+ String outputEigenVectorPathString) throws IOException {
+ ((Configurable) state.getCorpus()).setConf(new Configuration(originalConfig));
+ setConf(originalConfig);
+ solve(state, desiredRank, isSymmetric);
+ serializeOutput(state, new Path(outputEigenVectorPathString));
+ return state;
+ }
+
+ /**
+ * Factored-out LanczosSolver for the purpose of invoking it programmatically
+ */
+ public LanczosState runJob(Configuration originalConfig,
+ Path inputPath,
+ Path outputTmpPath,
+ int numRows,
+ int numCols,
+ boolean isSymmetric,
+ int desiredRank,
+ String outputEigenVectorPathString) throws IOException {
+ DistributedRowMatrix matrix = new DistributedRowMatrix(inputPath, outputTmpPath, numRows, numCols);
+ matrix.setConf(new Configuration(originalConfig));
+ LanczosState state = new LanczosState(matrix, desiredRank, getInitialVector(matrix));
+ return runJob(originalConfig, state, desiredRank, isSymmetric, outputEigenVectorPathString);
+ }
+
+ @Override
+ public int run(String[] strings) throws Exception {
+ Path inputPath = new Path(AbstractJob.getOption(parsedArgs, "--input"));
+ Path outputPath = new Path(AbstractJob.getOption(parsedArgs, "--output"));
+ Path outputTmpPath = new Path(AbstractJob.getOption(parsedArgs, "--tempDir"));
+ Path workingDirPath = AbstractJob.getOption(parsedArgs, "--workingDir") != null
+ ? new Path(AbstractJob.getOption(parsedArgs, "--workingDir")) : null;
+ int numRows = Integer.parseInt(AbstractJob.getOption(parsedArgs, "--numRows"));
+ int numCols = Integer.parseInt(AbstractJob.getOption(parsedArgs, "--numCols"));
+ boolean isSymmetric = Boolean.parseBoolean(AbstractJob.getOption(parsedArgs, "--symmetric"));
+ int desiredRank = Integer.parseInt(AbstractJob.getOption(parsedArgs, "--rank"));
+
+ boolean cleansvd = Boolean.parseBoolean(AbstractJob.getOption(parsedArgs, "--cleansvd"));
+ if (cleansvd) {
+ double maxError = Double.parseDouble(AbstractJob.getOption(parsedArgs, "--maxError"));
+ double minEigenvalue = Double.parseDouble(AbstractJob.getOption(parsedArgs, "--minEigenvalue"));
+ boolean inMemory = Boolean.parseBoolean(AbstractJob.getOption(parsedArgs, "--inMemory"));
+ return run(inputPath,
+ outputPath,
+ outputTmpPath,
+ workingDirPath,
+ numRows,
+ numCols,
+ isSymmetric,
+ desiredRank,
+ maxError,
+ minEigenvalue,
+ inMemory);
+ }
+ return run(inputPath, outputPath, outputTmpPath, workingDirPath, numRows, numCols, isSymmetric, desiredRank);
+ }
+
+ /**
+ * Run the solver to produce raw eigenvectors, then run the EigenVerificationJob to clean them
+ *
+ * @param inputPath the Path to the input corpus
+ * @param outputPath the Path to the output
+ * @param outputTmpPath a Path to a temporary working directory
+ * @param numRows the int number of rows
+ * @param numCols the int number of columns
+ * @param isSymmetric true if the input matrix is symmetric
+ * @param desiredRank the int desired rank of eigenvectors to produce
+ * @param maxError the maximum allowable error
+ * @param minEigenvalue the minimum usable eigenvalue
+ * @param inMemory true if the verification can be done in memory
+ * @return an int indicating success (0) or otherwise
+ */
+ public int run(Path inputPath,
+ Path outputPath,
+ Path outputTmpPath,
+ Path workingDirPath,
+ int numRows,
+ int numCols,
+ boolean isSymmetric,
+ int desiredRank,
+ double maxError,
+ double minEigenvalue,
+ boolean inMemory) throws Exception {
+ int result = run(inputPath, outputPath, outputTmpPath, workingDirPath, numRows, numCols,
+ isSymmetric, desiredRank);
+ if (result != 0) {
+ return result;
+ }
+ Path rawEigenVectorPath = new Path(outputPath, RAW_EIGENVECTORS);
+ return new EigenVerificationJob().run(inputPath,
+ rawEigenVectorPath,
+ outputPath,
+ outputTmpPath,
+ maxError,
+ minEigenvalue,
+ inMemory,
+ getConf() != null ? new Configuration(getConf()) : new Configuration());
+ }
+
+ /**
+ * Run the solver to produce the raw eigenvectors
+ *
+ * @param inputPath the Path to the input corpus
+ * @param outputPath the Path to the output
+ * @param outputTmpPath a Path to a temporary working directory
+ * @param numRows the int number of rows
+ * @param numCols the int number of columns
+ * @param isSymmetric true if the input matrix is symmetric
+ * @param desiredRank the int desired rank of eigenvectors to produce
+ * @return an int indicating success (0) or otherwise
+ */
+ public int run(Path inputPath,
+ Path outputPath,
+ Path outputTmpPath,
+ Path workingDirPath,
+ int numRows,
+ int numCols,
+ boolean isSymmetric,
+ int desiredRank) throws Exception {
+ DistributedRowMatrix matrix = new DistributedRowMatrix(inputPath, outputTmpPath, numRows, numCols);
+ matrix.setConf(new Configuration(getConf() != null ? getConf() : new Configuration()));
+
+ LanczosState state;
+ if (workingDirPath == null) {
+ state = new LanczosState(matrix, desiredRank, getInitialVector(matrix));
+ } else {
+ HdfsBackedLanczosState hState =
+ new HdfsBackedLanczosState(matrix, desiredRank, getInitialVector(matrix), workingDirPath);
+ hState.setConf(matrix.getConf());
+ state = hState;
+ }
+ solve(state, desiredRank, isSymmetric);
+
+ Path outputEigenVectorPath = new Path(outputPath, RAW_EIGENVECTORS);
+ serializeOutput(state, outputEigenVectorPath);
+ return 0;
+ }
+
+ /**
+ * @param state The final LanczosState to be serialized
+ * @param outputPath The path (relative to the current Configuration's FileSystem) to save the output to.
+ */
+ public void serializeOutput(LanczosState state, Path outputPath) throws IOException {
+ int numEigenVectors = state.getIterationNumber();
+ log.info("Persisting {} eigenVectors and eigenValues to: {}", numEigenVectors, outputPath);
+ Configuration conf = getConf() != null ? getConf() : new Configuration();
+ FileSystem fs = FileSystem.get(outputPath.toUri(), conf);
+ SequenceFile.Writer seqWriter =
+ new SequenceFile.Writer(fs, conf, outputPath, IntWritable.class, VectorWritable.class);
+ try {
+ IntWritable iw = new IntWritable();
+ for (int i = 0; i < numEigenVectors; i++) {
+ // Persist eigenvectors sorted by eigenvalues in descending order\
+ NamedVector v = new NamedVector(state.getRightSingularVector(numEigenVectors - 1 - i),
+ "eigenVector" + i + ", eigenvalue = " + state.getSingularValue(numEigenVectors - 1 - i));
+ Writable vw = new VectorWritable(v);
+ iw.set(i);
+ seqWriter.append(iw, vw);
+ }
+ } finally {
+ Closeables.close(seqWriter, false);
+ }
+ }
+
+ @Override
+ public void setConf(Configuration configuration) {
+ conf = configuration;
+ }
+
+ @Override
+ public Configuration getConf() {
+ return conf;
+ }
+
+ public DistributedLanczosSolverJob job() {
+ return new DistributedLanczosSolverJob();
+ }
+
+ /**
+ * Inner subclass of AbstractJob so we get access to AbstractJob's functionality w.r.t. cmdline options, but still
+ * sublcass LanczosSolver.
+ */
+ public class DistributedLanczosSolverJob extends AbstractJob {
+ @Override
+ public void setConf(Configuration conf) {
+ DistributedLanczosSolver.this.setConf(conf);
+ }
+
+ @Override
+ public Configuration getConf() {
+ return DistributedLanczosSolver.this.getConf();
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption("numRows", "nr", "Number of rows of the input matrix");
+ addOption("numCols", "nc", "Number of columns of the input matrix");
+ addOption("rank", "r", "Desired decomposition rank (note: only roughly 1/4 to 1/3 "
+ + "of these will have the top portion of the spectrum)");
+ addOption("symmetric", "sym", "Is the input matrix square and symmetric?");
+ addOption("workingDir", "wd", "Working directory path to store Lanczos basis vectors "
+ + "(to be used on restarts, and to avoid too much RAM usage)");
+ // options required to run cleansvd job
+ addOption("cleansvd", "cl", "Run the EigenVerificationJob to clean the eigenvectors after SVD", false);
+ addOption("maxError", "err", "Maximum acceptable error", "0.05");
+ addOption("minEigenvalue", "mev", "Minimum eigenvalue to keep the vector for", "0.0");
+ addOption("inMemory", "mem", "Buffer eigen matrix into memory (if you have enough!)", "false");
+
+ DistributedLanczosSolver.this.parsedArgs = parseArguments(args);
+ if (DistributedLanczosSolver.this.parsedArgs == null) {
+ return -1;
+ } else {
+ return DistributedLanczosSolver.this.run(args);
+ }
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new DistributedLanczosSolver().job(), args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/EigenVector.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/EigenVector.java b/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/EigenVector.java
new file mode 100644
index 0000000..d2f0c8c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/EigenVector.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.decomposer;
+
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+
+import java.util.regex.Pattern;
+
+/**
+ * TODO this is a horrible hack. Make a proper writable subclass also.
+ */
+public class EigenVector extends NamedVector {
+
+ private static final Pattern EQUAL_PATTERN = Pattern.compile(" = ");
+ private static final Pattern PIPE_PATTERN = Pattern.compile("\\|");
+
+ public EigenVector(Vector v, double eigenValue, double cosAngleError, int order) {
+ super(v instanceof DenseVector ? (DenseVector) v : new DenseVector(v),
+ "e|" + order + "| = |" + eigenValue + "|, err = " + cosAngleError);
+ }
+
+ public double getEigenValue() {
+ return getEigenValue(getName());
+ }
+
+ public double getCosAngleError() {
+ return getCosAngleError(getName());
+ }
+
+ public int getIndex() {
+ return getIndex(getName());
+ }
+
+ public static double getEigenValue(CharSequence name) {
+ return parseMetaData(name)[1];
+ }
+
+ public static double getCosAngleError(CharSequence name) {
+ return parseMetaData(name)[2];
+ }
+
+ public static int getIndex(CharSequence name) {
+ return (int)parseMetaData(name)[0];
+ }
+
+ public static double[] parseMetaData(CharSequence name) {
+ double[] m = new double[3];
+ String[] s = EQUAL_PATTERN.split(name);
+ m[0] = Double.parseDouble(PIPE_PATTERN.split(s[0])[1]);
+ m[1] = Double.parseDouble(PIPE_PATTERN.split(s[1])[1]);
+ m[2] = Double.parseDouble(s[2].substring(1));
+ return m;
+ }
+
+ protected double[] parseMetaData() {
+ return parseMetaData(getName());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/EigenVerificationJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/EigenVerificationJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/EigenVerificationJob.java
new file mode 100644
index 0000000..a7eaaed
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/EigenVerificationJob.java
@@ -0,0 +1,332 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.decomposer;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.decomposer.EigenStatus;
+import org.apache.mahout.math.decomposer.SimpleEigenVerifier;
+import org.apache.mahout.math.decomposer.SingularVectorVerifier;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>
+ * Class for taking the output of an eigendecomposition (specified as a Path location), and verifies correctness, in
+ * terms of the following: if you have a vector e, and a matrix m, then let e' = m.timesSquared(v); the error w.r.t.
+ * eigenvector-ness is the cosine of the angle between e and e':
+ * </p>
+ *
+ * <pre>
+ * error(e,e') = e.dot(e') / (e.norm(2)*e'.norm(2))
+ * </pre>
+ * <p>
+ * A set of eigenvectors should also all be very close to orthogonal, so this job computes all inner products between
+ * eigenvectors, and checks that this is close to the identity matrix.
+ * </p>
+ * <p>
+ * Parameters used in the cleanup (other than in the input/output path options) include --minEigenvalue, which specifies
+ * the value below which eigenvector/eigenvalue pairs will be discarded, and --maxError, which specifies the maximum
+ * error (as defined above) to be tolerated in an eigenvector.
+ * </p>
+ * <p>
+ * If all the eigenvectors can fit in memory, --inMemory allows for a speedier completion of this task by doing so.
+ * </p>
+ */
+public class EigenVerificationJob extends AbstractJob {
+
+ public static final String CLEAN_EIGENVECTORS = "cleanEigenvectors";
+
+ private static final Logger log = LoggerFactory.getLogger(EigenVerificationJob.class);
+
+ private SingularVectorVerifier eigenVerifier;
+
+ private VectorIterable eigensToVerify;
+
+ private VectorIterable corpus;
+
+ private double maxError;
+
+ private double minEigenValue;
+
+ // private boolean loadEigensInMemory;
+
+ private Path tmpOut;
+
+ private Path outPath;
+
+ private int maxEigensToKeep;
+
+ private Path cleanedEigensPath;
+
+ public void setEigensToVerify(VectorIterable eigens) {
+ eigensToVerify = eigens;
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ Map<String,List<String>> argMap = handleArgs(args);
+ if (argMap == null) {
+ return -1;
+ }
+ if (argMap.isEmpty()) {
+ return 0;
+ }
+ // parse out the arguments
+ runJob(getConf(), new Path(getOption("eigenInput")), new Path(getOption("corpusInput")), getOutputPath(),
+ getOption("inMemory") != null, Double.parseDouble(getOption("maxError")),
+ // Double.parseDouble(getOption("minEigenvalue")),
+ Integer.parseInt(getOption("maxEigens")));
+ return 0;
+ }
+
+ /**
+ * Run the job with the given arguments
+ *
+ * @param corpusInput
+ * the corpus input Path
+ * @param eigenInput
+ * the eigenvector input Path
+ * @param output
+ * the output Path
+ * @param tempOut
+ * temporary output Path
+ * @param maxError
+ * a double representing the maximum error
+ * @param minEigenValue
+ * a double representing the minimum eigenvalue
+ * @param inMemory
+ * a boolean requesting in-memory preparation
+ * @param conf
+ * the Configuration to use, or null if a default is ok (saves referencing Configuration in calling classes
+ * unless needed)
+ */
+ public int run(Path corpusInput, Path eigenInput, Path output, Path tempOut, double maxError, double minEigenValue,
+ boolean inMemory, Configuration conf) throws IOException {
+ this.outPath = output;
+ this.tmpOut = tempOut;
+ this.maxError = maxError;
+ this.minEigenValue = minEigenValue;
+
+ if (eigenInput != null && eigensToVerify == null) {
+ prepareEigens(conf, eigenInput, inMemory);
+ }
+ DistributedRowMatrix c = new DistributedRowMatrix(corpusInput, tempOut, 1, 1);
+ c.setConf(conf);
+ corpus = c;
+
+ // set up eigenverifier and orthoverifier TODO: allow multithreaded execution
+
+ eigenVerifier = new SimpleEigenVerifier();
+
+ // we don't currently verify orthonormality here.
+ // VectorIterable pairwiseInnerProducts = computePairwiseInnerProducts();
+
+ Map<MatrixSlice,EigenStatus> eigenMetaData = verifyEigens();
+
+ List<Map.Entry<MatrixSlice,EigenStatus>> prunedEigenMeta = pruneEigens(eigenMetaData);
+
+ saveCleanEigens(new Configuration(), prunedEigenMeta);
+ return 0;
+ }
+
+ private Map<String,List<String>> handleArgs(String[] args) throws IOException {
+ addOutputOption();
+ addOption("eigenInput", "ei",
+ "The Path for purported eigenVector input files (SequenceFile<WritableComparable,VectorWritable>.", null);
+ addOption("corpusInput", "ci", "The Path for corpus input files (SequenceFile<WritableComparable,VectorWritable>.");
+ addOption(DefaultOptionCreator.outputOption().create());
+ addOption(DefaultOptionCreator.helpOption());
+ addOption("inMemory", "mem", "Buffer eigen matrix into memory (if you have enough!)", "false");
+ addOption("maxError", "err", "Maximum acceptable error", "0.05");
+ addOption("minEigenvalue", "mev", "Minimum eigenvalue to keep the vector for", "0.0");
+ addOption("maxEigens", "max", "Maximum number of eigenvectors to keep (0 means all)", "0");
+
+ return parseArguments(args);
+ }
+
+ private void saveCleanEigens(Configuration conf, Collection<Map.Entry<MatrixSlice,EigenStatus>> prunedEigenMeta)
+ throws IOException {
+ Path path = new Path(outPath, CLEAN_EIGENVECTORS);
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ SequenceFile.Writer seqWriter = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class);
+ try {
+ IntWritable iw = new IntWritable();
+ int numEigensWritten = 0;
+ int index = 0;
+ for (Map.Entry<MatrixSlice,EigenStatus> pruneSlice : prunedEigenMeta) {
+ MatrixSlice s = pruneSlice.getKey();
+ EigenStatus meta = pruneSlice.getValue();
+ EigenVector ev = new EigenVector(s.vector(), meta.getEigenValue(), Math.abs(1 - meta.getCosAngle()), s.index());
+ // log.info("appending {} to {}", ev, path);
+ Writable vw = new VectorWritable(ev);
+ iw.set(index++);
+ seqWriter.append(iw, vw);
+
+ // increment the number of eigenvectors written and see if we've
+ // reached our specified limit, or if we wish to write all eigenvectors
+ // (latter is built-in, since numEigensWritten will always be > 0
+ numEigensWritten++;
+ if (numEigensWritten == maxEigensToKeep) {
+ log.info("{} of the {} total eigens have been written", maxEigensToKeep, prunedEigenMeta.size());
+ break;
+ }
+ }
+ } finally {
+ Closeables.close(seqWriter, false);
+ }
+ cleanedEigensPath = path;
+ }
+
+ private List<Map.Entry<MatrixSlice,EigenStatus>> pruneEigens(Map<MatrixSlice,EigenStatus> eigenMetaData) {
+ List<Map.Entry<MatrixSlice,EigenStatus>> prunedEigenMeta = Lists.newArrayList();
+
+ for (Map.Entry<MatrixSlice,EigenStatus> entry : eigenMetaData.entrySet()) {
+ if (Math.abs(1 - entry.getValue().getCosAngle()) < maxError && entry.getValue().getEigenValue() > minEigenValue) {
+ prunedEigenMeta.add(entry);
+ }
+ }
+
+ Collections.sort(prunedEigenMeta, new Comparator<Map.Entry<MatrixSlice,EigenStatus>>() {
+ @Override
+ public int compare(Map.Entry<MatrixSlice,EigenStatus> e1, Map.Entry<MatrixSlice,EigenStatus> e2) {
+ // sort eigens on eigenvalues in descending order
+ Double eg1 = e1.getValue().getEigenValue();
+ Double eg2 = e2.getValue().getEigenValue();
+ return eg1.compareTo(eg2);
+ }
+ });
+
+ // iterate thru' the eigens, pick up ones with max orthogonality with the selected ones
+ List<Map.Entry<MatrixSlice,EigenStatus>> selectedEigenMeta = Lists.newArrayList();
+ Map.Entry<MatrixSlice,EigenStatus> e1 = prunedEigenMeta.remove(0);
+ selectedEigenMeta.add(e1);
+ int selectedEigenMetaLength = selectedEigenMeta.size();
+ int prunedEigenMetaLength = prunedEigenMeta.size();
+
+ while (prunedEigenMetaLength > 0) {
+ double sum = Double.MAX_VALUE;
+ int index = 0;
+ for (int i = 0; i < prunedEigenMetaLength; i++) {
+ Map.Entry<MatrixSlice,EigenStatus> e = prunedEigenMeta.get(i);
+ double tmp = 0;
+ for (int j = 0; j < selectedEigenMetaLength; j++) {
+ Map.Entry<MatrixSlice,EigenStatus> ee = selectedEigenMeta.get(j);
+ tmp += ee.getKey().vector().times(e.getKey().vector()).norm(2);
+ }
+ if (tmp < sum) {
+ sum = tmp;
+ index = i;
+ }
+ }
+ Map.Entry<MatrixSlice,EigenStatus> e = prunedEigenMeta.remove(index);
+ selectedEigenMeta.add(e);
+ selectedEigenMetaLength++;
+ prunedEigenMetaLength--;
+ }
+
+ return selectedEigenMeta;
+ }
+
+ private Map<MatrixSlice,EigenStatus> verifyEigens() {
+ Map<MatrixSlice,EigenStatus> eigenMetaData = Maps.newHashMap();
+
+ for (MatrixSlice slice : eigensToVerify) {
+ EigenStatus status = eigenVerifier.verify(corpus, slice.vector());
+ eigenMetaData.put(slice, status);
+ }
+ return eigenMetaData;
+ }
+
+ private void prepareEigens(Configuration conf, Path eigenInput, boolean inMemory) {
+ DistributedRowMatrix eigens = new DistributedRowMatrix(eigenInput, tmpOut, 1, 1);
+ eigens.setConf(conf);
+ if (inMemory) {
+ List<Vector> eigenVectors = Lists.newArrayList();
+ for (MatrixSlice slice : eigens) {
+ eigenVectors.add(slice.vector());
+ }
+ eigensToVerify = new SparseRowMatrix(eigenVectors.size(), eigenVectors.get(0).size(),
+ eigenVectors.toArray(new Vector[eigenVectors.size()]), true, true);
+
+ } else {
+ eigensToVerify = eigens;
+ }
+ }
+
+ public Path getCleanedEigensPath() {
+ return cleanedEigensPath;
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new EigenVerificationJob(), args);
+ }
+
+ /**
+ * Progammatic invocation of run()
+ *
+ * @param eigenInput
+ * Output of LanczosSolver
+ * @param corpusInput
+ * Input of LanczosSolver
+ */
+ public void runJob(Configuration conf, Path eigenInput, Path corpusInput, Path output, boolean inMemory,
+ double maxError, int maxEigens) throws IOException {
+ // no need to handle command line arguments
+ outPath = output;
+ tmpOut = new Path(outPath, "tmp");
+ maxEigensToKeep = maxEigens;
+ this.maxError = maxError;
+ if (eigenInput != null && eigensToVerify == null) {
+ prepareEigens(new Configuration(conf), eigenInput, inMemory);
+ }
+
+ DistributedRowMatrix c = new DistributedRowMatrix(corpusInput, tmpOut, 1, 1);
+ c.setConf(new Configuration(conf));
+ corpus = c;
+
+ eigenVerifier = new SimpleEigenVerifier();
+
+ Map<MatrixSlice,EigenStatus> eigenMetaData = verifyEigens();
+ List<Map.Entry<MatrixSlice,EigenStatus>> prunedEigenMeta = pruneEigens(eigenMetaData);
+ saveCleanEigens(conf, prunedEigenMeta);
+ }
+}
[45/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
new file mode 100644
index 0000000..76e5239
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
@@ -0,0 +1,100 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.io.Serializable;
+
+import com.google.common.base.Preconditions;
+
+public class WeightedRunningAverage implements RunningAverage, Serializable {
+
+ private double totalWeight;
+ private double average;
+
+ public WeightedRunningAverage() {
+ totalWeight = 0.0;
+ average = Double.NaN;
+ }
+
+ @Override
+ public synchronized void addDatum(double datum) {
+ addDatum(datum, 1.0);
+ }
+
+ public synchronized void addDatum(double datum, double weight) {
+ double oldTotalWeight = totalWeight;
+ totalWeight += weight;
+ if (oldTotalWeight <= 0.0) {
+ average = datum;
+ } else {
+ average = average * oldTotalWeight / totalWeight + datum * weight / totalWeight;
+ }
+ }
+
+ @Override
+ public synchronized void removeDatum(double datum) {
+ removeDatum(datum, 1.0);
+ }
+
+ public synchronized void removeDatum(double datum, double weight) {
+ double oldTotalWeight = totalWeight;
+ totalWeight -= weight;
+ if (totalWeight <= 0.0) {
+ average = Double.NaN;
+ totalWeight = 0.0;
+ } else {
+ average = average * oldTotalWeight / totalWeight - datum * weight / totalWeight;
+ }
+ }
+
+ @Override
+ public synchronized void changeDatum(double delta) {
+ changeDatum(delta, 1.0);
+ }
+
+ public synchronized void changeDatum(double delta, double weight) {
+ Preconditions.checkArgument(weight <= totalWeight, "weight must be <= totalWeight");
+ average += delta * weight / totalWeight;
+ }
+
+ public synchronized double getTotalWeight() {
+ return totalWeight;
+ }
+
+ /** @return {@link #getTotalWeight()} */
+ @Override
+ public synchronized int getCount() {
+ return (int) totalWeight;
+ }
+
+ @Override
+ public synchronized double getAverage() {
+ return average;
+ }
+
+ @Override
+ public RunningAverage inverse() {
+ return new InvertedRunningAverage(this);
+ }
+
+ @Override
+ public synchronized String toString() {
+ return String.valueOf(average);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java
new file mode 100644
index 0000000..bed5812
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+/**
+ * This subclass also provides for a weighted estimate of the sample standard deviation.
+ * See <a href="http://en.wikipedia.org/wiki/Mean_square_weighted_deviation">estimate formulae here</a>.
+ */
+public final class WeightedRunningAverageAndStdDev extends WeightedRunningAverage implements RunningAverageAndStdDev {
+
+ private double totalSquaredWeight;
+ private double totalWeightedData;
+ private double totalWeightedSquaredData;
+
+ public WeightedRunningAverageAndStdDev() {
+ totalSquaredWeight = 0.0;
+ totalWeightedData = 0.0;
+ totalWeightedSquaredData = 0.0;
+ }
+
+ @Override
+ public synchronized void addDatum(double datum, double weight) {
+ super.addDatum(datum, weight);
+ totalSquaredWeight += weight * weight;
+ double weightedData = datum * weight;
+ totalWeightedData += weightedData;
+ totalWeightedSquaredData += weightedData * datum;
+ }
+
+ @Override
+ public synchronized void removeDatum(double datum, double weight) {
+ super.removeDatum(datum, weight);
+ totalSquaredWeight -= weight * weight;
+ if (totalSquaredWeight <= 0.0) {
+ totalSquaredWeight = 0.0;
+ }
+ double weightedData = datum * weight;
+ totalWeightedData -= weightedData;
+ if (totalWeightedData <= 0.0) {
+ totalWeightedData = 0.0;
+ }
+ totalWeightedSquaredData -= weightedData * datum;
+ if (totalWeightedSquaredData <= 0.0) {
+ totalWeightedSquaredData = 0.0;
+ }
+ }
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ @Override
+ public synchronized void changeDatum(double delta, double weight) {
+ throw new UnsupportedOperationException();
+ }
+
+
+ @Override
+ public synchronized double getStandardDeviation() {
+ double totalWeight = getTotalWeight();
+ return Math.sqrt((totalWeightedSquaredData * totalWeight - totalWeightedData * totalWeightedData)
+ / (totalWeight * totalWeight - totalSquaredWeight));
+ }
+
+ @Override
+ public RunningAverageAndStdDev inverse() {
+ return new InvertedRunningAverageAndStdDev(this);
+ }
+
+ @Override
+ public synchronized String toString() {
+ return String.valueOf(String.valueOf(getAverage()) + ',' + getStandardDeviation());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java
new file mode 100644
index 0000000..d1e93ab
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/AbstractJDBCComponent.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common.jdbc;
+
+import javax.naming.Context;
+import javax.naming.InitialContext;
+import javax.naming.NamingException;
+import javax.sql.DataSource;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * A helper class with common elements for several JDBC-related components.
+ */
+public abstract class AbstractJDBCComponent {
+
+ private static final Logger log = LoggerFactory.getLogger(AbstractJDBCComponent.class);
+
+ private static final int DEFAULT_FETCH_SIZE = 1000; // A max, "big" number of rows to buffer at once
+ protected static final String DEFAULT_DATASOURCE_NAME = "jdbc/taste";
+
+ protected static void checkNotNullAndLog(String argName, Object value) {
+ Preconditions.checkArgument(value != null && !value.toString().isEmpty(),
+ argName + " is null or empty");
+ log.debug("{}: {}", argName, value);
+ }
+
+ protected static void checkNotNullAndLog(String argName, Object[] values) {
+ Preconditions.checkArgument(values != null && values.length != 0, argName + " is null or zero-length");
+ for (Object value : values) {
+ checkNotNullAndLog(argName, value);
+ }
+ }
+
+ /**
+ * <p>
+ * Looks up a {@link DataSource} by name from JNDI. "java:comp/env/" is prepended to the argument before
+ * looking up the name in JNDI.
+ * </p>
+ *
+ * @param dataSourceName
+ * JNDI name where a {@link DataSource} is bound (e.g. "jdbc/taste")
+ * @return {@link DataSource} under that JNDI name
+ * @throws TasteException
+ * if a JNDI error occurs
+ */
+ public static DataSource lookupDataSource(String dataSourceName) throws TasteException {
+ Context context = null;
+ try {
+ context = new InitialContext();
+ return (DataSource) context.lookup("java:comp/env/" + dataSourceName);
+ } catch (NamingException ne) {
+ throw new TasteException(ne);
+ } finally {
+ if (context != null) {
+ try {
+ context.close();
+ } catch (NamingException ne) {
+ log.warn("Error while closing Context; continuing...", ne);
+ }
+ }
+ }
+ }
+
+ protected int getFetchSize() {
+ return DEFAULT_FETCH_SIZE;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java
new file mode 100644
index 0000000..3f024bc
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/EachRowIterator.java
@@ -0,0 +1,92 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common.jdbc;
+
+import javax.sql.DataSource;
+import java.io.Closeable;
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+
+import com.google.common.collect.AbstractIterator;
+import org.apache.mahout.common.IOUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Provides an {@link java.util.Iterator} over the result of an SQL query, as an iteration over the {@link ResultSet}.
+ * While the same object will be returned from the iteration each time, it will be returned once for each row
+ * of the result.
+ */
+final class EachRowIterator extends AbstractIterator<ResultSet> implements Closeable {
+
+ private static final Logger log = LoggerFactory.getLogger(EachRowIterator.class);
+
+ private final Connection connection;
+ private final PreparedStatement statement;
+ private final ResultSet resultSet;
+
+ EachRowIterator(DataSource dataSource, String sqlQuery) throws SQLException {
+ try {
+ connection = dataSource.getConnection();
+ statement = connection.prepareStatement(sqlQuery, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+ statement.setFetchDirection(ResultSet.FETCH_FORWARD);
+ //statement.setFetchSize(getFetchSize());
+ log.debug("Executing SQL query: {}", sqlQuery);
+ resultSet = statement.executeQuery();
+ } catch (SQLException sqle) {
+ close();
+ throw sqle;
+ }
+ }
+
+ @Override
+ protected ResultSet computeNext() {
+ try {
+ if (resultSet.next()) {
+ return resultSet;
+ } else {
+ close();
+ return null;
+ }
+ } catch (SQLException sqle) {
+ close();
+ throw new IllegalStateException(sqle);
+ }
+ }
+
+ public void skip(int n) throws SQLException {
+ try {
+ resultSet.relative(n);
+ } catch (SQLException sqle) {
+ // Can't use relative on MySQL Connector/J; try advancing manually
+ int i = 0;
+ while (i < n && resultSet.next()) {
+ i++;
+ }
+ }
+ }
+
+ @Override
+ public void close() {
+ IOUtils.quietClose(resultSet, statement, connection);
+ endOfData();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java
new file mode 100644
index 0000000..273ebd5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/jdbc/ResultSetIterator.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common.jdbc;
+
+import javax.sql.DataSource;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.Iterator;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Iterators;
+
+public abstract class ResultSetIterator<T> extends ForwardingIterator<T> {
+
+ private final Iterator<T> delegate;
+ private final EachRowIterator rowDelegate;
+
+ protected ResultSetIterator(DataSource dataSource, String sqlQuery) throws SQLException {
+ this.rowDelegate = new EachRowIterator(dataSource, sqlQuery);
+ delegate = Iterators.transform(rowDelegate,
+ new Function<ResultSet, T>() {
+ @Override
+ public T apply(ResultSet from) {
+ try {
+ return parseElement(from);
+ } catch (SQLException sqle) {
+ throw new IllegalStateException(sqle);
+ }
+ }
+ });
+ }
+
+ @Override
+ protected Iterator<T> delegate() {
+ return delegate;
+ }
+
+ protected abstract T parseElement(ResultSet resultSet) throws SQLException;
+
+ public void skip(int n) {
+ if (n >= 1) {
+ try {
+ rowDelegate.skip(n);
+ } catch (SQLException sqle) {
+ throw new IllegalStateException(sqle);
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java
new file mode 100644
index 0000000..f6598f3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AbstractDifferenceRecommenderEvaluator.java
@@ -0,0 +1,277 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.DataModelBuilder;
+import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
+import org.apache.mahout.cf.taste.eval.RecommenderEvaluator;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Abstract superclass of a couple implementations, providing shared functionality.
+ */
+public abstract class AbstractDifferenceRecommenderEvaluator implements RecommenderEvaluator {
+
+ private static final Logger log = LoggerFactory.getLogger(AbstractDifferenceRecommenderEvaluator.class);
+
+ private final Random random;
+ private float maxPreference;
+ private float minPreference;
+
+ protected AbstractDifferenceRecommenderEvaluator() {
+ random = RandomUtils.getRandom();
+ maxPreference = Float.NaN;
+ minPreference = Float.NaN;
+ }
+
+ @Override
+ public final float getMaxPreference() {
+ return maxPreference;
+ }
+
+ @Override
+ public final void setMaxPreference(float maxPreference) {
+ this.maxPreference = maxPreference;
+ }
+
+ @Override
+ public final float getMinPreference() {
+ return minPreference;
+ }
+
+ @Override
+ public final void setMinPreference(float minPreference) {
+ this.minPreference = minPreference;
+ }
+
+ @Override
+ public double evaluate(RecommenderBuilder recommenderBuilder,
+ DataModelBuilder dataModelBuilder,
+ DataModel dataModel,
+ double trainingPercentage,
+ double evaluationPercentage) throws TasteException {
+ Preconditions.checkNotNull(recommenderBuilder);
+ Preconditions.checkNotNull(dataModel);
+ Preconditions.checkArgument(trainingPercentage >= 0.0 && trainingPercentage <= 1.0,
+ "Invalid trainingPercentage: " + trainingPercentage + ". Must be: 0.0 <= trainingPercentage <= 1.0");
+ Preconditions.checkArgument(evaluationPercentage >= 0.0 && evaluationPercentage <= 1.0,
+ "Invalid evaluationPercentage: " + evaluationPercentage + ". Must be: 0.0 <= evaluationPercentage <= 1.0");
+
+ log.info("Beginning evaluation using {} of {}", trainingPercentage, dataModel);
+
+ int numUsers = dataModel.getNumUsers();
+ FastByIDMap<PreferenceArray> trainingPrefs = new FastByIDMap<>(
+ 1 + (int) (evaluationPercentage * numUsers));
+ FastByIDMap<PreferenceArray> testPrefs = new FastByIDMap<>(
+ 1 + (int) (evaluationPercentage * numUsers));
+
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ long userID = it.nextLong();
+ if (random.nextDouble() < evaluationPercentage) {
+ splitOneUsersPrefs(trainingPercentage, trainingPrefs, testPrefs, userID, dataModel);
+ }
+ }
+
+ DataModel trainingModel = dataModelBuilder == null ? new GenericDataModel(trainingPrefs)
+ : dataModelBuilder.buildDataModel(trainingPrefs);
+
+ Recommender recommender = recommenderBuilder.buildRecommender(trainingModel);
+
+ double result = getEvaluation(testPrefs, recommender);
+ log.info("Evaluation result: {}", result);
+ return result;
+ }
+
+ private void splitOneUsersPrefs(double trainingPercentage,
+ FastByIDMap<PreferenceArray> trainingPrefs,
+ FastByIDMap<PreferenceArray> testPrefs,
+ long userID,
+ DataModel dataModel) throws TasteException {
+ List<Preference> oneUserTrainingPrefs = null;
+ List<Preference> oneUserTestPrefs = null;
+ PreferenceArray prefs = dataModel.getPreferencesFromUser(userID);
+ int size = prefs.length();
+ for (int i = 0; i < size; i++) {
+ Preference newPref = new GenericPreference(userID, prefs.getItemID(i), prefs.getValue(i));
+ if (random.nextDouble() < trainingPercentage) {
+ if (oneUserTrainingPrefs == null) {
+ oneUserTrainingPrefs = Lists.newArrayListWithCapacity(3);
+ }
+ oneUserTrainingPrefs.add(newPref);
+ } else {
+ if (oneUserTestPrefs == null) {
+ oneUserTestPrefs = Lists.newArrayListWithCapacity(3);
+ }
+ oneUserTestPrefs.add(newPref);
+ }
+ }
+ if (oneUserTrainingPrefs != null) {
+ trainingPrefs.put(userID, new GenericUserPreferenceArray(oneUserTrainingPrefs));
+ if (oneUserTestPrefs != null) {
+ testPrefs.put(userID, new GenericUserPreferenceArray(oneUserTestPrefs));
+ }
+ }
+ }
+
+ private float capEstimatedPreference(float estimate) {
+ if (estimate > maxPreference) {
+ return maxPreference;
+ }
+ if (estimate < minPreference) {
+ return minPreference;
+ }
+ return estimate;
+ }
+
+ private double getEvaluation(FastByIDMap<PreferenceArray> testPrefs, Recommender recommender)
+ throws TasteException {
+ reset();
+ Collection<Callable<Void>> estimateCallables = Lists.newArrayList();
+ AtomicInteger noEstimateCounter = new AtomicInteger();
+ for (Map.Entry<Long,PreferenceArray> entry : testPrefs.entrySet()) {
+ estimateCallables.add(
+ new PreferenceEstimateCallable(recommender, entry.getKey(), entry.getValue(), noEstimateCounter));
+ }
+ log.info("Beginning evaluation of {} users", estimateCallables.size());
+ RunningAverageAndStdDev timing = new FullRunningAverageAndStdDev();
+ execute(estimateCallables, noEstimateCounter, timing);
+ return computeFinalEvaluation();
+ }
+
+ protected static void execute(Collection<Callable<Void>> callables,
+ AtomicInteger noEstimateCounter,
+ RunningAverageAndStdDev timing) throws TasteException {
+
+ Collection<Callable<Void>> wrappedCallables = wrapWithStatsCallables(callables, noEstimateCounter, timing);
+ int numProcessors = Runtime.getRuntime().availableProcessors();
+ ExecutorService executor = Executors.newFixedThreadPool(numProcessors);
+ log.info("Starting timing of {} tasks in {} threads", wrappedCallables.size(), numProcessors);
+ try {
+ List<Future<Void>> futures = executor.invokeAll(wrappedCallables);
+ // Go look for exceptions here, really
+ for (Future<Void> future : futures) {
+ future.get();
+ }
+
+ } catch (InterruptedException ie) {
+ throw new TasteException(ie);
+ } catch (ExecutionException ee) {
+ throw new TasteException(ee.getCause());
+ }
+
+ executor.shutdown();
+ try {
+ executor.awaitTermination(10, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ throw new TasteException(e.getCause());
+ }
+ }
+
+ private static Collection<Callable<Void>> wrapWithStatsCallables(Iterable<Callable<Void>> callables,
+ AtomicInteger noEstimateCounter,
+ RunningAverageAndStdDev timing) {
+ Collection<Callable<Void>> wrapped = Lists.newArrayList();
+ int count = 0;
+ for (Callable<Void> callable : callables) {
+ boolean logStats = count++ % 1000 == 0; // log every 1000 or so iterations
+ wrapped.add(new StatsCallable(callable, logStats, timing, noEstimateCounter));
+ }
+ return wrapped;
+ }
+
+ protected abstract void reset();
+
+ protected abstract void processOneEstimate(float estimatedPreference, Preference realPref);
+
+ protected abstract double computeFinalEvaluation();
+
+ public final class PreferenceEstimateCallable implements Callable<Void> {
+
+ private final Recommender recommender;
+ private final long testUserID;
+ private final PreferenceArray prefs;
+ private final AtomicInteger noEstimateCounter;
+
+ public PreferenceEstimateCallable(Recommender recommender,
+ long testUserID,
+ PreferenceArray prefs,
+ AtomicInteger noEstimateCounter) {
+ this.recommender = recommender;
+ this.testUserID = testUserID;
+ this.prefs = prefs;
+ this.noEstimateCounter = noEstimateCounter;
+ }
+
+ @Override
+ public Void call() throws TasteException {
+ for (Preference realPref : prefs) {
+ float estimatedPreference = Float.NaN;
+ try {
+ estimatedPreference = recommender.estimatePreference(testUserID, realPref.getItemID());
+ } catch (NoSuchUserException nsue) {
+ // It's possible that an item exists in the test data but not training data in which case
+ // NSEE will be thrown. Just ignore it and move on.
+ log.info("User exists in test data but not training data: {}", testUserID);
+ } catch (NoSuchItemException nsie) {
+ log.info("Item exists in test data but not training data: {}", realPref.getItemID());
+ }
+ if (Float.isNaN(estimatedPreference)) {
+ noEstimateCounter.incrementAndGet();
+ } else {
+ estimatedPreference = capEstimatedPreference(estimatedPreference);
+ processOneEstimate(estimatedPreference, realPref);
+ }
+ }
+ return null;
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java
new file mode 100644
index 0000000..4dad040
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/AverageAbsoluteDifferenceRecommenderEvaluator.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.Preference;
+
+/**
+ * <p>
+ * A {@link org.apache.mahout.cf.taste.eval.RecommenderEvaluator} which computes the average absolute
+ * difference between predicted and actual ratings for users.
+ * </p>
+ *
+ * <p>
+ * This algorithm is also called "mean average error".
+ * </p>
+ */
+public final class AverageAbsoluteDifferenceRecommenderEvaluator extends
+ AbstractDifferenceRecommenderEvaluator {
+
+ private RunningAverage average;
+
+ @Override
+ protected void reset() {
+ average = new FullRunningAverage();
+ }
+
+ @Override
+ protected void processOneEstimate(float estimatedPreference, Preference realPref) {
+ average.addDatum(Math.abs(realPref.getValue() - estimatedPreference));
+ }
+
+ @Override
+ protected double computeFinalEvaluation() {
+ return average.getAverage();
+ }
+
+ @Override
+ public String toString() {
+ return "AverageAbsoluteDifferenceRecommenderEvaluator";
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java
new file mode 100644
index 0000000..0e121d1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluator.java
@@ -0,0 +1,237 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import java.util.List;
+import java.util.Random;
+
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.DataModelBuilder;
+import org.apache.mahout.cf.taste.eval.IRStatistics;
+import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
+import org.apache.mahout.cf.taste.eval.RecommenderIRStatsEvaluator;
+import org.apache.mahout.cf.taste.eval.RelevantItemsDataSplitter;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.common.RandomUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * For each user, these implementation determine the top {@code n} preferences, then evaluate the IR
+ * statistics based on a {@link DataModel} that does not have these values. This number {@code n} is the
+ * "at" value, as in "precision at 5". For example, this would mean precision evaluated by removing the top 5
+ * preferences for a user and then finding the percentage of those 5 items included in the top 5
+ * recommendations for that user.
+ * </p>
+ */
+public final class GenericRecommenderIRStatsEvaluator implements RecommenderIRStatsEvaluator {
+
+ private static final Logger log = LoggerFactory.getLogger(GenericRecommenderIRStatsEvaluator.class);
+
+ private static final double LOG2 = Math.log(2.0);
+
+ /**
+ * Pass as "relevanceThreshold" argument to
+ * {@link #evaluate(RecommenderBuilder, DataModelBuilder, DataModel, IDRescorer, int, double, double)} to
+ * have it attempt to compute a reasonable threshold. Note that this will impact performance.
+ */
+ public static final double CHOOSE_THRESHOLD = Double.NaN;
+
+ private final Random random;
+ private final RelevantItemsDataSplitter dataSplitter;
+
+ public GenericRecommenderIRStatsEvaluator() {
+ this(new GenericRelevantItemsDataSplitter());
+ }
+
+ public GenericRecommenderIRStatsEvaluator(RelevantItemsDataSplitter dataSplitter) {
+ Preconditions.checkNotNull(dataSplitter);
+ random = RandomUtils.getRandom();
+ this.dataSplitter = dataSplitter;
+ }
+
+ @Override
+ public IRStatistics evaluate(RecommenderBuilder recommenderBuilder,
+ DataModelBuilder dataModelBuilder,
+ DataModel dataModel,
+ IDRescorer rescorer,
+ int at,
+ double relevanceThreshold,
+ double evaluationPercentage) throws TasteException {
+
+ Preconditions.checkArgument(recommenderBuilder != null, "recommenderBuilder is null");
+ Preconditions.checkArgument(dataModel != null, "dataModel is null");
+ Preconditions.checkArgument(at >= 1, "at must be at least 1");
+ Preconditions.checkArgument(evaluationPercentage > 0.0 && evaluationPercentage <= 1.0,
+ "Invalid evaluationPercentage: " + evaluationPercentage + ". Must be: 0.0 < evaluationPercentage <= 1.0");
+
+ int numItems = dataModel.getNumItems();
+ RunningAverage precision = new FullRunningAverage();
+ RunningAverage recall = new FullRunningAverage();
+ RunningAverage fallOut = new FullRunningAverage();
+ RunningAverage nDCG = new FullRunningAverage();
+ int numUsersRecommendedFor = 0;
+ int numUsersWithRecommendations = 0;
+
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+
+ long userID = it.nextLong();
+
+ if (random.nextDouble() >= evaluationPercentage) {
+ // Skipped
+ continue;
+ }
+
+ long start = System.currentTimeMillis();
+
+ PreferenceArray prefs = dataModel.getPreferencesFromUser(userID);
+
+ // List some most-preferred items that would count as (most) "relevant" results
+ double theRelevanceThreshold = Double.isNaN(relevanceThreshold) ? computeThreshold(prefs) : relevanceThreshold;
+ FastIDSet relevantItemIDs = dataSplitter.getRelevantItemsIDs(userID, at, theRelevanceThreshold, dataModel);
+
+ int numRelevantItems = relevantItemIDs.size();
+ if (numRelevantItems <= 0) {
+ continue;
+ }
+
+ FastByIDMap<PreferenceArray> trainingUsers = new FastByIDMap<>(dataModel.getNumUsers());
+ LongPrimitiveIterator it2 = dataModel.getUserIDs();
+ while (it2.hasNext()) {
+ dataSplitter.processOtherUser(userID, relevantItemIDs, trainingUsers, it2.nextLong(), dataModel);
+ }
+
+ DataModel trainingModel = dataModelBuilder == null ? new GenericDataModel(trainingUsers)
+ : dataModelBuilder.buildDataModel(trainingUsers);
+ try {
+ trainingModel.getPreferencesFromUser(userID);
+ } catch (NoSuchUserException nsee) {
+ continue; // Oops we excluded all prefs for the user -- just move on
+ }
+
+ int size = numRelevantItems + trainingModel.getItemIDsFromUser(userID).size();
+ if (size < 2 * at) {
+ // Really not enough prefs to meaningfully evaluate this user
+ continue;
+ }
+
+ Recommender recommender = recommenderBuilder.buildRecommender(trainingModel);
+
+ int intersectionSize = 0;
+ List<RecommendedItem> recommendedItems = recommender.recommend(userID, at, rescorer);
+ for (RecommendedItem recommendedItem : recommendedItems) {
+ if (relevantItemIDs.contains(recommendedItem.getItemID())) {
+ intersectionSize++;
+ }
+ }
+
+ int numRecommendedItems = recommendedItems.size();
+
+ // Precision
+ if (numRecommendedItems > 0) {
+ precision.addDatum((double) intersectionSize / (double) numRecommendedItems);
+ }
+
+ // Recall
+ recall.addDatum((double) intersectionSize / (double) numRelevantItems);
+
+ // Fall-out
+ if (numRelevantItems < size) {
+ fallOut.addDatum((double) (numRecommendedItems - intersectionSize)
+ / (double) (numItems - numRelevantItems));
+ }
+
+ // nDCG
+ // In computing, assume relevant IDs have relevance 1 and others 0
+ double cumulativeGain = 0.0;
+ double idealizedGain = 0.0;
+ for (int i = 0; i < numRecommendedItems; i++) {
+ RecommendedItem item = recommendedItems.get(i);
+ double discount = 1.0 / log2(i + 2.0); // Classical formulation says log(i+1), but i is 0-based here
+ if (relevantItemIDs.contains(item.getItemID())) {
+ cumulativeGain += discount;
+ }
+ // otherwise we're multiplying discount by relevance 0 so it doesn't do anything
+
+ // Ideally results would be ordered with all relevant ones first, so this theoretical
+ // ideal list starts with number of relevant items equal to the total number of relevant items
+ if (i < numRelevantItems) {
+ idealizedGain += discount;
+ }
+ }
+ if (idealizedGain > 0.0) {
+ nDCG.addDatum(cumulativeGain / idealizedGain);
+ }
+
+ // Reach
+ numUsersRecommendedFor++;
+ if (numRecommendedItems > 0) {
+ numUsersWithRecommendations++;
+ }
+
+ long end = System.currentTimeMillis();
+
+ log.info("Evaluated with user {} in {}ms", userID, end - start);
+ log.info("Precision/recall/fall-out/nDCG/reach: {} / {} / {} / {} / {}",
+ precision.getAverage(), recall.getAverage(), fallOut.getAverage(), nDCG.getAverage(),
+ (double) numUsersWithRecommendations / (double) numUsersRecommendedFor);
+ }
+
+ return new IRStatisticsImpl(
+ precision.getAverage(),
+ recall.getAverage(),
+ fallOut.getAverage(),
+ nDCG.getAverage(),
+ (double) numUsersWithRecommendations / (double) numUsersRecommendedFor);
+ }
+
+ private static double computeThreshold(PreferenceArray prefs) {
+ if (prefs.length() < 2) {
+ // Not enough data points -- return a threshold that allows everything
+ return Double.NEGATIVE_INFINITY;
+ }
+ RunningAverageAndStdDev stdDev = new FullRunningAverageAndStdDev();
+ int size = prefs.length();
+ for (int i = 0; i < size; i++) {
+ stdDev.addDatum(prefs.getValue(i));
+ }
+ return stdDev.getAverage() + stdDev.getStandardDeviation();
+ }
+
+ private static double log2(double value) {
+ return Math.log(value) / LOG2;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java
new file mode 100644
index 0000000..b0ef18c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/GenericRelevantItemsDataSplitter.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.eval.RelevantItemsDataSplitter;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Picks relevant items to be those with the strongest preference, and
+ * includes the other users' preferences in full.
+ */
+public final class GenericRelevantItemsDataSplitter implements RelevantItemsDataSplitter {
+
+ @Override
+ public FastIDSet getRelevantItemsIDs(long userID,
+ int at,
+ double relevanceThreshold,
+ DataModel dataModel) throws TasteException {
+ PreferenceArray prefs = dataModel.getPreferencesFromUser(userID);
+ FastIDSet relevantItemIDs = new FastIDSet(at);
+ prefs.sortByValueReversed();
+ for (int i = 0; i < prefs.length() && relevantItemIDs.size() < at; i++) {
+ if (prefs.getValue(i) >= relevanceThreshold) {
+ relevantItemIDs.add(prefs.getItemID(i));
+ }
+ }
+ return relevantItemIDs;
+ }
+
+ @Override
+ public void processOtherUser(long userID,
+ FastIDSet relevantItemIDs,
+ FastByIDMap<PreferenceArray> trainingUsers,
+ long otherUserID,
+ DataModel dataModel) throws TasteException {
+ PreferenceArray prefs2Array = dataModel.getPreferencesFromUser(otherUserID);
+ // If we're dealing with the very user that we're evaluating for precision/recall,
+ if (userID == otherUserID) {
+ // then must remove all the test IDs, the "relevant" item IDs
+ List<Preference> prefs2 = Lists.newArrayListWithCapacity(prefs2Array.length());
+ for (Preference pref : prefs2Array) {
+ prefs2.add(pref);
+ }
+ for (Iterator<Preference> iterator = prefs2.iterator(); iterator.hasNext();) {
+ Preference pref = iterator.next();
+ if (relevantItemIDs.contains(pref.getItemID())) {
+ iterator.remove();
+ }
+ }
+ if (!prefs2.isEmpty()) {
+ trainingUsers.put(otherUserID, new GenericUserPreferenceArray(prefs2));
+ }
+ } else {
+ // otherwise just add all those other user's prefs
+ trainingUsers.put(otherUserID, prefs2Array);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java
new file mode 100644
index 0000000..2838b08
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/IRStatisticsImpl.java
@@ -0,0 +1,95 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import java.io.Serializable;
+
+import org.apache.mahout.cf.taste.eval.IRStatistics;
+
+import com.google.common.base.Preconditions;
+
+public final class IRStatisticsImpl implements IRStatistics, Serializable {
+
+ private final double precision;
+ private final double recall;
+ private final double fallOut;
+ private final double ndcg;
+ private final double reach;
+
+ IRStatisticsImpl(double precision, double recall, double fallOut, double ndcg, double reach) {
+ Preconditions.checkArgument(Double.isNaN(precision) || (precision >= 0.0 && precision <= 1.0),
+ "Illegal precision: " + precision + ". Must be: 0.0 <= precision <= 1.0 or NaN");
+ Preconditions.checkArgument(Double.isNaN(recall) || (recall >= 0.0 && recall <= 1.0),
+ "Illegal recall: " + recall + ". Must be: 0.0 <= recall <= 1.0 or NaN");
+ Preconditions.checkArgument(Double.isNaN(fallOut) || (fallOut >= 0.0 && fallOut <= 1.0),
+ "Illegal fallOut: " + fallOut + ". Must be: 0.0 <= fallOut <= 1.0 or NaN");
+ Preconditions.checkArgument(Double.isNaN(ndcg) || (ndcg >= 0.0 && ndcg <= 1.0),
+ "Illegal nDCG: " + ndcg + ". Must be: 0.0 <= nDCG <= 1.0 or NaN");
+ Preconditions.checkArgument(Double.isNaN(reach) || (reach >= 0.0 && reach <= 1.0),
+ "Illegal reach: " + reach + ". Must be: 0.0 <= reach <= 1.0 or NaN");
+ this.precision = precision;
+ this.recall = recall;
+ this.fallOut = fallOut;
+ this.ndcg = ndcg;
+ this.reach = reach;
+ }
+
+ @Override
+ public double getPrecision() {
+ return precision;
+ }
+
+ @Override
+ public double getRecall() {
+ return recall;
+ }
+
+ @Override
+ public double getFallOut() {
+ return fallOut;
+ }
+
+ @Override
+ public double getF1Measure() {
+ return getFNMeasure(1.0);
+ }
+
+ @Override
+ public double getFNMeasure(double b) {
+ double b2 = b * b;
+ double sum = b2 * precision + recall;
+ return sum == 0.0 ? Double.NaN : (1.0 + b2) * precision * recall / sum;
+ }
+
+ @Override
+ public double getNormalizedDiscountedCumulativeGain() {
+ return ndcg;
+ }
+
+ @Override
+ public double getReach() {
+ return reach;
+ }
+
+ @Override
+ public String toString() {
+ return "IRStatisticsImpl[precision:" + precision + ",recall:" + recall + ",fallOut:"
+ + fallOut + ",nDCG:" + ndcg + ",reach:" + reach + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java
new file mode 100644
index 0000000..213f7f9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadCallable.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import org.apache.mahout.cf.taste.recommender.Recommender;
+
+import java.util.concurrent.Callable;
+
+final class LoadCallable implements Callable<Void> {
+
+ private final Recommender recommender;
+ private final long userID;
+
+ LoadCallable(Recommender recommender, long userID) {
+ this.recommender = recommender;
+ this.userID = userID;
+ }
+
+ @Override
+ public Void call() throws Exception {
+ recommender.recommend(userID, 10);
+ return null;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java
new file mode 100644
index 0000000..abb5ed8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluator.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.SamplingLongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+
+/**
+ * Simple helper class for running load on a Recommender.
+ */
+public final class LoadEvaluator {
+
+ private LoadEvaluator() { }
+
+ public static LoadStatistics runLoad(Recommender recommender) throws TasteException {
+ return runLoad(recommender, 10);
+ }
+
+ public static LoadStatistics runLoad(Recommender recommender, int howMany) throws TasteException {
+ DataModel dataModel = recommender.getDataModel();
+ int numUsers = dataModel.getNumUsers();
+ double sampleRate = 1000.0 / numUsers;
+ LongPrimitiveIterator userSampler =
+ SamplingLongPrimitiveIterator.maybeWrapIterator(dataModel.getUserIDs(), sampleRate);
+ recommender.recommend(userSampler.next(), howMany); // Warm up
+ Collection<Callable<Void>> callables = Lists.newArrayList();
+ while (userSampler.hasNext()) {
+ callables.add(new LoadCallable(recommender, userSampler.next()));
+ }
+ AtomicInteger noEstimateCounter = new AtomicInteger();
+ RunningAverageAndStdDev timing = new FullRunningAverageAndStdDev();
+ AbstractDifferenceRecommenderEvaluator.execute(callables, noEstimateCounter, timing);
+ return new LoadStatistics(timing);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java
new file mode 100644
index 0000000..f89160c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/LoadStatistics.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+
+public final class LoadStatistics {
+
+ private final RunningAverage timing;
+
+ LoadStatistics(RunningAverage timing) {
+ this.timing = timing;
+ }
+
+ public RunningAverage getTiming() {
+ return timing;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/OrderBasedRecommenderEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/OrderBasedRecommenderEvaluator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/OrderBasedRecommenderEvaluator.java
new file mode 100644
index 0000000..e267a39
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/OrderBasedRecommenderEvaluator.java
@@ -0,0 +1,431 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Evaluate recommender by comparing order of all raw prefs with order in
+ * recommender's output for that user. Can also compare data models.
+ */
+public final class OrderBasedRecommenderEvaluator {
+
+ private static final Logger log = LoggerFactory.getLogger(OrderBasedRecommenderEvaluator.class);
+
+ private OrderBasedRecommenderEvaluator() {
+ }
+
+ public static void evaluate(Recommender recommender1,
+ Recommender recommender2,
+ int samples,
+ RunningAverage tracker,
+ String tag) throws TasteException {
+ printHeader();
+ LongPrimitiveIterator users = recommender1.getDataModel().getUserIDs();
+
+ while (users.hasNext()) {
+ long userID = users.nextLong();
+ List<RecommendedItem> recs1 = recommender1.recommend(userID, samples);
+ List<RecommendedItem> recs2 = recommender2.recommend(userID, samples);
+ FastIDSet commonSet = new FastIDSet();
+ long maxItemID = setBits(commonSet, recs1, samples);
+ FastIDSet otherSet = new FastIDSet();
+ maxItemID = Math.max(maxItemID, setBits(otherSet, recs2, samples));
+ int max = mask(commonSet, otherSet, maxItemID);
+ max = Math.min(max, samples);
+ if (max < 2) {
+ continue;
+ }
+ Long[] items1 = getCommonItems(commonSet, recs1, max);
+ Long[] items2 = getCommonItems(commonSet, recs2, max);
+ double variance = scoreCommonSubset(tag, userID, samples, max, items1, items2);
+ tracker.addDatum(variance);
+ }
+ }
+
+ public static void evaluate(Recommender recommender,
+ DataModel model,
+ int samples,
+ RunningAverage tracker,
+ String tag) throws TasteException {
+ printHeader();
+ LongPrimitiveIterator users = recommender.getDataModel().getUserIDs();
+ while (users.hasNext()) {
+ long userID = users.nextLong();
+ List<RecommendedItem> recs1 = recommender.recommend(userID, model.getNumItems());
+ PreferenceArray prefs2 = model.getPreferencesFromUser(userID);
+ prefs2.sortByValueReversed();
+ FastIDSet commonSet = new FastIDSet();
+ long maxItemID = setBits(commonSet, recs1, samples);
+ FastIDSet otherSet = new FastIDSet();
+ maxItemID = Math.max(maxItemID, setBits(otherSet, prefs2, samples));
+ int max = mask(commonSet, otherSet, maxItemID);
+ max = Math.min(max, samples);
+ if (max < 2) {
+ continue;
+ }
+ Long[] items1 = getCommonItems(commonSet, recs1, max);
+ Long[] items2 = getCommonItems(commonSet, prefs2, max);
+ double variance = scoreCommonSubset(tag, userID, samples, max, items1, items2);
+ tracker.addDatum(variance);
+ }
+ }
+
+ public static void evaluate(DataModel model1,
+ DataModel model2,
+ int samples,
+ RunningAverage tracker,
+ String tag) throws TasteException {
+ printHeader();
+ LongPrimitiveIterator users = model1.getUserIDs();
+ while (users.hasNext()) {
+ long userID = users.nextLong();
+ PreferenceArray prefs1 = model1.getPreferencesFromUser(userID);
+ PreferenceArray prefs2 = model2.getPreferencesFromUser(userID);
+ prefs1.sortByValueReversed();
+ prefs2.sortByValueReversed();
+ FastIDSet commonSet = new FastIDSet();
+ long maxItemID = setBits(commonSet, prefs1, samples);
+ FastIDSet otherSet = new FastIDSet();
+ maxItemID = Math.max(maxItemID, setBits(otherSet, prefs2, samples));
+ int max = mask(commonSet, otherSet, maxItemID);
+ max = Math.min(max, samples);
+ if (max < 2) {
+ continue;
+ }
+ Long[] items1 = getCommonItems(commonSet, prefs1, max);
+ Long[] items2 = getCommonItems(commonSet, prefs2, max);
+ double variance = scoreCommonSubset(tag, userID, samples, max, items1, items2);
+ tracker.addDatum(variance);
+ }
+ }
+
+ /**
+ * This exists because FastIDSet has 'retainAll' as MASK, but there is
+ * no count of the number of items in the set. size() is supposed to do
+ * this but does not work.
+ */
+ private static int mask(FastIDSet commonSet, FastIDSet otherSet, long maxItemID) {
+ int count = 0;
+ for (int i = 0; i <= maxItemID; i++) {
+ if (commonSet.contains(i)) {
+ if (otherSet.contains(i)) {
+ count++;
+ } else {
+ commonSet.remove(i);
+ }
+ }
+ }
+ return count;
+ }
+
+ private static Long[] getCommonItems(FastIDSet commonSet, Iterable<RecommendedItem> recs, int max) {
+ Long[] commonItems = new Long[max];
+ int index = 0;
+ for (RecommendedItem rec : recs) {
+ Long item = rec.getItemID();
+ if (commonSet.contains(item)) {
+ commonItems[index++] = item;
+ }
+ if (index == max) {
+ break;
+ }
+ }
+ return commonItems;
+ }
+
+ private static Long[] getCommonItems(FastIDSet commonSet, PreferenceArray prefs1, int max) {
+ Long[] commonItems = new Long[max];
+ int index = 0;
+ for (int i = 0; i < prefs1.length(); i++) {
+ Long item = prefs1.getItemID(i);
+ if (commonSet.contains(item)) {
+ commonItems[index++] = item;
+ }
+ if (index == max) {
+ break;
+ }
+ }
+ return commonItems;
+ }
+
+ private static long setBits(FastIDSet modelSet, List<RecommendedItem> items, int max) {
+ long maxItem = -1;
+ for (int i = 0; i < items.size() && i < max; i++) {
+ long itemID = items.get(i).getItemID();
+ modelSet.add(itemID);
+ if (itemID > maxItem) {
+ maxItem = itemID;
+ }
+ }
+ return maxItem;
+ }
+
+ private static long setBits(FastIDSet modelSet, PreferenceArray prefs, int max) {
+ long maxItem = -1;
+ for (int i = 0; i < prefs.length() && i < max; i++) {
+ long itemID = prefs.getItemID(i);
+ modelSet.add(itemID);
+ if (itemID > maxItem) {
+ maxItem = itemID;
+ }
+ }
+ return maxItem;
+ }
+
+ private static void printHeader() {
+ log.info("tag,user,samples,common,hamming,bubble,rank,normal,score");
+ }
+
+ /**
+ * Common Subset Scoring
+ *
+ * These measurements are given the set of results that are common to both
+ * recommendation lists. They only get ordered lists.
+ *
+ * These measures all return raw numbers do not correlate among the tests.
+ * The numbers are not corrected against the total number of samples or the
+ * number of common items.
+ * The one contract is that all measures are 0 for an exact match and an
+ * increasing positive number as differences increase.
+ */
+ private static double scoreCommonSubset(String tag,
+ long userID,
+ int samples,
+ int subset,
+ Long[] itemsL,
+ Long[] itemsR) {
+ int[] vectorZ = new int[subset];
+ int[] vectorZabs = new int[subset];
+
+ long bubble = sort(itemsL, itemsR);
+ int hamming = slidingWindowHamming(itemsR, itemsL);
+ if (hamming > samples) {
+ throw new IllegalStateException();
+ }
+ getVectorZ(itemsR, itemsL, vectorZ, vectorZabs);
+ double normalW = normalWilcoxon(vectorZ, vectorZabs);
+ double meanRank = getMeanRank(vectorZabs);
+ // case statement for requested value
+ double variance = Math.sqrt(meanRank);
+ log.info("{},{},{},{},{},{},{},{},{}",
+ tag, userID, samples, subset, hamming, bubble, meanRank, normalW, variance);
+ return variance;
+ }
+
+ // simple sliding-window hamming distance: a[i or plus/minus 1] == b[i]
+ private static int slidingWindowHamming(Long[] itemsR, Long[] itemsL) {
+ int count = 0;
+ int samples = itemsR.length;
+
+ if (itemsR[0].equals(itemsL[0]) || itemsR[0].equals(itemsL[1])) {
+ count++;
+ }
+ for (int i = 1; i < samples - 1; i++) {
+ long itemID = itemsL[i];
+ if (itemsR[i] == itemID || itemsR[i - 1] == itemID || itemsR[i + 1] == itemID) {
+ count++;
+ }
+ }
+ if (itemsR[samples - 1].equals(itemsL[samples - 1]) || itemsR[samples - 1].equals(itemsL[samples - 2])) {
+ count++;
+ }
+ return count;
+ }
+
+ /**
+ * Normal-distribution probability value for matched sets of values.
+ * Based upon:
+ * http://comp9.psych.cornell.edu/Darlington/normscor.htm
+ *
+ * The Standard Wilcoxon is not used because it requires a lookup table.
+ */
+ static double normalWilcoxon(int[] vectorZ, int[] vectorZabs) {
+ int nitems = vectorZ.length;
+
+ double[] ranks = new double[nitems];
+ double[] ranksAbs = new double[nitems];
+ wilcoxonRanks(vectorZ, vectorZabs, ranks, ranksAbs);
+ return Math.min(getMeanWplus(ranks), getMeanWminus(ranks));
+ }
+
+ /**
+ * vector Z is a list of distances between the correct value and the recommended value
+ * Z[i] = position i of correct itemID - position of correct itemID in recommendation list
+ * can be positive or negative
+ * the smaller the better - means recommendations are closer
+ * both are the same length, and both sample from the same set
+ *
+ * destructive to items arrays - allows N log N instead of N^2 order
+ */
+ private static void getVectorZ(Long[] itemsR, Long[] itemsL, int[] vectorZ, int[] vectorZabs) {
+ int nitems = itemsR.length;
+ int bottom = 0;
+ int top = nitems - 1;
+ for (int i = 0; i < nitems; i++) {
+ long itemID = itemsR[i];
+ for (int j = bottom; j <= top; j++) {
+ if (itemsL[j] == null) {
+ continue;
+ }
+ long test = itemsL[j];
+ if (itemID == test) {
+ vectorZ[i] = i - j;
+ vectorZabs[i] = Math.abs(i - j);
+ if (j == bottom) {
+ bottom++;
+ } else if (j == top) {
+ top--;
+ } else {
+ itemsL[j] = null;
+ }
+ break;
+ }
+ }
+ }
+ }
+
+ /**
+ * Ranks are the position of the value from low to high, divided by the # of values.
+ * I had to walk through it a few times.
+ */
+ private static void wilcoxonRanks(int[] vectorZ, int[] vectorZabs, double[] ranks, double[] ranksAbs) {
+ int nitems = vectorZ.length;
+ int[] sorted = vectorZabs.clone();
+ Arrays.sort(sorted);
+ int zeros = 0;
+ for (; zeros < nitems; zeros++) {
+ if (sorted[zeros] > 0) {
+ break;
+ }
+ }
+ for (int i = 0; i < nitems; i++) {
+ double rank = 0.0;
+ int count = 0;
+ int score = vectorZabs[i];
+ for (int j = 0; j < nitems; j++) {
+ if (score == sorted[j]) {
+ rank += j + 1 - zeros;
+ count++;
+ } else if (score < sorted[j]) {
+ break;
+ }
+ }
+ if (vectorZ[i] != 0) {
+ ranks[i] = (rank / count) * (vectorZ[i] < 0 ? -1 : 1); // better be at least 1
+ ranksAbs[i] = Math.abs(ranks[i]);
+ }
+ }
+ }
+
+ private static double getMeanRank(int[] ranks) {
+ int nitems = ranks.length;
+ double sum = 0.0;
+ for (int rank : ranks) {
+ sum += rank;
+ }
+ return sum / nitems;
+ }
+
+ private static double getMeanWplus(double[] ranks) {
+ int nitems = ranks.length;
+ double sum = 0.0;
+ for (double rank : ranks) {
+ if (rank > 0) {
+ sum += rank;
+ }
+ }
+ return sum / nitems;
+ }
+
+ private static double getMeanWminus(double[] ranks) {
+ int nitems = ranks.length;
+ double sum = 0.0;
+ for (double rank : ranks) {
+ if (rank < 0) {
+ sum -= rank;
+ }
+ }
+ return sum / nitems;
+ }
+
+ /**
+ * Do bubble sort and return number of swaps needed to match preference lists.
+ * Sort itemsR using itemsL as the reference order.
+ */
+ static long sort(Long[] itemsL, Long[] itemsR) {
+ int length = itemsL.length;
+ if (length < 2) {
+ return 0;
+ }
+ if (length == 2) {
+ return itemsL[0].longValue() == itemsR[0].longValue() ? 0 : 1;
+ }
+ // 1) avoid changing originals; 2) primitive type is more efficient
+ long[] reference = new long[length];
+ long[] sortable = new long[length];
+ for (int i = 0; i < length; i++) {
+ reference[i] = itemsL[i];
+ sortable[i] = itemsR[i];
+ }
+ int sorted = 0;
+ long swaps = 0;
+ while (sorted < length - 1) {
+ // opportunistically trim back the top
+ while (length > 0 && reference[length - 1] == sortable[length - 1]) {
+ length--;
+ }
+ if (length == 0) {
+ break;
+ }
+ if (reference[sorted] == sortable[sorted]) {
+ sorted++;
+ } else {
+ for (int j = sorted; j < length - 1; j++) {
+ // do not swap anything already in place
+ int jump = 1;
+ if (reference[j] == sortable[j]) {
+ while (j + jump < length && reference[j + jump] == sortable[j + jump]) {
+ jump++;
+ }
+ }
+ if (j + jump < length && !(reference[j] == sortable[j] && reference[j + jump] == sortable[j + jump])) {
+ long tmp = sortable[j];
+ sortable[j] = sortable[j + 1];
+ sortable[j + 1] = tmp;
+ swaps++;
+ }
+ }
+ }
+ }
+ return swaps;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/RMSRecommenderEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/RMSRecommenderEvaluator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/RMSRecommenderEvaluator.java
new file mode 100644
index 0000000..97eda10
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/RMSRecommenderEvaluator.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.Preference;
+
+/**
+ * <p>
+ * A {@link org.apache.mahout.cf.taste.eval.RecommenderEvaluator} which computes the "root mean squared"
+ * difference between predicted and actual ratings for users. This is the square root of the average of this
+ * difference, squared.
+ * </p>
+ */
+public final class RMSRecommenderEvaluator extends AbstractDifferenceRecommenderEvaluator {
+
+ private RunningAverage average;
+
+ @Override
+ protected void reset() {
+ average = new FullRunningAverage();
+ }
+
+ @Override
+ protected void processOneEstimate(float estimatedPreference, Preference realPref) {
+ double diff = realPref.getValue() - estimatedPreference;
+ average.addDatum(diff * diff);
+ }
+
+ @Override
+ protected double computeFinalEvaluation() {
+ return Math.sqrt(average.getAverage());
+ }
+
+ @Override
+ public String toString() {
+ return "RMSRecommenderEvaluator";
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/StatsCallable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/StatsCallable.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/StatsCallable.java
new file mode 100644
index 0000000..036d0b4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/eval/StatsCallable.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.concurrent.Callable;
+import java.util.concurrent.atomic.AtomicInteger;
+
+final class StatsCallable implements Callable<Void> {
+
+ private static final Logger log = LoggerFactory.getLogger(StatsCallable.class);
+
+ private final Callable<Void> delegate;
+ private final boolean logStats;
+ private final RunningAverageAndStdDev timing;
+ private final AtomicInteger noEstimateCounter;
+
+ StatsCallable(Callable<Void> delegate,
+ boolean logStats,
+ RunningAverageAndStdDev timing,
+ AtomicInteger noEstimateCounter) {
+ this.delegate = delegate;
+ this.logStats = logStats;
+ this.timing = timing;
+ this.noEstimateCounter = noEstimateCounter;
+ }
+
+ @Override
+ public Void call() throws Exception {
+ long start = System.currentTimeMillis();
+ delegate.call();
+ long end = System.currentTimeMillis();
+ timing.addDatum(end - start);
+ if (logStats) {
+ Runtime runtime = Runtime.getRuntime();
+ int average = (int) timing.getAverage();
+ log.info("Average time per recommendation: {}ms", average);
+ long totalMemory = runtime.totalMemory();
+ long memory = totalMemory - runtime.freeMemory();
+ log.info("Approximate memory used: {}MB / {}MB", memory / 1000000L, totalMemory / 1000000L);
+ log.info("Unable to recommend in {} cases", noEstimateCounter.get());
+ }
+ return null;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractDataModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractDataModel.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractDataModel.java
new file mode 100644
index 0000000..a1a2a1f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractDataModel.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import org.apache.mahout.cf.taste.model.DataModel;
+
+/**
+ * Contains some features common to all implementations.
+ */
+public abstract class AbstractDataModel implements DataModel {
+
+ private float maxPreference;
+ private float minPreference;
+
+ protected AbstractDataModel() {
+ maxPreference = Float.NaN;
+ minPreference = Float.NaN;
+ }
+
+ @Override
+ public float getMaxPreference() {
+ return maxPreference;
+ }
+
+ protected void setMaxPreference(float maxPreference) {
+ this.maxPreference = maxPreference;
+ }
+
+ @Override
+ public float getMinPreference() {
+ return minPreference;
+ }
+
+ protected void setMinPreference(float minPreference) {
+ this.minPreference = minPreference;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractIDMigrator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractIDMigrator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractIDMigrator.java
new file mode 100644
index 0000000..94f2d0b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractIDMigrator.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+
+import java.util.Collection;
+
+import com.google.common.base.Charsets;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.model.IDMigrator;
+
+public abstract class AbstractIDMigrator implements IDMigrator {
+
+ private final MessageDigest md5Digest;
+
+ protected AbstractIDMigrator() {
+ try {
+ md5Digest = MessageDigest.getInstance("MD5");
+ } catch (NoSuchAlgorithmException nsae) {
+ // Can't happen
+ throw new IllegalStateException(nsae);
+ }
+ }
+
+ /**
+ * @return most significant 8 bytes of the MD5 hash of the string, as a long
+ */
+ protected final long hash(String value) {
+ byte[] md5hash;
+ synchronized (md5Digest) {
+ md5hash = md5Digest.digest(value.getBytes(Charsets.UTF_8));
+ md5Digest.reset();
+ }
+ long hash = 0L;
+ for (int i = 0; i < 8; i++) {
+ hash = hash << 8 | md5hash[i] & 0x00000000000000FFL;
+ }
+ return hash;
+ }
+
+ @Override
+ public long toLongID(String stringID) {
+ return hash(stringID);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractJDBCIDMigrator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractJDBCIDMigrator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractJDBCIDMigrator.java
new file mode 100644
index 0000000..cd3a434
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/AbstractJDBCIDMigrator.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+
+import javax.sql.DataSource;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.UpdatableIDMigrator;
+import org.apache.mahout.common.IOUtils;
+
+/**
+ * Implementation which stores the reverse long-to-String mapping in a database. Subclasses can override and
+ * configure the class to operate with particular databases by supplying appropriate SQL statements to the
+ * constructor.
+ */
+public abstract class AbstractJDBCIDMigrator extends AbstractIDMigrator implements UpdatableIDMigrator {
+
+ public static final String DEFAULT_MAPPING_TABLE = "taste_id_mapping";
+ public static final String DEFAULT_LONG_ID_COLUMN = "long_id";
+ public static final String DEFAULT_STRING_ID_COLUMN = "string_id";
+
+ private final DataSource dataSource;
+ private final String getStringIDSQL;
+ private final String storeMappingSQL;
+
+ /**
+ * @param getStringIDSQL
+ * SQL statement which selects one column, the String ID, from a mapping table. The statement
+ * should take one long parameter.
+ * @param storeMappingSQL
+ * SQL statement which saves a mapping from long to String. It should take two parameters, a long
+ * and a String.
+ */
+ protected AbstractJDBCIDMigrator(DataSource dataSource, String getStringIDSQL, String storeMappingSQL) {
+ this.dataSource = dataSource;
+ this.getStringIDSQL = getStringIDSQL;
+ this.storeMappingSQL = storeMappingSQL;
+ }
+
+ @Override
+ public final void storeMapping(long longID, String stringID) throws TasteException {
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(storeMappingSQL);
+ stmt.setLong(1, longID);
+ stmt.setString(2, stringID);
+ stmt.executeUpdate();
+ } catch (SQLException sqle) {
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.quietClose(null, stmt, conn);
+ }
+ }
+
+ @Override
+ public final String toStringID(long longID) throws TasteException {
+ Connection conn = null;
+ PreparedStatement stmt = null;
+ ResultSet rs = null;
+ try {
+ conn = dataSource.getConnection();
+ stmt = conn.prepareStatement(getStringIDSQL, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+ stmt.setFetchDirection(ResultSet.FETCH_FORWARD);
+ stmt.setFetchSize(1);
+ stmt.setLong(1, longID);
+ rs = stmt.executeQuery();
+ if (rs.next()) {
+ return rs.getString(1);
+ } else {
+ return null;
+ }
+ } catch (SQLException sqle) {
+ throw new TasteException(sqle);
+ } finally {
+ IOUtils.quietClose(rs, stmt, conn);
+ }
+ }
+
+ @Override
+ public void initialize(Iterable<String> stringIDs) throws TasteException {
+ for (String stringID : stringIDs) {
+ storeMapping(toLongID(stringID), stringID);
+ }
+ }
+
+}
[17/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java
new file mode 100644
index 0000000..378a885
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/YtYJob.java
@@ -0,0 +1,220 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import org.apache.commons.lang3.Validate;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.UpperTriangular;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+/**
+ * Job that accumulates Y'Y output
+ */
+public final class YtYJob {
+
+ public static final String PROP_OMEGA_SEED = "ssvd.omegaseed";
+ public static final String PROP_K = "ssvd.k";
+ public static final String PROP_P = "ssvd.p";
+
+ // we have single output, so we use standard output
+ public static final String OUTPUT_YT_Y = "part-";
+
+ private YtYJob() {
+ }
+
+ public static class YtYMapper extends
+ Mapper<Writable, VectorWritable, IntWritable, VectorWritable> {
+
+ private int kp;
+ private Omega omega;
+ private UpperTriangular mYtY;
+
+ /*
+ * we keep yRow in a dense form here but keep an eye not to dense up while
+ * doing YtY products. I am not sure that sparse vector would create much
+ * performance benefits since we must to assume that y would be more often
+ * dense than sparse, so for bulk dense operations that would perform
+ * somewhat better than a RandomAccessSparse vector frequent updates.
+ */
+ private Vector yRow;
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ int k = context.getConfiguration().getInt(PROP_K, -1);
+ int p = context.getConfiguration().getInt(PROP_P, -1);
+
+ Validate.isTrue(k > 0, "invalid k parameter");
+ Validate.isTrue(p > 0, "invalid p parameter");
+
+ kp = k + p;
+ long omegaSeed =
+ Long.parseLong(context.getConfiguration().get(PROP_OMEGA_SEED));
+
+ omega = new Omega(omegaSeed, k + p);
+
+ mYtY = new UpperTriangular(kp);
+
+ // see which one works better!
+ // yRow = new RandomAccessSparseVector(kp);
+ yRow = new DenseVector(kp);
+ }
+
+ @Override
+ protected void map(Writable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ omega.computeYRow(value.get(), yRow);
+ // compute outer product update for YtY
+
+ if (yRow.isDense()) {
+ for (int i = 0; i < kp; i++) {
+ double yi;
+ if ((yi = yRow.getQuick(i)) == 0.0) {
+ continue; // avoid densing up here unnecessarily
+ }
+ for (int j = i; j < kp; j++) {
+ double yj;
+ if ((yj = yRow.getQuick(j)) != 0.0) {
+ mYtY.setQuick(i, j, mYtY.getQuick(i, j) + yi * yj);
+ }
+ }
+ }
+ } else {
+ /*
+ * the disadvantage of using sparse vector (aside from the fact that we
+ * are creating some short-lived references) here is that we obviously
+ * do two times more iterations then necessary if y row is pretty dense.
+ */
+ for (Vector.Element eli : yRow.nonZeroes()) {
+ int i = eli.index();
+ for (Vector.Element elj : yRow.nonZeroes()) {
+ int j = elj.index();
+ if (j < i) {
+ continue;
+ }
+ mYtY.setQuick(i, j, mYtY.getQuick(i, j) + eli.get() * elj.get());
+ }
+ }
+ }
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+ context.write(new IntWritable(context.getTaskAttemptID().getTaskID()
+ .getId()),
+ new VectorWritable(new DenseVector(mYtY.getData())));
+ }
+ }
+
+ public static class YtYReducer extends
+ Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> {
+ private final VectorWritable accum = new VectorWritable();
+ private DenseVector acc;
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ int k = context.getConfiguration().getInt(PROP_K, -1);
+ int p = context.getConfiguration().getInt(PROP_P, -1);
+
+ Validate.isTrue(k > 0, "invalid k parameter");
+ Validate.isTrue(p > 0, "invalid p parameter");
+ accum.set(acc = new DenseVector(k + p));
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+ context.write(new IntWritable(), accum);
+ }
+
+ @Override
+ protected void reduce(IntWritable key,
+ Iterable<VectorWritable> values,
+ Context arg2) throws IOException,
+ InterruptedException {
+ for (VectorWritable vw : values) {
+ acc.addAll(vw.get());
+ }
+ }
+ }
+
+ public static void run(Configuration conf,
+ Path[] inputPaths,
+ Path outputPath,
+ int k,
+ int p,
+ long seed) throws ClassNotFoundException,
+ InterruptedException, IOException {
+
+ Job job = new Job(conf);
+ job.setJobName("YtY-job");
+ job.setJarByClass(YtYJob.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ FileInputFormat.setInputPaths(job, inputPaths);
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ SequenceFileOutputFormat.setOutputCompressionType(job,
+ CompressionType.BLOCK);
+
+ job.setMapOutputKeyClass(IntWritable.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+
+ job.setMapperClass(YtYMapper.class);
+
+ job.getConfiguration().setLong(PROP_OMEGA_SEED, seed);
+ job.getConfiguration().setInt(PROP_K, k);
+ job.getConfiguration().setInt(PROP_P, p);
+
+ /*
+ * we must reduce to just one matrix which means we need only one reducer.
+ * But it's ok since each mapper outputs only one vector (a packed
+ * UpperTriangular) so even if there're thousands of mappers, one reducer
+ * should cope just fine.
+ */
+ job.setNumReduceTasks(1);
+
+ job.submit();
+ job.waitForCompletion(false);
+
+ if (!job.isSuccessful()) {
+ throw new IOException("YtY job unsuccessful.");
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java
new file mode 100644
index 0000000..7033efe
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GivensThinSolver.java
@@ -0,0 +1,638 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd.qr;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.math.AbstractVector;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.OrderedIntDoubleMapping;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.UpperTriangular;
+
+/**
+ * Givens Thin solver. Standard Givens operations are reordered in a way that
+ * helps us to push them thru MapReduce operations in a block fashion.
+ */
+public class GivensThinSolver {
+
+ private double[] vARow;
+ private double[] vQtRow;
+ private final double[][] mQt;
+ private final double[][] mR;
+ private int qtStartRow;
+ private int rStartRow;
+ private int m;
+ private final int n; // m-row cnt, n- column count, m>=n
+ private int cnt;
+ private final double[] cs = new double[2];
+
+ public GivensThinSolver(int m, int n) {
+ if (!(m >= n)) {
+ throw new IllegalArgumentException("Givens thin QR: must be true: m>=n");
+ }
+
+ this.m = m;
+ this.n = n;
+
+ mQt = new double[n][];
+ mR = new double[n][];
+ vARow = new double[n];
+ vQtRow = new double[m];
+
+ for (int i = 0; i < n; i++) {
+ mQt[i] = new double[this.m];
+ mR[i] = new double[this.n];
+ }
+ cnt = 0;
+ }
+
+ public void reset() {
+ cnt = 0;
+ }
+
+ public void solve(Matrix a) {
+
+ assert a.rowSize() == m;
+ assert a.columnSize() == n;
+
+ double[] aRow = new double[n];
+ for (int i = 0; i < m; i++) {
+ Vector aRowV = a.viewRow(i);
+ for (int j = 0; j < n; j++) {
+ aRow[j] = aRowV.getQuick(j);
+ }
+ appendRow(aRow);
+ }
+ }
+
+ public boolean isFull() {
+ return cnt == m;
+ }
+
+ public int getM() {
+ return m;
+ }
+
+ public int getN() {
+ return n;
+ }
+
+ public int getCnt() {
+ return cnt;
+ }
+
+ public void adjust(int newM) {
+ if (newM == m) {
+ // no adjustment is required.
+ return;
+ }
+ if (newM < n) {
+ throw new IllegalArgumentException("new m can't be less than n");
+ }
+ if (newM < cnt) {
+ throw new IllegalArgumentException(
+ "new m can't be less than rows accumulated");
+ }
+ vQtRow = new double[newM];
+
+ // grow or shrink qt rows
+ if (newM > m) {
+ // grow qt rows
+ for (int i = 0; i < n; i++) {
+ mQt[i] = Arrays.copyOf(mQt[i], newM);
+ System.arraycopy(mQt[i], 0, mQt[i], newM - m, m);
+ Arrays.fill(mQt[i], 0, newM - m, 0);
+ }
+ } else {
+ // shrink qt rows
+ for (int i = 0; i < n; i++) {
+ mQt[i] = Arrays.copyOfRange(mQt[i], m - newM, m);
+ }
+ }
+
+ m = newM;
+
+ }
+
+ public void trim() {
+ adjust(cnt);
+ }
+
+ /**
+ * api for row-by-row addition
+ *
+ * @param aRow
+ */
+ public void appendRow(double[] aRow) {
+ if (cnt >= m) {
+ throw new IllegalStateException("thin QR solver fed more rows than initialized for");
+ }
+ try {
+ /*
+ * moving pointers around is inefficient but for the sanity's sake i am
+ * keeping it this way so i don't have to guess how R-tilde index maps to
+ * actual block index
+ */
+ Arrays.fill(vQtRow, 0);
+ vQtRow[m - cnt - 1] = 1;
+ int height = cnt > n ? n : cnt;
+ System.arraycopy(aRow, 0, vARow, 0, n);
+
+ if (height > 0) {
+ givens(vARow[0], getRRow(0)[0], cs);
+ applyGivensInPlace(cs[0], cs[1], vARow, getRRow(0), 0, n);
+ applyGivensInPlace(cs[0], cs[1], vQtRow, getQtRow(0), 0, m);
+ }
+
+ for (int i = 1; i < height; i++) {
+ givens(getRRow(i - 1)[i], getRRow(i)[i], cs);
+ applyGivensInPlace(cs[0], cs[1], getRRow(i - 1), getRRow(i), i,
+ n - i);
+ applyGivensInPlace(cs[0], cs[1], getQtRow(i - 1), getQtRow(i), 0,
+ m);
+ }
+ /*
+ * push qt and r-tilde 1 row down
+ *
+ * just swap the references to reduce GC churning
+ */
+ pushQtDown();
+ double[] swap = getQtRow(0);
+ setQtRow(0, vQtRow);
+ vQtRow = swap;
+
+ pushRDown();
+ swap = getRRow(0);
+ setRRow(0, vARow);
+ vARow = swap;
+
+ } finally {
+ cnt++;
+ }
+ }
+
+ private double[] getQtRow(int row) {
+
+ return mQt[(row += qtStartRow) >= n ? row - n : row];
+ }
+
+ private void setQtRow(int row, double[] qtRow) {
+ mQt[(row += qtStartRow) >= n ? row - n : row] = qtRow;
+ }
+
+ private void pushQtDown() {
+ qtStartRow = qtStartRow == 0 ? n - 1 : qtStartRow - 1;
+ }
+
+ private double[] getRRow(int row) {
+ row += rStartRow;
+ return mR[row >= n ? row - n : row];
+ }
+
+ private void setRRow(int row, double[] rrow) {
+ mR[(row += rStartRow) >= n ? row - n : row] = rrow;
+ }
+
+ private void pushRDown() {
+ rStartRow = rStartRow == 0 ? n - 1 : rStartRow - 1;
+ }
+
+ /*
+ * warning: both of these return actually n+1 rows with the last one being //
+ * not interesting.
+ */
+ public UpperTriangular getRTilde() {
+ UpperTriangular packedR = new UpperTriangular(n);
+ for (int i = 0; i < n; i++) {
+ packedR.assignNonZeroElementsInRow(i, getRRow(i));
+ }
+ return packedR;
+ }
+
+ public double[][] getThinQtTilde() {
+ if (qtStartRow != 0) {
+ /*
+ * rotate qt rows into place
+ *
+ * double[~500][], once per block, not a big deal.
+ */
+ double[][] qt = new double[n][];
+ System.arraycopy(mQt, qtStartRow, qt, 0, n - qtStartRow);
+ System.arraycopy(mQt, 0, qt, n - qtStartRow, qtStartRow);
+ return qt;
+ }
+ return mQt;
+ }
+
+ public static void applyGivensInPlace(double c, double s, double[] row1,
+ double[] row2, int offset, int len) {
+
+ int n = offset + len;
+ for (int j = offset; j < n; j++) {
+ double tau1 = row1[j];
+ double tau2 = row2[j];
+ row1[j] = c * tau1 - s * tau2;
+ row2[j] = s * tau1 + c * tau2;
+ }
+ }
+
+ public static void applyGivensInPlace(double c, double s, Vector row1,
+ Vector row2, int offset, int len) {
+
+ int n = offset + len;
+ for (int j = offset; j < n; j++) {
+ double tau1 = row1.getQuick(j);
+ double tau2 = row2.getQuick(j);
+ row1.setQuick(j, c * tau1 - s * tau2);
+ row2.setQuick(j, s * tau1 + c * tau2);
+ }
+ }
+
+ public static void applyGivensInPlace(double c, double s, int i, int k,
+ Matrix mx) {
+ int n = mx.columnSize();
+
+ for (int j = 0; j < n; j++) {
+ double tau1 = mx.get(i, j);
+ double tau2 = mx.get(k, j);
+ mx.set(i, j, c * tau1 - s * tau2);
+ mx.set(k, j, s * tau1 + c * tau2);
+ }
+ }
+
+ public static void fromRho(double rho, double[] csOut) {
+ if (rho == 1) {
+ csOut[0] = 0;
+ csOut[1] = 1;
+ return;
+ }
+ if (Math.abs(rho) < 1) {
+ csOut[1] = 2 * rho;
+ csOut[0] = Math.sqrt(1 - csOut[1] * csOut[1]);
+ return;
+ }
+ csOut[0] = 2 / rho;
+ csOut[1] = Math.sqrt(1 - csOut[0] * csOut[0]);
+ }
+
+ public static void givens(double a, double b, double[] csOut) {
+ if (b == 0) {
+ csOut[0] = 1;
+ csOut[1] = 0;
+ return;
+ }
+ if (Math.abs(b) > Math.abs(a)) {
+ double tau = -a / b;
+ csOut[1] = 1 / Math.sqrt(1 + tau * tau);
+ csOut[0] = csOut[1] * tau;
+ } else {
+ double tau = -b / a;
+ csOut[0] = 1 / Math.sqrt(1 + tau * tau);
+ csOut[1] = csOut[0] * tau;
+ }
+ }
+
+ public static double toRho(double c, double s) {
+ if (c == 0) {
+ return 1;
+ }
+ if (Math.abs(s) < Math.abs(c)) {
+ return Math.signum(c) * s / 2;
+ } else {
+ return Math.signum(s) * 2 / c;
+ }
+ }
+
+ public static void mergeR(UpperTriangular r1, UpperTriangular r2) {
+ TriangularRowView r1Row = new TriangularRowView(r1);
+ TriangularRowView r2Row = new TriangularRowView(r2);
+
+ int kp = r1Row.size();
+ assert kp == r2Row.size();
+
+ double[] cs = new double[2];
+
+ for (int v = 0; v < kp; v++) {
+ for (int u = v; u < kp; u++) {
+ givens(r1Row.setViewedRow(u).get(u), r2Row.setViewedRow(u - v).get(u),
+ cs);
+ applyGivensInPlace(cs[0], cs[1], r1Row, r2Row, u, kp - u);
+ }
+ }
+ }
+
+ public static void mergeR(double[][] r1, double[][] r2) {
+ int kp = r1[0].length;
+ assert kp == r2[0].length;
+
+ double[] cs = new double[2];
+
+ for (int v = 0; v < kp; v++) {
+ for (int u = v; u < kp; u++) {
+ givens(r1[u][u], r2[u - v][u], cs);
+ applyGivensInPlace(cs[0], cs[1], r1[u], r2[u - v], u, kp - u);
+ }
+ }
+
+ }
+
+ public static void mergeRonQ(UpperTriangular r1, UpperTriangular r2,
+ double[][] qt1, double[][] qt2) {
+ TriangularRowView r1Row = new TriangularRowView(r1);
+ TriangularRowView r2Row = new TriangularRowView(r2);
+ int kp = r1Row.size();
+ assert kp == r2Row.size();
+ assert kp == qt1.length;
+ assert kp == qt2.length;
+
+ int r = qt1[0].length;
+ assert qt2[0].length == r;
+
+ double[] cs = new double[2];
+
+ for (int v = 0; v < kp; v++) {
+ for (int u = v; u < kp; u++) {
+ givens(r1Row.setViewedRow(u).get(u), r2Row.setViewedRow(u - v).get(u),
+ cs);
+ applyGivensInPlace(cs[0], cs[1], r1Row, r2Row, u, kp - u);
+ applyGivensInPlace(cs[0], cs[1], qt1[u], qt2[u - v], 0, r);
+ }
+ }
+ }
+
+ public static void mergeRonQ(double[][] r1, double[][] r2, double[][] qt1,
+ double[][] qt2) {
+
+ int kp = r1[0].length;
+ assert kp == r2[0].length;
+ assert kp == qt1.length;
+ assert kp == qt2.length;
+
+ int r = qt1[0].length;
+ assert qt2[0].length == r;
+ double[] cs = new double[2];
+
+ /*
+ * pairwise givens(a,b) so that a come off main diagonal in r1 and bs come
+ * off u-th upper subdiagonal in r2.
+ */
+ for (int v = 0; v < kp; v++) {
+ for (int u = v; u < kp; u++) {
+ givens(r1[u][u], r2[u - v][u], cs);
+ applyGivensInPlace(cs[0], cs[1], r1[u], r2[u - v], u, kp - u);
+ applyGivensInPlace(cs[0], cs[1], qt1[u], qt2[u - v], 0, r);
+ }
+ }
+ }
+
+ // returns merged Q (which in this case is the qt1)
+ public static double[][] mergeQrUp(double[][] qt1, double[][] r1,
+ double[][] r2) {
+ int kp = qt1.length;
+ int r = qt1[0].length;
+
+ double[][] qTilde = new double[kp][];
+ for (int i = 0; i < kp; i++) {
+ qTilde[i] = new double[r];
+ }
+ mergeRonQ(r1, r2, qt1, qTilde);
+ return qt1;
+ }
+
+ // returns merged Q (which in this case is the qt1)
+ public static double[][] mergeQrUp(double[][] qt1, UpperTriangular r1, UpperTriangular r2) {
+ int kp = qt1.length;
+ int r = qt1[0].length;
+
+ double[][] qTilde = new double[kp][];
+ for (int i = 0; i < kp; i++) {
+ qTilde[i] = new double[r];
+ }
+ mergeRonQ(r1, r2, qt1, qTilde);
+ return qt1;
+ }
+
+ public static double[][] mergeQrDown(double[][] r1, double[][] qt2, double[][] r2) {
+ int kp = qt2.length;
+ int r = qt2[0].length;
+
+ double[][] qTilde = new double[kp][];
+ for (int i = 0; i < kp; i++) {
+ qTilde[i] = new double[r];
+ }
+ mergeRonQ(r1, r2, qTilde, qt2);
+ return qTilde;
+
+ }
+
+ public static double[][] mergeQrDown(UpperTriangular r1, double[][] qt2, UpperTriangular r2) {
+ int kp = qt2.length;
+ int r = qt2[0].length;
+
+ double[][] qTilde = new double[kp][];
+ for (int i = 0; i < kp; i++) {
+ qTilde[i] = new double[r];
+ }
+ mergeRonQ(r1, r2, qTilde, qt2);
+ return qTilde;
+
+ }
+
+ public static double[][] computeQtHat(double[][] qt, int i,
+ Iterator<UpperTriangular> rIter) {
+ UpperTriangular rTilde = rIter.next();
+ for (int j = 1; j < i; j++) {
+ mergeR(rTilde, rIter.next());
+ }
+ if (i > 0) {
+ qt = mergeQrDown(rTilde, qt, rIter.next());
+ }
+ while (rIter.hasNext()) {
+ qt = mergeQrUp(qt, rTilde, rIter.next());
+ }
+ return qt;
+ }
+
+ // test helpers
+ public static boolean isOrthonormal(double[][] qt, boolean insufficientRank, double epsilon) {
+ int n = qt.length;
+ int rank = 0;
+ for (int i = 0; i < n; i++) {
+ Vector ei = new DenseVector(qt[i], true);
+
+ double norm = ei.norm(2);
+
+ if (Math.abs(1.0 - norm) < epsilon) {
+ rank++;
+ } else if (Math.abs(norm) > epsilon) {
+ return false; // not a rank deficiency, either
+ }
+
+ for (int j = 0; j <= i; j++) {
+ Vector ej = new DenseVector(qt[j], true);
+ double dot = ei.dot(ej);
+ if (!(Math.abs((i == j && rank > j ? 1.0 : 0.0) - dot) < epsilon)) {
+ return false;
+ }
+ }
+ }
+ return insufficientRank ? rank < n : rank == n;
+ }
+
+ public static boolean isOrthonormalBlocked(Iterable<double[][]> qtHats,
+ boolean insufficientRank, double epsilon) {
+ int n = qtHats.iterator().next().length;
+ int rank = 0;
+ for (int i = 0; i < n; i++) {
+ List<Vector> ei = Lists.newArrayList();
+ // Vector e_i=new DenseVector (qt[i],true);
+ for (double[][] qtHat : qtHats) {
+ ei.add(new DenseVector(qtHat[i], true));
+ }
+
+ double norm = 0;
+ for (Vector v : ei) {
+ norm += v.dot(v);
+ }
+ norm = Math.sqrt(norm);
+ if (Math.abs(1 - norm) < epsilon) {
+ rank++;
+ } else if (Math.abs(norm) > epsilon) {
+ return false; // not a rank deficiency, either
+ }
+
+ for (int j = 0; j <= i; j++) {
+ List<Vector> ej = Lists.newArrayList();
+ for (double[][] qtHat : qtHats) {
+ ej.add(new DenseVector(qtHat[j], true));
+ }
+
+ // Vector e_j = new DenseVector ( qt[j], true);
+ double dot = 0;
+ for (int k = 0; k < ei.size(); k++) {
+ dot += ei.get(k).dot(ej.get(k));
+ }
+ if (!(Math.abs((i == j && rank > j ? 1 : 0) - dot) < epsilon)) {
+ return false;
+ }
+ }
+ }
+ return insufficientRank ? rank < n : rank == n;
+ }
+
+ private static final class TriangularRowView extends AbstractVector {
+ private final UpperTriangular viewed;
+ private int rowNum;
+
+ private TriangularRowView(UpperTriangular viewed) {
+ super(viewed.columnSize());
+ this.viewed = viewed;
+
+ }
+
+ TriangularRowView setViewedRow(int row) {
+ rowNum = row;
+ return this;
+ }
+
+ @Override
+ public boolean isDense() {
+ return true;
+ }
+
+ @Override
+ public boolean isSequentialAccess() {
+ return false;
+ }
+
+ @Override
+ public Iterator<Element> iterator() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Iterator<Element> iterateNonZero() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double getQuick(int index) {
+ return viewed.getQuick(rowNum, index);
+ }
+
+ @Override
+ public Vector like() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setQuick(int index, double value) {
+ viewed.setQuick(rowNum, index, value);
+
+ }
+
+ @Override
+ public int getNumNondefaultElements() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double getLookupCost() {
+ return 1;
+ }
+
+ @Override
+ public double getIteratorAdvanceCost() {
+ return 1;
+ }
+
+ @Override
+ public boolean isAddConstantTime() {
+ return true;
+ }
+
+ @Override
+ public Matrix matrixLike(int rows, int columns) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * Used internally by assign() to update multiple indices and values at once.
+ * Only really useful for sparse vectors (especially SequentialAccessSparseVector).
+ * <p/>
+ * If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector.
+ *
+ * @param updates a mapping of indices to values to merge in the vector.
+ */
+ @Override
+ public void mergeUpdates(OrderedIntDoubleMapping updates) {
+ int[] indices = updates.getIndices();
+ double[] values = updates.getValues();
+ for (int i = 0; i < updates.getNumMappings(); ++i) {
+ viewed.setQuick(rowNum, indices[i], values[i]);
+ }
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java
new file mode 100644
index 0000000..09be91f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/GramSchmidt.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.hadoop.stochasticsvd.qr;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * Gram Schmidt quick helper.
+ */
+public final class GramSchmidt {
+
+ private GramSchmidt() {
+ }
+
+ public static void orthonormalizeColumns(Matrix mx) {
+
+ int n = mx.numCols();
+
+ for (int c = 0; c < n; c++) {
+ Vector col = mx.viewColumn(c);
+ for (int c1 = 0; c1 < c; c1++) {
+ Vector viewC1 = mx.viewColumn(c1);
+ col.assign(col.minus(viewC1.times(viewC1.dot(col))));
+
+ }
+ final double norm2 = col.norm(2);
+ col.assign(new DoubleFunction() {
+ @Override
+ public double apply(double x) {
+ return x / norm2;
+ }
+ });
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java
new file mode 100644
index 0000000..8509e0a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRFirstStep.java
@@ -0,0 +1,284 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.hadoop.stochasticsvd.qr;
+
+import java.io.Closeable;
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.lib.MultipleOutputs;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.iterator.CopyConstructorIterator;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.stochasticsvd.DenseBlockWritable;
+import org.apache.mahout.math.UpperTriangular;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+/**
+ * QR first step without MR abstractions and doing it just in terms of iterators
+ * and collectors. (although Collector is probably an outdated api).
+ *
+ *
+ */
+@SuppressWarnings("deprecation")
+public class QRFirstStep implements Closeable, OutputCollector<Writable, Vector> {
+
+ public static final String PROP_K = "ssvd.k";
+ public static final String PROP_P = "ssvd.p";
+ public static final String PROP_AROWBLOCK_SIZE = "ssvd.arowblock.size";
+
+ private int kp;
+ private List<double[]> yLookahead;
+ private GivensThinSolver qSolver;
+ private int blockCnt;
+ private final DenseBlockWritable value = new DenseBlockWritable();
+ private final Writable tempKey = new IntWritable();
+ private MultipleOutputs outputs;
+ private final Deque<Closeable> closeables = Lists.newLinkedList();
+ private SequenceFile.Writer tempQw;
+ private Path tempQPath;
+ private final List<UpperTriangular> rSubseq = Lists.newArrayList();
+ private final Configuration jobConf;
+
+ private final OutputCollector<? super Writable, ? super DenseBlockWritable> qtHatOut;
+ private final OutputCollector<? super Writable, ? super VectorWritable> rHatOut;
+
+ public QRFirstStep(Configuration jobConf,
+ OutputCollector<? super Writable, ? super DenseBlockWritable> qtHatOut,
+ OutputCollector<? super Writable, ? super VectorWritable> rHatOut) {
+ this.jobConf = jobConf;
+ this.qtHatOut = qtHatOut;
+ this.rHatOut = rHatOut;
+ setup();
+ }
+
+ @Override
+ public void close() throws IOException {
+ cleanup();
+ }
+
+ public int getKP() {
+ return kp;
+ }
+
+ private void flushSolver() throws IOException {
+ UpperTriangular r = qSolver.getRTilde();
+ double[][] qt = qSolver.getThinQtTilde();
+
+ rSubseq.add(r);
+
+ value.setBlock(qt);
+ getTempQw().append(tempKey, value);
+
+ /*
+ * this probably should be a sparse row matrix, but compressor should get it
+ * for disk and in memory we want it dense anyway, sparse random
+ * implementations would be a mostly a memory management disaster consisting
+ * of rehashes and GC // thrashing. (IMHO)
+ */
+ value.setBlock(null);
+ qSolver.reset();
+ }
+
+ // second pass to run a modified version of computeQHatSequence.
+ private void flushQBlocks() throws IOException {
+ if (blockCnt == 1) {
+ /*
+ * only one block, no temp file, no second pass. should be the default
+ * mode for efficiency in most cases. Sure mapper should be able to load
+ * the entire split in memory -- and we don't require even that.
+ */
+ value.setBlock(qSolver.getThinQtTilde());
+ outputQHat(value);
+ outputR(new VectorWritable(new DenseVector(qSolver.getRTilde().getData(),
+ true)));
+
+ } else {
+ secondPass();
+ }
+ }
+
+ private void outputQHat(DenseBlockWritable value) throws IOException {
+ qtHatOut.collect(NullWritable.get(), value);
+ }
+
+ private void outputR(VectorWritable value) throws IOException {
+ rHatOut.collect(NullWritable.get(), value);
+ }
+
+ private void secondPass() throws IOException {
+ qSolver = null; // release mem
+ FileSystem localFs = FileSystem.getLocal(jobConf);
+ SequenceFile.Reader tempQr =
+ new SequenceFile.Reader(localFs, tempQPath, jobConf);
+ closeables.addFirst(tempQr);
+ int qCnt = 0;
+ while (tempQr.next(tempKey, value)) {
+ value
+ .setBlock(GivensThinSolver.computeQtHat(value.getBlock(),
+ qCnt,
+ new CopyConstructorIterator<>(rSubseq.iterator())));
+ if (qCnt == 1) {
+ /*
+ * just merge r[0] <- r[1] so it doesn't have to repeat in subsequent
+ * computeQHat iterators
+ */
+ GivensThinSolver.mergeR(rSubseq.get(0), rSubseq.remove(1));
+ } else {
+ qCnt++;
+ }
+ outputQHat(value);
+ }
+
+ assert rSubseq.size() == 1;
+
+ outputR(new VectorWritable(new DenseVector(rSubseq.get(0).getData(), true)));
+
+ }
+
+ protected void map(Vector incomingYRow) throws IOException {
+ double[] yRow;
+ if (yLookahead.size() == kp) {
+ if (qSolver.isFull()) {
+
+ flushSolver();
+ blockCnt++;
+
+ }
+ yRow = yLookahead.remove(0);
+
+ qSolver.appendRow(yRow);
+ } else {
+ yRow = new double[kp];
+ }
+
+ if (incomingYRow.isDense()) {
+ for (int i = 0; i < kp; i++) {
+ yRow[i] = incomingYRow.get(i);
+ }
+ } else {
+ Arrays.fill(yRow, 0);
+ for (Element yEl : incomingYRow.nonZeroes()) {
+ yRow[yEl.index()] = yEl.get();
+ }
+ }
+
+ yLookahead.add(yRow);
+ }
+
+ protected void setup() {
+
+ int r = Integer.parseInt(jobConf.get(PROP_AROWBLOCK_SIZE));
+ int k = Integer.parseInt(jobConf.get(PROP_K));
+ int p = Integer.parseInt(jobConf.get(PROP_P));
+ kp = k + p;
+
+ yLookahead = Lists.newArrayListWithCapacity(kp);
+ qSolver = new GivensThinSolver(r, kp);
+ outputs = new MultipleOutputs(new JobConf(jobConf));
+ closeables.addFirst(new Closeable() {
+ @Override
+ public void close() throws IOException {
+ outputs.close();
+ }
+ });
+
+ }
+
+ protected void cleanup() throws IOException {
+ try {
+ if (qSolver == null && yLookahead.isEmpty()) {
+ return;
+ }
+ if (qSolver == null) {
+ qSolver = new GivensThinSolver(yLookahead.size(), kp);
+ }
+ // grow q solver up if necessary
+
+ qSolver.adjust(qSolver.getCnt() + yLookahead.size());
+ while (!yLookahead.isEmpty()) {
+
+ qSolver.appendRow(yLookahead.remove(0));
+
+ }
+ assert qSolver.isFull();
+ if (++blockCnt > 1) {
+ flushSolver();
+ assert tempQw != null;
+ closeables.remove(tempQw);
+ Closeables.close(tempQw, false);
+ }
+ flushQBlocks();
+
+ } finally {
+ IOUtils.close(closeables);
+ }
+
+ }
+
+ private SequenceFile.Writer getTempQw() throws IOException {
+ if (tempQw == null) {
+ /*
+ * temporary Q output hopefully will not exceed size of IO cache in which
+ * case it is only good since it is going to be managed by kernel, not
+ * java GC. And if IO cache is not good enough, then at least it is always
+ * sequential.
+ */
+ String taskTmpDir = System.getProperty("java.io.tmpdir");
+
+ FileSystem localFs = FileSystem.getLocal(jobConf);
+ Path parent = new Path(taskTmpDir);
+ Path sub = new Path(parent, "qw_" + System.currentTimeMillis());
+ tempQPath = new Path(sub, "q-temp.seq");
+ tempQw =
+ SequenceFile.createWriter(localFs,
+ jobConf,
+ tempQPath,
+ IntWritable.class,
+ DenseBlockWritable.class,
+ CompressionType.BLOCK);
+ closeables.addFirst(tempQw);
+ closeables.addFirst(new IOUtils.DeleteFileOnClose(new File(tempQPath
+ .toString())));
+ }
+ return tempQw;
+ }
+
+ @Override
+ public void collect(Writable key, Vector vw) throws IOException {
+ map(vw);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java
new file mode 100644
index 0000000..545f1f9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/qr/QRLastStep.java
@@ -0,0 +1,144 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd.qr;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+import org.apache.commons.lang3.Validate;
+import org.apache.mahout.common.iterator.CopyConstructorIterator;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.stochasticsvd.DenseBlockWritable;
+import org.apache.mahout.math.UpperTriangular;
+
+import com.google.common.collect.Lists;
+
+/**
+ * Second/last step of QR iterations. Takes input of qtHats and rHats and
+ * provides iterator to pull ready rows of final Q.
+ *
+ */
+public class QRLastStep implements Closeable, Iterator<Vector> {
+
+ private final Iterator<DenseBlockWritable> qHatInput;
+
+ private final List<UpperTriangular> mRs = Lists.newArrayList();
+ private final int blockNum;
+ private double[][] mQt;
+ private int cnt;
+ private int r;
+ private int kp;
+ private Vector qRow;
+
+ /**
+ *
+ * @param qHatInput
+ * the Q-Hat input that was output in the first step
+ * @param rHatInput
+ * all RHat outputs int the group in order of groups
+ * @param blockNum
+ * our RHat number in the group
+ */
+ public QRLastStep(Iterator<DenseBlockWritable> qHatInput,
+ Iterator<VectorWritable> rHatInput,
+ int blockNum) {
+ this.blockNum = blockNum;
+ this.qHatInput = qHatInput;
+ /*
+ * in this implementation we actually preload all Rs into memory to make R
+ * sequence modifications more efficient.
+ */
+ int block = 0;
+ while (rHatInput.hasNext()) {
+ Vector value = rHatInput.next().get();
+ if (block < blockNum && block > 0) {
+ GivensThinSolver.mergeR(mRs.get(0), new UpperTriangular(value));
+ } else {
+ mRs.add(new UpperTriangular(value));
+ }
+ block++;
+ }
+
+ }
+
+ private boolean loadNextQt() {
+ boolean more = qHatInput.hasNext();
+ if (!more) {
+ return false;
+ }
+ DenseBlockWritable v = qHatInput.next();
+ mQt =
+ GivensThinSolver
+ .computeQtHat(v.getBlock(),
+ blockNum == 0 ? 0 : 1,
+ new CopyConstructorIterator<>(mRs.iterator()));
+ r = mQt[0].length;
+ kp = mQt.length;
+ if (qRow == null) {
+ qRow = new DenseVector(kp);
+ }
+ return true;
+ }
+
+ @Override
+ public boolean hasNext() {
+ if (mQt != null && cnt == r) {
+ mQt = null;
+ }
+ boolean result = true;
+ if (mQt == null) {
+ result = loadNextQt();
+ cnt = 0;
+ }
+ return result;
+ }
+
+ @Override
+ public Vector next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ Validate.isTrue(hasNext(), "Q input overrun");
+ /*
+ * because Q blocks are initially stored in inverse order
+ */
+ int qRowIndex = r - cnt - 1;
+ for (int j = 0; j < kp; j++) {
+ qRow.setQuick(j, mQt[j][qRowIndex]);
+ }
+ cnt++;
+ return qRow;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void close() throws IOException {
+ mQt = null;
+ mRs.clear();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java
new file mode 100644
index 0000000..51484c7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/BruteSearch.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.PriorityQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Ordering;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+import org.apache.mahout.math.random.WeightedThing;
+
+/**
+ * Search for nearest neighbors using a complete search (i.e. looping through
+ * the references and comparing each vector to the query).
+ */
+public class BruteSearch extends UpdatableSearcher {
+ /**
+ * The list of reference vectors.
+ */
+ private final List<Vector> referenceVectors;
+
+ public BruteSearch(DistanceMeasure distanceMeasure) {
+ super(distanceMeasure);
+ referenceVectors = Lists.newArrayList();
+ }
+
+ @Override
+ public void add(Vector vector) {
+ referenceVectors.add(vector);
+ }
+
+ @Override
+ public int size() {
+ return referenceVectors.size();
+ }
+
+ /**
+ * Scans the list of reference vectors one at a time for @limit neighbors of
+ * the query vector.
+ * The weights of the WeightedVectors are not taken into account.
+ *
+ * @param query The query vector.
+ * @param limit The number of results to returned; must be at least 1.
+ * @return A list of the closest @limit neighbors for the given query.
+ */
+ @Override
+ public List<WeightedThing<Vector>> search(Vector query, int limit) {
+ Preconditions.checkArgument(limit > 0, "limit must be greater then 0!");
+ limit = Math.min(limit, referenceVectors.size());
+ // A priority queue of the best @limit elements, ordered from worst to best so that the worst
+ // element is always on top and can easily be removed.
+ PriorityQueue<WeightedThing<Integer>> bestNeighbors =
+ new PriorityQueue<>(limit, Ordering.natural().reverse());
+ // The resulting list of weighted WeightedVectors (the weight is the distance from the query).
+ List<WeightedThing<Vector>> results =
+ Lists.newArrayListWithCapacity(limit);
+ int rowNumber = 0;
+ for (Vector row : referenceVectors) {
+ double distance = distanceMeasure.distance(query, row);
+ // Only add a new neighbor if the result is better than the worst element
+ // in the queue or the queue isn't full.
+ if (bestNeighbors.size() < limit || bestNeighbors.peek().getWeight() > distance) {
+ bestNeighbors.add(new WeightedThing<>(rowNumber, distance));
+ if (bestNeighbors.size() > limit) {
+ bestNeighbors.poll();
+ } else {
+ // Increase the size of the results list by 1 so we can add elements in the reverse
+ // order from the queue.
+ results.add(null);
+ }
+ }
+ ++rowNumber;
+ }
+ for (int i = limit - 1; i >= 0; --i) {
+ WeightedThing<Integer> neighbor = bestNeighbors.poll();
+ results.set(i, new WeightedThing<>(
+ referenceVectors.get(neighbor.getValue()), neighbor.getWeight()));
+ }
+ return results;
+ }
+
+ /**
+ * Returns the closest vector to the query.
+ * When only one the nearest vector is needed, use this method, NOT search(query, limit) because
+ * it's faster (less overhead).
+ *
+ * @param query the vector to search for
+ * @param differentThanQuery if true, returns the closest vector different than the query (this
+ * only matters if the query is among the searched vectors), otherwise,
+ * returns the closest vector to the query (even the same vector).
+ * @return the weighted vector closest to the query
+ */
+ @Override
+ public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) {
+ double bestDistance = Double.POSITIVE_INFINITY;
+ Vector bestVector = null;
+ for (Vector row : referenceVectors) {
+ double distance = distanceMeasure.distance(query, row);
+ if (distance < bestDistance && (!differentThanQuery || !row.equals(query))) {
+ bestDistance = distance;
+ bestVector = row;
+ }
+ }
+ return new WeightedThing<>(bestVector, bestDistance);
+ }
+
+ /**
+ * Searches with a list full of queries in a threaded fashion.
+ *
+ * @param queries The queries to search for.
+ * @param limit The number of results to return.
+ * @param numThreads Number of threads to use in searching.
+ * @return A list of result lists.
+ */
+ public List<List<WeightedThing<Vector>>> search(Iterable<WeightedVector> queries,
+ final int limit, int numThreads) throws InterruptedException {
+ ExecutorService executor = Executors.newFixedThreadPool(numThreads);
+ List<Callable<Object>> tasks = Lists.newArrayList();
+
+ final List<List<WeightedThing<Vector>>> results = Lists.newArrayList();
+ int i = 0;
+ for (final Vector query : queries) {
+ results.add(null);
+ final int index = i++;
+ tasks.add(new Callable<Object>() {
+ @Override
+ public Object call() throws Exception {
+ results.set(index, BruteSearch.this.search(query, limit));
+ return null;
+ }
+ });
+ }
+
+ executor.invokeAll(tasks);
+ executor.shutdown();
+
+ return results;
+ }
+
+ @Override
+ public Iterator<Vector> iterator() {
+ return referenceVectors.iterator();
+ }
+
+ @Override
+ public boolean remove(Vector query, double epsilon) {
+ int rowNumber = 0;
+ for (Vector row : referenceVectors) {
+ double distance = distanceMeasure.distance(query, row);
+ if (distance < epsilon) {
+ referenceVectors.remove(rowNumber);
+ return true;
+ }
+ rowNumber++;
+ }
+ return false;
+ }
+
+ @Override
+ public void clear() {
+ referenceVectors.clear();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java
new file mode 100644
index 0000000..006f4b6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/FastProjectionSearch.java
@@ -0,0 +1,326 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.random.RandomProjector;
+import org.apache.mahout.math.random.WeightedThing;
+
+/**
+ * Does approximate nearest neighbor search by projecting the vectors similar to ProjectionSearch.
+ * The main difference between this class and the ProjectionSearch is the use of sorted arrays
+ * instead of binary search trees to implement the sets of scalar projections.
+ *
+ * Instead of taking log n time to add a vector to each of the vectors, * the pending additions are
+ * kept separate and are searched using a brute search. When there are "enough" pending additions,
+ * they're committed into the main pool of vectors.
+ */
+public class FastProjectionSearch extends UpdatableSearcher {
+ // The list of vectors that have not yet been projected (that are pending).
+ private final List<Vector> pendingAdditions = Lists.newArrayList();
+
+ // The list of basis vectors. Populated when the first vector's dimension is know by calling
+ // initialize once.
+ private Matrix basisMatrix = null;
+
+ // The list of sorted lists of scalar projections. The outer list has one entry for each basis
+ // vector that all the other vectors will be projected on.
+ // For each basis vector, the inner list has an entry for each vector that has been projected.
+ // These entries are WeightedThing<Vector> where the weight is the value of the scalar
+ // projection and the value is the vector begin referred to.
+ private List<List<WeightedThing<Vector>>> scalarProjections;
+
+ // The number of projection used for approximating the distance.
+ private final int numProjections;
+
+ // The number of elements to keep on both sides of the closest estimated distance as possible
+ // candidates for the best actual distance.
+ private final int searchSize;
+
+ // Initially, the dimension of the vectors searched by this searcher is unknown. After adding
+ // the first vector, the basis will be initialized. This marks whether initialization has
+ // happened or not so we only do it once.
+ private boolean initialized = false;
+
+ // Removing vectors from the searcher is done lazily to avoid the linear time cost of removing
+ // elements from an array. This member keeps track of the number of removed vectors (marked as
+ // "impossible" values in the array) so they can be removed when updating the structure.
+ private int numPendingRemovals = 0;
+
+ private static final double ADDITION_THRESHOLD = 0.05;
+ private static final double REMOVAL_THRESHOLD = 0.02;
+
+ public FastProjectionSearch(DistanceMeasure distanceMeasure, int numProjections, int searchSize) {
+ super(distanceMeasure);
+ Preconditions.checkArgument(numProjections > 0 && numProjections < 100,
+ "Unreasonable value for number of projections. Must be: 0 < numProjections < 100");
+ this.numProjections = numProjections;
+ this.searchSize = searchSize;
+ scalarProjections = Lists.newArrayListWithCapacity(numProjections);
+ for (int i = 0; i < numProjections; ++i) {
+ scalarProjections.add(Lists.<WeightedThing<Vector>>newArrayList());
+ }
+ }
+
+ private void initialize(int numDimensions) {
+ if (initialized) {
+ return;
+ }
+ basisMatrix = RandomProjector.generateBasisNormal(numProjections, numDimensions);
+ initialized = true;
+ }
+
+ /**
+ * Add a new Vector to the Searcher that will be checked when getting
+ * the nearest neighbors.
+ * <p/>
+ * The vector IS NOT CLONED. Do not modify the vector externally otherwise the internal
+ * Searcher data structures could be invalidated.
+ */
+ @Override
+ public void add(Vector vector) {
+ initialize(vector.size());
+ pendingAdditions.add(vector);
+ }
+
+ /**
+ * Returns the number of WeightedVectors being searched for nearest neighbors.
+ */
+ @Override
+ public int size() {
+ return pendingAdditions.size() + scalarProjections.get(0).size() - numPendingRemovals;
+ }
+
+ /**
+ * When querying the Searcher for the closest vectors, a list of WeightedThing<Vector>s is
+ * returned. The value of the WeightedThing is the neighbor and the weight is the
+ * the distance (calculated by some metric - see a concrete implementation) between the query
+ * and neighbor.
+ * The actual type of vector in the pair is the same as the vector added to the Searcher.
+ */
+ @Override
+ public List<WeightedThing<Vector>> search(Vector query, int limit) {
+ reindex(false);
+
+ Set<Vector> candidates = Sets.newHashSet();
+ Vector projection = basisMatrix.times(query);
+ for (int i = 0; i < basisMatrix.numRows(); ++i) {
+ List<WeightedThing<Vector>> currProjections = scalarProjections.get(i);
+ int middle = Collections.binarySearch(currProjections,
+ new WeightedThing<Vector>(projection.get(i)));
+ if (middle < 0) {
+ middle = -(middle + 1);
+ }
+ for (int j = Math.max(0, middle - searchSize);
+ j < Math.min(currProjections.size(), middle + searchSize + 1); ++j) {
+ if (currProjections.get(j).getValue() == null) {
+ continue;
+ }
+ candidates.add(currProjections.get(j).getValue());
+ }
+ }
+
+ List<WeightedThing<Vector>> top =
+ Lists.newArrayListWithCapacity(candidates.size() + pendingAdditions.size());
+ for (Vector candidate : Iterables.concat(candidates, pendingAdditions)) {
+ top.add(new WeightedThing<>(candidate, distanceMeasure.distance(candidate, query)));
+ }
+ Collections.sort(top);
+
+ return top.subList(0, Math.min(top.size(), limit));
+ }
+
+ /**
+ * Returns the closest vector to the query.
+ * When only one the nearest vector is needed, use this method, NOT search(query, limit) because
+ * it's faster (less overhead).
+ *
+ * @param query the vector to search for
+ * @param differentThanQuery if true, returns the closest vector different than the query (this
+ * only matters if the query is among the searched vectors), otherwise,
+ * returns the closest vector to the query (even the same vector).
+ * @return the weighted vector closest to the query
+ */
+ @Override
+ public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) {
+ reindex(false);
+
+ double bestDistance = Double.POSITIVE_INFINITY;
+ Vector bestVector = null;
+
+ Vector projection = basisMatrix.times(query);
+ for (int i = 0; i < basisMatrix.numRows(); ++i) {
+ List<WeightedThing<Vector>> currProjections = scalarProjections.get(i);
+ int middle = Collections.binarySearch(currProjections,
+ new WeightedThing<Vector>(projection.get(i)));
+ if (middle < 0) {
+ middle = -(middle + 1);
+ }
+ for (int j = Math.max(0, middle - searchSize);
+ j < Math.min(currProjections.size(), middle + searchSize + 1); ++j) {
+ if (currProjections.get(j).getValue() == null) {
+ continue;
+ }
+ Vector vector = currProjections.get(j).getValue();
+ double distance = distanceMeasure.distance(vector, query);
+ if (distance < bestDistance && (!differentThanQuery || !vector.equals(query))) {
+ bestDistance = distance;
+ bestVector = vector;
+ }
+ }
+ }
+
+ for (Vector vector : pendingAdditions) {
+ double distance = distanceMeasure.distance(vector, query);
+ if (distance < bestDistance && (!differentThanQuery || !vector.equals(query))) {
+ bestDistance = distance;
+ bestVector = vector;
+ }
+ }
+
+ return new WeightedThing<>(bestVector, bestDistance);
+ }
+
+ @Override
+ public boolean remove(Vector vector, double epsilon) {
+ WeightedThing<Vector> closestPair = searchFirst(vector, false);
+ if (distanceMeasure.distance(closestPair.getValue(), vector) > epsilon) {
+ return false;
+ }
+
+ boolean isProjected = true;
+ Vector projection = basisMatrix.times(vector);
+ for (int i = 0; i < basisMatrix.numRows(); ++i) {
+ List<WeightedThing<Vector>> currProjections = scalarProjections.get(i);
+ WeightedThing<Vector> searchedThing = new WeightedThing<>(projection.get(i));
+ int middle = Collections.binarySearch(currProjections, searchedThing);
+ if (middle < 0) {
+ isProjected = false;
+ break;
+ }
+ // Elements to be removed are kept in the sorted array until the next reindex, but their inner vector
+ // is set to null.
+ scalarProjections.get(i).set(middle, searchedThing);
+ }
+ if (isProjected) {
+ ++numPendingRemovals;
+ return true;
+ }
+
+ for (int i = 0; i < pendingAdditions.size(); ++i) {
+ if (pendingAdditions.get(i).equals(vector)) {
+ pendingAdditions.remove(i);
+ break;
+ }
+ }
+ return true;
+ }
+
+ private void reindex(boolean force) {
+ int numProjected = scalarProjections.get(0).size();
+ if (force || pendingAdditions.size() > ADDITION_THRESHOLD * numProjected
+ || numPendingRemovals > REMOVAL_THRESHOLD * numProjected) {
+
+ // We only need to copy the first list because when iterating we use only that list for the Vector
+ // references.
+ // see public Iterator<Vector> iterator()
+ List<List<WeightedThing<Vector>>> scalarProjections = Lists.newArrayListWithCapacity(numProjections);
+ for (int i = 0; i < numProjections; ++i) {
+ if (i == 0) {
+ scalarProjections.add(Lists.newArrayList(this.scalarProjections.get(i)));
+ } else {
+ scalarProjections.add(this.scalarProjections.get(i));
+ }
+ }
+
+ // Project every pending vector onto every basis vector.
+ for (Vector pending : pendingAdditions) {
+ Vector projection = basisMatrix.times(pending);
+ for (int i = 0; i < numProjections; ++i) {
+ scalarProjections.get(i).add(new WeightedThing<>(pending, projection.get(i)));
+ }
+ }
+ pendingAdditions.clear();
+ // For each basis vector, sort the resulting list (for binary search) and remove the number
+ // of pending removals (it's the same for every basis vector) at the end (the weights are
+ // set to Double.POSITIVE_INFINITY when removing).
+ for (int i = 0; i < numProjections; ++i) {
+ List<WeightedThing<Vector>> currProjections = scalarProjections.get(i);
+ for (WeightedThing<Vector> v : currProjections) {
+ if (v.getValue() == null) {
+ v.setWeight(Double.POSITIVE_INFINITY);
+ }
+ }
+ Collections.sort(currProjections);
+ for (int j = 0; j < numPendingRemovals; ++j) {
+ currProjections.remove(currProjections.size() - 1);
+ }
+ }
+ numPendingRemovals = 0;
+
+ this.scalarProjections = scalarProjections;
+ }
+ }
+
+ @Override
+ public void clear() {
+ pendingAdditions.clear();
+ for (int i = 0; i < numProjections; ++i) {
+ scalarProjections.get(i).clear();
+ }
+ numPendingRemovals = 0;
+ }
+
+ /**
+ * This iterates on the snapshot of the contents first instantiated regardless of any future modifications.
+ * Changes done after the iterator is created will not be visible to the iterator but will be visible
+ * when searching.
+ * @return iterator through the vectors in this searcher.
+ */
+ @Override
+ public Iterator<Vector> iterator() {
+ reindex(true);
+ return new AbstractIterator<Vector>() {
+ private final Iterator<WeightedThing<Vector>> data = scalarProjections.get(0).iterator();
+ @Override
+ protected Vector computeNext() {
+ do {
+ if (!data.hasNext()) {
+ return endOfData();
+ }
+ WeightedThing<Vector> next = data.next();
+ if (next.getValue() != null) {
+ return next.getValue();
+ }
+ } while (true);
+ }
+ };
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
new file mode 100644
index 0000000..eb91813
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/HashedVector.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+
+/**
+ * Decorates a weighted vector with a locality sensitive hash.
+ *
+ * The LSH function implemented is the random hyperplane based hash function.
+ * See "Similarity Estimation Techniques from Rounding Algorithms" by Moses S. Charikar, section 3.
+ * http://www.cs.princeton.edu/courses/archive/spring04/cos598B/bib/CharikarEstim.pdf
+ */
+public class HashedVector extends WeightedVector {
+ protected static final int INVALID_INDEX = -1;
+
+ /**
+ * Value of the locality sensitive hash. It is 64 bit.
+ */
+ private final long hash;
+
+ public HashedVector(Vector vector, long hash, int index) {
+ super(vector, 1, index);
+ this.hash = hash;
+ }
+
+ public HashedVector(Vector vector, Matrix projection, int index, long mask) {
+ super(vector, 1, index);
+ this.hash = mask & computeHash64(vector, projection);
+ }
+
+ public HashedVector(WeightedVector weightedVector, Matrix projection, long mask) {
+ super(weightedVector.getVector(), weightedVector.getWeight(), weightedVector.getIndex());
+ this.hash = mask & computeHash64(weightedVector, projection);
+ }
+
+ public static long computeHash64(Vector vector, Matrix projection) {
+ long hash = 0;
+ for (Element element : projection.times(vector).nonZeroes()) {
+ if (element.get() > 0) {
+ hash += 1L << element.index();
+ }
+ }
+ return hash;
+ }
+
+ public static HashedVector hash(WeightedVector v, Matrix projection) {
+ return hash(v, projection, 0);
+ }
+
+ public static HashedVector hash(WeightedVector v, Matrix projection, long mask) {
+ return new HashedVector(v, projection, mask);
+ }
+
+ public int hammingDistance(long otherHash) {
+ return Long.bitCount(hash ^ otherHash);
+ }
+
+ public long getHash() {
+ return hash;
+ }
+
+ @Override
+ public String toString() {
+ return String.format("index=%d, hash=%08x, v=%s", getIndex(), hash, getVector());
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof HashedVector)) {
+ return o instanceof Vector && this.minus((Vector) o).norm(1) == 0;
+ }
+ HashedVector v = (HashedVector) o;
+ return v.hash == this.hash && this.minus(v).norm(1) == 0;
+ }
+
+ @Override
+ public int hashCode() {
+ int result = super.hashCode();
+ result = 31 * result + (int) (hash ^ (hash >>> 32));
+ return result;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java b/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
new file mode 100644
index 0000000..aa1f103
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/neighborhood/LocalitySensitiveHashSearch.java
@@ -0,0 +1,295 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.neighborhood;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import org.apache.lucene.util.PriorityQueue;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.random.RandomProjector;
+import org.apache.mahout.math.random.WeightedThing;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+/**
+ * Implements a Searcher that uses locality sensitivity hash as a first pass approximation
+ * to estimate distance without floating point math. The clever bit about this implementation
+ * is that it does an adaptive cutoff for the cutoff on the bitwise distance. Making this
+ * cutoff adaptive means that we only needs to make a single pass through the data.
+ */
+public class LocalitySensitiveHashSearch extends UpdatableSearcher {
+ /**
+ * Number of bits in the locality sensitive hash. 64 bits fix neatly into a long.
+ */
+ private static final int BITS = 64;
+
+ /**
+ * Bit mask for the computed hash. Currently, it's 0xffffffffffff.
+ */
+ private static final long BIT_MASK = -1L;
+
+ /**
+ * The maximum Hamming distance between two hashes that the hash limit can grow back to.
+ * It starts at BITS and decreases as more points than are needed are added to the candidate priority queue.
+ * But, after the observed distribution of distances becomes too good (we're seeing less than some percentage of the
+ * total number of points; using the hash strategy somewhere less than 25%) the limit is increased to compute
+ * more distances.
+ * This is because
+ */
+ private static final int MAX_HASH_LIMIT = 32;
+
+ /**
+ * Minimum number of points with a given Hamming from the query that must be observed to consider raising the minimum
+ * distance for a candidate.
+ */
+ private static final int MIN_DISTRIBUTION_COUNT = 10;
+
+ private final Multiset<HashedVector> trainingVectors = HashMultiset.create();
+
+ /**
+ * This matrix of BITS random vectors is used to compute the Locality Sensitive Hash
+ * we compute the dot product with these vectors using a matrix multiplication and then use just
+ * sign of each result as one bit in the hash
+ */
+ private Matrix projection;
+
+ /**
+ * The search size determines how many top results we retain. We do this because the hash distance
+ * isn't guaranteed to be entirely monotonic with respect to the real distance. To the extent that
+ * actual distance is well approximated by hash distance, then the searchSize can be decreased to
+ * roughly the number of results that you want.
+ */
+ private int searchSize;
+
+ /**
+ * Controls how the hash limit is raised. 0 means use minimum of distribution, 1 means use first quartile.
+ * Intermediate values indicate an interpolation should be used. Negative values mean to never increase.
+ */
+ private double hashLimitStrategy = 0.9;
+
+ /**
+ * Number of evaluations of the full distance between two points that was required.
+ */
+ private int distanceEvaluations = 0;
+
+ /**
+ * Whether the projection matrix was initialized. This has to be deferred until the size of the vectors is known,
+ * effectively until the first vector is added.
+ */
+ private boolean initialized = false;
+
+ public LocalitySensitiveHashSearch(DistanceMeasure distanceMeasure, int searchSize) {
+ super(distanceMeasure);
+ this.searchSize = searchSize;
+ this.projection = null;
+ }
+
+ private void initialize(int numDimensions) {
+ if (initialized) {
+ return;
+ }
+ initialized = true;
+ projection = RandomProjector.generateBasisNormal(BITS, numDimensions);
+ }
+
+ private PriorityQueue<WeightedThing<Vector>> searchInternal(Vector query) {
+ long queryHash = HashedVector.computeHash64(query, projection);
+
+ // We keep an approximation of the closest vectors here.
+ PriorityQueue<WeightedThing<Vector>> top = Searcher.getCandidateQueue(getSearchSize());
+
+ // We scan the vectors using bit counts as an approximation of the dot product so we can do as few
+ // full distance computations as possible. Our goal is to only do full distance computations for
+ // vectors with hash distance at most as large as the searchSize biggest hash distance seen so far.
+
+ OnlineSummarizer[] distribution = new OnlineSummarizer[BITS + 1];
+ for (int i = 0; i < BITS + 1; i++) {
+ distribution[i] = new OnlineSummarizer();
+ }
+
+ distanceEvaluations = 0;
+
+ // We keep the counts of the hash distances here. This lets us accurately
+ // judge what hash distance cutoff we should use.
+ int[] hashCounts = new int[BITS + 1];
+
+ // Maximum number of different bits to still consider a vector a candidate for nearest neighbor.
+ // Starts at the maximum number of bits, but decreases and can increase.
+ int hashLimit = BITS;
+ int limitCount = 0;
+ double distanceLimit = Double.POSITIVE_INFINITY;
+
+ // In this loop, we have the invariants that:
+ //
+ // limitCount = sum_{i<hashLimit} hashCount[i]
+ // and
+ // limitCount >= searchSize && limitCount - hashCount[hashLimit-1] < searchSize
+ for (HashedVector vector : trainingVectors) {
+ // This computes the Hamming Distance between the vector's hash and the query's hash.
+ // The result is correlated with the angle between the vectors.
+ int bitDot = vector.hammingDistance(queryHash);
+ if (bitDot <= hashLimit) {
+ distanceEvaluations++;
+
+ double distance = distanceMeasure.distance(query, vector);
+ distribution[bitDot].add(distance);
+
+ if (distance < distanceLimit) {
+ top.insertWithOverflow(new WeightedThing<Vector>(vector, distance));
+ if (top.size() == searchSize) {
+ distanceLimit = top.top().getWeight();
+ }
+
+ hashCounts[bitDot]++;
+ limitCount++;
+ while (hashLimit > 0 && limitCount - hashCounts[hashLimit - 1] > searchSize) {
+ hashLimit--;
+ limitCount -= hashCounts[hashLimit];
+ }
+
+ if (hashLimitStrategy >= 0) {
+ while (hashLimit < MAX_HASH_LIMIT && distribution[hashLimit].getCount() > MIN_DISTRIBUTION_COUNT
+ && ((1 - hashLimitStrategy) * distribution[hashLimit].getQuartile(0)
+ + hashLimitStrategy * distribution[hashLimit].getQuartile(1)) < distanceLimit) {
+ limitCount += hashCounts[hashLimit];
+ hashLimit++;
+ }
+ }
+ }
+ }
+ }
+ return top;
+ }
+
+ @Override
+ public List<WeightedThing<Vector>> search(Vector query, int limit) {
+ PriorityQueue<WeightedThing<Vector>> top = searchInternal(query);
+ List<WeightedThing<Vector>> results = Lists.newArrayListWithExpectedSize(top.size());
+ while (top.size() != 0) {
+ WeightedThing<Vector> wv = top.pop();
+ results.add(new WeightedThing<>(((HashedVector) wv.getValue()).getVector(), wv.getWeight()));
+ }
+ Collections.reverse(results);
+ if (limit < results.size()) {
+ results = results.subList(0, limit);
+ }
+ return results;
+ }
+
+ /**
+ * Returns the closest vector to the query.
+ * When only one the nearest vector is needed, use this method, NOT search(query, limit) because
+ * it's faster (less overhead).
+ * This is nearly the same as search().
+ *
+ * @param query the vector to search for
+ * @param differentThanQuery if true, returns the closest vector different than the query (this
+ * only matters if the query is among the searched vectors), otherwise,
+ * returns the closest vector to the query (even the same vector).
+ * @return the weighted vector closest to the query
+ */
+ @Override
+ public WeightedThing<Vector> searchFirst(Vector query, boolean differentThanQuery) {
+ // We get the top searchSize neighbors.
+ PriorityQueue<WeightedThing<Vector>> top = searchInternal(query);
+ // We then cut the number down to just the best 2.
+ while (top.size() > 2) {
+ top.pop();
+ }
+ // If there are fewer than 2 results, we just return the one we have.
+ if (top.size() < 2) {
+ return removeHash(top.pop());
+ }
+ // There are exactly 2 results.
+ WeightedThing<Vector> secondBest = top.pop();
+ WeightedThing<Vector> best = top.pop();
+ // If the best result is the same as the query, but we don't want to return the query.
+ if (differentThanQuery && best.getValue().equals(query)) {
+ best = secondBest;
+ }
+ return removeHash(best);
+ }
+
+ protected static WeightedThing<Vector> removeHash(WeightedThing<Vector> input) {
+ return new WeightedThing<>(((HashedVector) input.getValue()).getVector(), input.getWeight());
+ }
+
+ @Override
+ public void add(Vector vector) {
+ initialize(vector.size());
+ trainingVectors.add(new HashedVector(vector, projection, HashedVector.INVALID_INDEX, BIT_MASK));
+ }
+
+ @Override
+ public int size() {
+ return trainingVectors.size();
+ }
+
+ public int getSearchSize() {
+ return searchSize;
+ }
+
+ public void setSearchSize(int size) {
+ searchSize = size;
+ }
+
+ public void setRaiseHashLimitStrategy(double strategy) {
+ hashLimitStrategy = strategy;
+ }
+
+ /**
+ * This is only for testing.
+ * @return the number of times the actual distance between two vectors was computed.
+ */
+ public int resetEvaluationCount() {
+ int result = distanceEvaluations;
+ distanceEvaluations = 0;
+ return result;
+ }
+
+ @Override
+ public Iterator<Vector> iterator() {
+ return Iterators.transform(trainingVectors.iterator(), new Function<HashedVector, Vector>() {
+ @Override
+ public Vector apply(org.apache.mahout.math.neighborhood.HashedVector input) {
+ Preconditions.checkNotNull(input);
+ //noinspection ConstantConditions
+ return input.getVector();
+ }
+ });
+ }
+
+ @Override
+ public boolean remove(Vector v, double epsilon) {
+ return trainingVectors.remove(new HashedVector(v, projection, HashedVector.INVALID_INDEX, BIT_MASK));
+ }
+
+ @Override
+ public void clear() {
+ trainingVectors.clear();
+ }
+}
[12/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJobTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJobTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJobTest.java
new file mode 100644
index 0000000..1326777
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/RecommenderJobTest.java
@@ -0,0 +1,928 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Counter;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.cf.taste.hadoop.EntityPrefWritable;
+import org.apache.mahout.cf.taste.hadoop.MutableRecommendedItem;
+import org.apache.mahout.cf.taste.hadoop.RecommendedItemsWritable;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.hadoop.ToItemPrefsMapper;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.recommender.GenericRecommendedItem;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.FileLineIterable;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.CooccurrenceCountSimilarity;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.TanimotoCoefficientSimilarity;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+import org.easymock.IArgumentMatcher;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+public class RecommenderJobTest extends TasteTestCase {
+
+ /**
+ * tests {@link ItemIDIndexMapper}
+ */
+ @Test
+ public void testItemIDIndexMapper() throws Exception {
+ Mapper<LongWritable,Text, VarIntWritable, VarLongWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+
+ context.write(new VarIntWritable(TasteHadoopUtils.idToIndex(789L)), new VarLongWritable(789L));
+ EasyMock.replay(context);
+
+ new ItemIDIndexMapper().map(new LongWritable(123L), new Text("456,789,5.0"), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * tests {@link ItemIDIndexReducer}
+ */
+ @Test
+ public void testItemIDIndexReducer() throws Exception {
+ Reducer<VarIntWritable, VarLongWritable, VarIntWritable,VarLongWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+
+ context.write(new VarIntWritable(123), new VarLongWritable(45L));
+ EasyMock.replay(context);
+
+ new ItemIDIndexReducer().reduce(new VarIntWritable(123), Arrays.asList(new VarLongWritable(67L),
+ new VarLongWritable(89L), new VarLongWritable(45L)), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * tests {@link ToItemPrefsMapper}
+ */
+ @Test
+ public void testToItemPrefsMapper() throws Exception {
+ Mapper<LongWritable,Text, VarLongWritable,VarLongWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+
+ context.write(new VarLongWritable(12L), new EntityPrefWritable(34L, 1.0f));
+ context.write(new VarLongWritable(56L), new EntityPrefWritable(78L, 2.0f));
+ EasyMock.replay(context);
+
+ ToItemPrefsMapper mapper = new ToItemPrefsMapper();
+ mapper.map(new LongWritable(123L), new Text("12,34,1"), context);
+ mapper.map(new LongWritable(456L), new Text("56,78,2"), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * tests {@link ToItemPrefsMapper} using boolean data
+ */
+ @Test
+ public void testToItemPrefsMapperBooleanData() throws Exception {
+ Mapper<LongWritable,Text, VarLongWritable,VarLongWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+
+ context.write(new VarLongWritable(12L), new VarLongWritable(34L));
+ context.write(new VarLongWritable(56L), new VarLongWritable(78L));
+ EasyMock.replay(context);
+
+ ToItemPrefsMapper mapper = new ToItemPrefsMapper();
+ setField(mapper, "booleanData", true);
+ mapper.map(new LongWritable(123L), new Text("12,34"), context);
+ mapper.map(new LongWritable(456L), new Text("56,78"), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * tests {@link ToUserVectorsReducer}
+ */
+ @Test
+ public void testToUserVectorReducer() throws Exception {
+ Reducer<VarLongWritable,VarLongWritable,VarLongWritable,VectorWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+ Counter userCounters = EasyMock.createMock(Counter.class);
+
+ EasyMock.expect(context.getCounter(ToUserVectorsReducer.Counters.USERS)).andReturn(userCounters);
+ userCounters.increment(1);
+ context.write(EasyMock.eq(new VarLongWritable(12L)), MathHelper.vectorMatches(
+ MathHelper.elem(TasteHadoopUtils.idToIndex(34L), 1.0), MathHelper.elem(TasteHadoopUtils.idToIndex(56L), 2.0)));
+
+ EasyMock.replay(context, userCounters);
+
+ Collection<VarLongWritable> varLongWritables = Lists.newLinkedList();
+ varLongWritables.add(new EntityPrefWritable(34L, 1.0f));
+ varLongWritables.add(new EntityPrefWritable(56L, 2.0f));
+
+ new ToUserVectorsReducer().reduce(new VarLongWritable(12L), varLongWritables, context);
+
+ EasyMock.verify(context, userCounters);
+ }
+
+ /**
+ * tests {@link ToUserVectorsReducer} using boolean data
+ */
+ @Test
+ public void testToUserVectorReducerWithBooleanData() throws Exception {
+ Reducer<VarLongWritable,VarLongWritable,VarLongWritable,VectorWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+ Counter userCounters = EasyMock.createMock(Counter.class);
+
+ EasyMock.expect(context.getCounter(ToUserVectorsReducer.Counters.USERS)).andReturn(userCounters);
+ userCounters.increment(1);
+ context.write(EasyMock.eq(new VarLongWritable(12L)), MathHelper.vectorMatches(
+ MathHelper.elem(TasteHadoopUtils.idToIndex(34L), 1.0), MathHelper.elem(TasteHadoopUtils.idToIndex(56L), 1.0)));
+
+ EasyMock.replay(context, userCounters);
+
+ new ToUserVectorsReducer().reduce(new VarLongWritable(12L), Arrays.asList(new VarLongWritable(34L),
+ new VarLongWritable(56L)), context);
+
+ EasyMock.verify(context, userCounters);
+ }
+
+ /**
+ * tests {@link SimilarityMatrixRowWrapperMapper}
+ */
+ @Test
+ public void testSimilarityMatrixRowWrapperMapper() throws Exception {
+ Mapper<IntWritable,VectorWritable,VarIntWritable,VectorOrPrefWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+
+ context.write(EasyMock.eq(new VarIntWritable(12)), vectorOfVectorOrPrefWritableMatches(MathHelper.elem(34, 0.5),
+ MathHelper.elem(56, 0.7)));
+
+ EasyMock.replay(context);
+
+ RandomAccessSparseVector vector = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ vector.set(12, 1.0);
+ vector.set(34, 0.5);
+ vector.set(56, 0.7);
+
+ new SimilarityMatrixRowWrapperMapper().map(new IntWritable(12), new VectorWritable(vector), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * verifies the {@link Vector} included in a {@link VectorOrPrefWritable}
+ */
+ private static VectorOrPrefWritable vectorOfVectorOrPrefWritableMatches(final Vector.Element... elements) {
+ EasyMock.reportMatcher(new IArgumentMatcher() {
+ @Override
+ public boolean matches(Object argument) {
+ if (argument instanceof VectorOrPrefWritable) {
+ Vector v = ((VectorOrPrefWritable) argument).getVector();
+ return MathHelper.consistsOf(v, elements);
+ }
+ return false;
+ }
+
+ @Override
+ public void appendTo(StringBuffer buffer) {}
+ });
+ return null;
+ }
+
+ /**
+ * tests {@link UserVectorSplitterMapper}
+ */
+ @Test
+ public void testUserVectorSplitterMapper() throws Exception {
+ Mapper<VarLongWritable,VectorWritable, VarIntWritable,VectorOrPrefWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+
+ context.write(EasyMock.eq(new VarIntWritable(34)), prefOfVectorOrPrefWritableMatches(123L, 0.5f));
+ context.write(EasyMock.eq(new VarIntWritable(56)), prefOfVectorOrPrefWritableMatches(123L, 0.7f));
+
+ EasyMock.replay(context);
+
+ UserVectorSplitterMapper mapper = new UserVectorSplitterMapper();
+ setField(mapper, "maxPrefsPerUserConsidered", 10);
+
+ RandomAccessSparseVector vector = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ vector.set(34, 0.5);
+ vector.set(56, 0.7);
+
+ mapper.map(new VarLongWritable(123L), new VectorWritable(vector), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * verifies a preference in a {@link VectorOrPrefWritable}
+ */
+ private static VectorOrPrefWritable prefOfVectorOrPrefWritableMatches(final long userID, final float prefValue) {
+ EasyMock.reportMatcher(new IArgumentMatcher() {
+ @Override
+ public boolean matches(Object argument) {
+ if (argument instanceof VectorOrPrefWritable) {
+ VectorOrPrefWritable pref = (VectorOrPrefWritable) argument;
+ return pref.getUserID() == userID && pref.getValue() == prefValue;
+ }
+ return false;
+ }
+
+ @Override
+ public void appendTo(StringBuffer buffer) {}
+ });
+ return null;
+ }
+
+ /**
+ * tests {@link UserVectorSplitterMapper} in the special case that some userIDs shall be excluded
+ */
+ @Test
+ public void testUserVectorSplitterMapperUserExclusion() throws Exception {
+ Mapper<VarLongWritable,VectorWritable, VarIntWritable,VectorOrPrefWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+
+ context.write(EasyMock.eq(new VarIntWritable(34)), prefOfVectorOrPrefWritableMatches(123L, 0.5f));
+ context.write(EasyMock.eq(new VarIntWritable(56)), prefOfVectorOrPrefWritableMatches(123L, 0.7f));
+
+ EasyMock.replay(context);
+
+ FastIDSet usersToRecommendFor = new FastIDSet();
+ usersToRecommendFor.add(123L);
+
+ UserVectorSplitterMapper mapper = new UserVectorSplitterMapper();
+ setField(mapper, "maxPrefsPerUserConsidered", 10);
+ setField(mapper, "usersToRecommendFor", usersToRecommendFor);
+
+
+ RandomAccessSparseVector vector = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ vector.set(34, 0.5);
+ vector.set(56, 0.7);
+
+ mapper.map(new VarLongWritable(123L), new VectorWritable(vector), context);
+ mapper.map(new VarLongWritable(456L), new VectorWritable(vector), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * tests {@link UserVectorSplitterMapper} in the special case that the number of preferences to be considered
+ * is less than the number of available preferences
+ */
+ @Test
+ public void testUserVectorSplitterMapperOnlySomePrefsConsidered() throws Exception {
+ Mapper<VarLongWritable,VectorWritable, VarIntWritable,VectorOrPrefWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+
+ context.write(EasyMock.eq(new VarIntWritable(34)), prefOfVectorOrPrefWritableMatchesNaN(123L));
+ context.write(EasyMock.eq(new VarIntWritable(56)), prefOfVectorOrPrefWritableMatches(123L, 0.7f));
+
+ EasyMock.replay(context);
+
+ UserVectorSplitterMapper mapper = new UserVectorSplitterMapper();
+ setField(mapper, "maxPrefsPerUserConsidered", 1);
+
+ RandomAccessSparseVector vector = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ vector.set(34, 0.5);
+ vector.set(56, 0.7);
+
+ mapper.map(new VarLongWritable(123L), new VectorWritable(vector), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * verifies that a preference value is NaN in a {@link VectorOrPrefWritable}
+ */
+ private static VectorOrPrefWritable prefOfVectorOrPrefWritableMatchesNaN(final long userID) {
+ EasyMock.reportMatcher(new IArgumentMatcher() {
+ @Override
+ public boolean matches(Object argument) {
+ if (argument instanceof VectorOrPrefWritable) {
+ VectorOrPrefWritable pref = (VectorOrPrefWritable) argument;
+ return pref.getUserID() == userID && Float.isNaN(pref.getValue());
+ }
+ return false;
+ }
+
+ @Override
+ public void appendTo(StringBuffer buffer) {}
+ });
+ return null;
+ }
+
+ /**
+ * tests {@link ToVectorAndPrefReducer}
+ */
+ @Test
+ public void testToVectorAndPrefReducer() throws Exception {
+ Reducer<VarIntWritable,VectorOrPrefWritable,VarIntWritable,VectorAndPrefsWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+
+ context.write(EasyMock.eq(new VarIntWritable(1)), vectorAndPrefsWritableMatches(Arrays.asList(123L, 456L),
+ Arrays.asList(1.0f, 2.0f), MathHelper.elem(3, 0.5), MathHelper.elem(7, 0.8)));
+
+ EasyMock.replay(context);
+
+ Vector similarityColumn = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ similarityColumn.set(3, 0.5);
+ similarityColumn.set(7, 0.8);
+
+ VectorOrPrefWritable itemPref1 = new VectorOrPrefWritable(123L, 1.0f);
+ VectorOrPrefWritable itemPref2 = new VectorOrPrefWritable(456L, 2.0f);
+ VectorOrPrefWritable similarities = new VectorOrPrefWritable(similarityColumn);
+
+ new ToVectorAndPrefReducer().reduce(new VarIntWritable(1), Arrays.asList(itemPref1, itemPref2, similarities),
+ context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * verifies a {@link VectorAndPrefsWritable}
+ */
+ private static VectorAndPrefsWritable vectorAndPrefsWritableMatches(final List<Long> userIDs,
+ final List<Float> prefValues, final Vector.Element... elements) {
+ EasyMock.reportMatcher(new IArgumentMatcher() {
+ @Override
+ public boolean matches(Object argument) {
+ if (argument instanceof VectorAndPrefsWritable) {
+ VectorAndPrefsWritable vectorAndPrefs = (VectorAndPrefsWritable) argument;
+
+ if (!vectorAndPrefs.getUserIDs().equals(userIDs)) {
+ return false;
+ }
+ if (!vectorAndPrefs.getValues().equals(prefValues)) {
+ return false;
+ }
+ return MathHelper.consistsOf(vectorAndPrefs.getVector(), elements);
+ }
+ return false;
+ }
+
+ @Override
+ public void appendTo(StringBuffer buffer) {}
+ });
+ return null;
+ }
+
+ /**
+ * tests {@link ToVectorAndPrefReducer} in the error case that two similarity column vectors a supplied for the same
+ * item (which should never happen)
+ */
+ @Test
+ public void testToVectorAndPrefReducerExceptionOn2Vectors() throws Exception {
+ Reducer<VarIntWritable,VectorOrPrefWritable,VarIntWritable,VectorAndPrefsWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+
+ EasyMock.replay(context);
+
+ Vector similarityColumn1 = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ Vector similarityColumn2 = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+
+ VectorOrPrefWritable similarities1 = new VectorOrPrefWritable(similarityColumn1);
+ VectorOrPrefWritable similarities2 = new VectorOrPrefWritable(similarityColumn2);
+
+ try {
+ new ToVectorAndPrefReducer().reduce(new VarIntWritable(1), Arrays.asList(similarities1, similarities2), context);
+ fail();
+ } catch (IllegalStateException e) {
+ // good
+ }
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * tests {@link org.apache.mahout.cf.taste.hadoop.item.ItemFilterMapper}
+ */
+ @Test
+ public void testItemFilterMapper() throws Exception {
+
+ Mapper<LongWritable,Text,VarLongWritable,VarLongWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+
+ context.write(new VarLongWritable(34L), new VarLongWritable(12L));
+ context.write(new VarLongWritable(78L), new VarLongWritable(56L));
+
+ EasyMock.replay(context);
+
+ ItemFilterMapper mapper = new ItemFilterMapper();
+ mapper.map(null, new Text("12,34"), context);
+ mapper.map(null, new Text("56,78"), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * tests {@link org.apache.mahout.cf.taste.hadoop.item.ItemFilterAsVectorAndPrefsReducer}
+ */
+ @Test
+ public void testItemFilterAsVectorAndPrefsReducer() throws Exception {
+ Reducer<VarLongWritable,VarLongWritable,VarIntWritable,VectorAndPrefsWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+
+ int itemIDIndex = TasteHadoopUtils.idToIndex(123L);
+ context.write(EasyMock.eq(new VarIntWritable(itemIDIndex)), vectorAndPrefsForFilteringMatches(123L, 456L, 789L));
+
+ EasyMock.replay(context);
+
+ new ItemFilterAsVectorAndPrefsReducer().reduce(new VarLongWritable(123L), Arrays.asList(new VarLongWritable(456L),
+ new VarLongWritable(789L)), context);
+
+ EasyMock.verify(context);
+ }
+
+ static VectorAndPrefsWritable vectorAndPrefsForFilteringMatches(final long itemID, final long... userIDs) {
+ EasyMock.reportMatcher(new IArgumentMatcher() {
+ @Override
+ public boolean matches(Object argument) {
+ if (argument instanceof VectorAndPrefsWritable) {
+ VectorAndPrefsWritable vectorAndPrefs = (VectorAndPrefsWritable) argument;
+ Vector vector = vectorAndPrefs.getVector();
+ if (vector.getNumNondefaultElements() != 1) {
+ return false;
+ }
+ if (!Double.isNaN(vector.get(TasteHadoopUtils.idToIndex(itemID)))) {
+ return false;
+ }
+ if (userIDs.length != vectorAndPrefs.getUserIDs().size()) {
+ return false;
+ }
+ for (long userID : userIDs) {
+ if (!vectorAndPrefs.getUserIDs().contains(userID)) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void appendTo(StringBuffer buffer) {}
+ });
+ return null;
+ }
+
+ /**
+ * tests {@link PartialMultiplyMapper}
+ */
+ @Test
+ public void testPartialMultiplyMapper() throws Exception {
+
+ Vector similarityColumn = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ similarityColumn.set(3, 0.5);
+ similarityColumn.set(7, 0.8);
+
+ Mapper<VarIntWritable,VectorAndPrefsWritable,VarLongWritable,PrefAndSimilarityColumnWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+
+ PrefAndSimilarityColumnWritable one = new PrefAndSimilarityColumnWritable();
+ PrefAndSimilarityColumnWritable two = new PrefAndSimilarityColumnWritable();
+ one.set(1.0f, similarityColumn);
+ two.set(3.0f, similarityColumn);
+
+ context.write(EasyMock.eq(new VarLongWritable(123L)), EasyMock.eq(one));
+ context.write(EasyMock.eq(new VarLongWritable(456L)), EasyMock.eq(two));
+
+ EasyMock.replay(context);
+
+ VectorAndPrefsWritable vectorAndPrefs = new VectorAndPrefsWritable(similarityColumn, Arrays.asList(123L, 456L),
+ Arrays.asList(1.0f, 3.0f));
+
+ new PartialMultiplyMapper().map(new VarIntWritable(1), vectorAndPrefs, context);
+
+ EasyMock.verify(context);
+ }
+
+
+ /**
+ * tests {@link AggregateAndRecommendReducer}
+ */
+ @Test
+ public void testAggregateAndRecommendReducer() throws Exception {
+ Reducer<VarLongWritable,PrefAndSimilarityColumnWritable,VarLongWritable,RecommendedItemsWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+
+ context.write(EasyMock.eq(new VarLongWritable(123L)), recommendationsMatch(new MutableRecommendedItem(1L, 2.8f),
+ new MutableRecommendedItem(2L, 2.0f)));
+
+ EasyMock.replay(context);
+
+ RandomAccessSparseVector similarityColumnOne = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ similarityColumnOne.set(1, 0.1);
+ similarityColumnOne.set(2, 0.5);
+
+ RandomAccessSparseVector similarityColumnTwo = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ similarityColumnTwo.set(1, 0.9);
+ similarityColumnTwo.set(2, 0.5);
+
+ List<PrefAndSimilarityColumnWritable> values = Arrays.asList(
+ new PrefAndSimilarityColumnWritable(1.0f, similarityColumnOne),
+ new PrefAndSimilarityColumnWritable(3.0f, similarityColumnTwo));
+
+ OpenIntLongHashMap indexItemIDMap = new OpenIntLongHashMap();
+ indexItemIDMap.put(1, 1L);
+ indexItemIDMap.put(2, 2L);
+
+ AggregateAndRecommendReducer reducer = new AggregateAndRecommendReducer();
+
+ setField(reducer, "indexItemIDMap", indexItemIDMap);
+ setField(reducer, "recommendationsPerUser", 3);
+
+ reducer.reduce(new VarLongWritable(123L), values, context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * tests {@link AggregateAndRecommendReducer}
+ */
+ @Test
+ public void testAggregateAndRecommendReducerExcludeRecommendationsBasedOnOneItem() throws Exception {
+ Reducer<VarLongWritable,PrefAndSimilarityColumnWritable,VarLongWritable,RecommendedItemsWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+
+ context.write(EasyMock.eq(new VarLongWritable(123L)), recommendationsMatch(new MutableRecommendedItem(1L, 2.8f)));
+
+ EasyMock.replay(context);
+
+ RandomAccessSparseVector similarityColumnOne = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ similarityColumnOne.set(1, 0.1);
+
+ RandomAccessSparseVector similarityColumnTwo = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ similarityColumnTwo.set(1, 0.9);
+ similarityColumnTwo.set(2, 0.5);
+
+ List<PrefAndSimilarityColumnWritable> values = Arrays.asList(
+ new PrefAndSimilarityColumnWritable(1.0f, similarityColumnOne),
+ new PrefAndSimilarityColumnWritable(3.0f, similarityColumnTwo));
+
+ OpenIntLongHashMap indexItemIDMap = new OpenIntLongHashMap();
+ indexItemIDMap.put(1, 1L);
+ indexItemIDMap.put(2, 2L);
+
+ AggregateAndRecommendReducer reducer = new AggregateAndRecommendReducer();
+
+ setField(reducer, "indexItemIDMap", indexItemIDMap);
+ setField(reducer, "recommendationsPerUser", 3);
+
+ reducer.reduce(new VarLongWritable(123L), values, context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * tests {@link AggregateAndRecommendReducer} with a limit on the recommendations per user
+ */
+ @Test
+ public void testAggregateAndRecommendReducerLimitNumberOfRecommendations() throws Exception {
+ Reducer<VarLongWritable,PrefAndSimilarityColumnWritable,VarLongWritable,RecommendedItemsWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+
+ context.write(EasyMock.eq(new VarLongWritable(123L)), recommendationsMatch(new MutableRecommendedItem(1L, 2.8f)));
+
+ EasyMock.replay(context);
+
+ RandomAccessSparseVector similarityColumnOne = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ similarityColumnOne.set(1, 0.1);
+ similarityColumnOne.set(2, 0.5);
+
+ RandomAccessSparseVector similarityColumnTwo = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ similarityColumnTwo.set(1, 0.9);
+ similarityColumnTwo.set(2, 0.5);
+
+ List<PrefAndSimilarityColumnWritable> values = Arrays.asList(
+ new PrefAndSimilarityColumnWritable(1.0f, similarityColumnOne),
+ new PrefAndSimilarityColumnWritable(3.0f, similarityColumnTwo));
+
+ OpenIntLongHashMap indexItemIDMap = new OpenIntLongHashMap();
+ indexItemIDMap.put(1, 1L);
+ indexItemIDMap.put(2, 2L);
+
+ AggregateAndRecommendReducer reducer = new AggregateAndRecommendReducer();
+
+ setField(reducer, "indexItemIDMap", indexItemIDMap);
+ setField(reducer, "recommendationsPerUser", 1);
+
+ reducer.reduce(new VarLongWritable(123L), values, context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * verifies a {@link RecommendedItemsWritable}
+ */
+ static RecommendedItemsWritable recommendationsMatch(final RecommendedItem... items) {
+ EasyMock.reportMatcher(new IArgumentMatcher() {
+ @Override
+ public boolean matches(Object argument) {
+ if (argument instanceof RecommendedItemsWritable) {
+ RecommendedItemsWritable recommendedItemsWritable = (RecommendedItemsWritable) argument;
+ List<RecommendedItem> expectedItems = Arrays.asList(items);
+ return expectedItems.equals(recommendedItemsWritable.getRecommendedItems());
+ }
+ return false;
+ }
+
+ @Override
+ public void appendTo(StringBuffer buffer) {}
+ });
+ return null;
+ }
+
+ /**
+ * small integration test that runs the full job
+ *
+ * As a tribute to http://www.slideshare.net/srowen/collaborative-filtering-at-scale,
+ * we recommend people food to animals in this test :)
+ *
+ * <pre>
+ *
+ * user-item-matrix
+ *
+ * burger hotdog berries icecream
+ * dog 5 5 2 -
+ * rabbit 2 - 3 5
+ * cow - 5 - 3
+ * donkey 3 - - 5
+ *
+ *
+ * item-item-similarity-matrix (tanimoto-coefficient of the item-vectors of the user-item-matrix)
+ *
+ * burger hotdog berries icecream
+ * burger - 0.25 0.66 0.5
+ * hotdog 0.25 - 0.33 0.25
+ * berries 0.66 0.33 - 0.25
+ * icecream 0.5 0.25 0.25 -
+ *
+ *
+ * Prediction(dog, icecream) = (0.5 * 5 + 0.25 * 5 + 0.25 * 2 ) / (0.5 + 0.25 + 0.25) ~ 4.3
+ * Prediction(rabbit, hotdog) = (0.25 * 2 + 0.33 * 3 + 0.25 * 5) / (0.25 + 0.33 + 0.25) ~ 3,3
+ * Prediction(cow, burger) = (0.25 * 5 + 0.5 * 3) / (0.25 + 0.5) ~ 3,7
+ * Prediction(cow, berries) = (0.33 * 5 + 0.25 * 3) / (0.33 + 0.25) ~ 4,1
+ * Prediction(donkey, hotdog) = (0.25 * 3 + 0.25 * 5) / (0.25 + 0.25) ~ 4
+ * Prediction(donkey, berries) = (0.66 * 3 + 0.25 * 5) / (0.66 + 0.25) ~ 3,5
+ *
+ * </pre>
+ */
+ @Test
+ public void testCompleteJob() throws Exception {
+
+ File inputFile = getTestTempFile("prefs.txt");
+ File outputDir = getTestTempDir("output");
+ outputDir.delete();
+ File similaritiesOutputDir = getTestTempDir("outputSimilarities");
+ similaritiesOutputDir.delete();
+ File tmpDir = getTestTempDir("tmp");
+
+ writeLines(inputFile,
+ "1,1,5",
+ "1,2,5",
+ "1,3,2",
+ "2,1,2",
+ "2,3,3",
+ "2,4,5",
+ "3,2,5",
+ "3,4,3",
+ "4,1,3",
+ "4,4,5");
+
+ RecommenderJob recommenderJob = new RecommenderJob();
+
+ Configuration conf = getConfiguration();
+ conf.set("mapred.input.dir", inputFile.getAbsolutePath());
+ conf.set("mapred.output.dir", outputDir.getAbsolutePath());
+ conf.setBoolean("mapred.output.compress", false);
+
+ recommenderJob.setConf(conf);
+
+ recommenderJob.run(new String[] { "--tempDir", tmpDir.getAbsolutePath(), "--similarityClassname",
+ TanimotoCoefficientSimilarity.class.getName(), "--numRecommendations", "4",
+ "--outputPathForSimilarityMatrix", similaritiesOutputDir.getAbsolutePath() });
+
+ Map<Long,List<RecommendedItem>> recommendations = readRecommendations(new File(outputDir, "part-r-00000"));
+ assertEquals(4, recommendations.size());
+
+ for (Entry<Long,List<RecommendedItem>> entry : recommendations.entrySet()) {
+ long userID = entry.getKey();
+ List<RecommendedItem> items = entry.getValue();
+ assertNotNull(items);
+ RecommendedItem item1 = items.get(0);
+
+ if (userID == 1L) {
+ assertEquals(1, items.size());
+ assertEquals(4L, item1.getItemID());
+ assertEquals(4.3, item1.getValue(), 0.05);
+ }
+ if (userID == 2L) {
+ assertEquals(1, items.size());
+ assertEquals(2L, item1.getItemID());
+ assertEquals(3.3, item1.getValue(), 0.05);
+ }
+ if (userID == 3L) {
+ assertEquals(2, items.size());
+ assertEquals(3L, item1.getItemID());
+ assertEquals(4.1, item1.getValue(), 0.05);
+ RecommendedItem item2 = items.get(1);
+ assertEquals(1L, item2.getItemID());
+ assertEquals(3.7, item2.getValue(), 0.05);
+ }
+ if (userID == 4L) {
+ assertEquals(2, items.size());
+ assertEquals(2L, item1.getItemID());
+ assertEquals(4.0, item1.getValue(), 0.05);
+ RecommendedItem item2 = items.get(1);
+ assertEquals(3L, item2.getItemID());
+ assertEquals(3.5, item2.getValue(), 0.05);
+ }
+ }
+
+ Map<Pair<Long, Long>, Double> similarities = readSimilarities(new File(similaritiesOutputDir, "part-r-00000"));
+ assertEquals(6, similarities.size());
+
+ assertEquals(0.25, similarities.get(new Pair<Long, Long>(1L, 2L)), EPSILON);
+ assertEquals(0.6666666666666666, similarities.get(new Pair<Long, Long>(1L, 3L)), EPSILON);
+ assertEquals(0.5, similarities.get(new Pair<Long, Long>(1L, 4L)), EPSILON);
+ assertEquals(0.3333333333333333, similarities.get(new Pair<Long, Long>(2L, 3L)), EPSILON);
+ assertEquals(0.25, similarities.get(new Pair<Long, Long>(2L, 4L)), EPSILON);
+ assertEquals(0.25, similarities.get(new Pair<Long, Long>(3L, 4L)), EPSILON);
+ }
+
+ /**
+ * small integration test for boolean data
+ */
+ @Test
+ public void testCompleteJobBoolean() throws Exception {
+
+ File inputFile = getTestTempFile("prefs.txt");
+ File outputDir = getTestTempDir("output");
+ outputDir.delete();
+ File tmpDir = getTestTempDir("tmp");
+ File usersFile = getTestTempFile("users.txt");
+ writeLines(usersFile, "3");
+
+ writeLines(inputFile,
+ "1,1",
+ "1,2",
+ "1,3",
+ "2,1",
+ "2,3",
+ "2,4",
+ "3,2",
+ "3,4",
+ "4,1",
+ "4,4");
+
+ RecommenderJob recommenderJob = new RecommenderJob();
+
+ Configuration conf = getConfiguration();
+ conf.set("mapred.input.dir", inputFile.getAbsolutePath());
+ conf.set("mapred.output.dir", outputDir.getAbsolutePath());
+ conf.setBoolean("mapred.output.compress", false);
+
+ recommenderJob.setConf(conf);
+
+ recommenderJob.run(new String[] { "--tempDir", tmpDir.getAbsolutePath(), "--similarityClassname",
+ CooccurrenceCountSimilarity.class.getName(), "--booleanData", "true",
+ "--usersFile", usersFile.getAbsolutePath() });
+
+ Map<Long,List<RecommendedItem>> recommendations = readRecommendations(new File(outputDir, "part-r-00000"));
+
+ List<RecommendedItem> recommendedToCow = recommendations.get(3L);
+ assertEquals(2, recommendedToCow.size());
+
+ RecommendedItem item1 = recommendedToCow.get(0);
+ RecommendedItem item2 = recommendedToCow.get(1);
+
+ assertEquals(1L, item1.getItemID());
+ assertEquals(3L, item2.getItemID());
+
+ /* predicted pref must be the sum of similarities:
+ * item1: coocc(burger, hotdog) + coocc(burger, icecream) = 3
+ * item2: coocc(berries, hotdog) + coocc(berries, icecream) = 2 */
+ assertEquals(3, item1.getValue(), 0.05);
+ assertEquals(2, item2.getValue(), 0.05);
+ }
+
+ /**
+ * check whether the explicit user/item filter works
+ */
+ @Test
+ public void testCompleteJobWithFiltering() throws Exception {
+
+ File inputFile = getTestTempFile("prefs.txt");
+ File userFile = getTestTempFile("users.txt");
+ File filterFile = getTestTempFile("filter.txt");
+ File outputDir = getTestTempDir("output");
+ outputDir.delete();
+ File tmpDir = getTestTempDir("tmp");
+
+ writeLines(inputFile,
+ "1,1,5",
+ "1,2,5",
+ "1,3,2",
+ "2,1,2",
+ "2,3,3",
+ "2,4,5",
+ "3,2,5",
+ "3,4,3",
+ "4,1,3",
+ "4,4,5");
+
+ /* only compute recommendations for the donkey */
+ writeLines(userFile, "4");
+ /* do not recommend the hotdog for the donkey */
+ writeLines(filterFile, "4,2");
+
+ RecommenderJob recommenderJob = new RecommenderJob();
+
+ Configuration conf = getConfiguration();
+ conf.set("mapred.input.dir", inputFile.getAbsolutePath());
+ conf.set("mapred.output.dir", outputDir.getAbsolutePath());
+ conf.setBoolean("mapred.output.compress", false);
+
+ recommenderJob.setConf(conf);
+
+ recommenderJob.run(new String[] { "--tempDir", tmpDir.getAbsolutePath(), "--similarityClassname",
+ TanimotoCoefficientSimilarity.class.getName(), "--numRecommendations", "1",
+ "--usersFile", userFile.getAbsolutePath(), "--filterFile", filterFile.getAbsolutePath() });
+
+ Map<Long,List<RecommendedItem>> recommendations = readRecommendations(new File(outputDir, "part-r-00000"));
+
+ assertEquals(1, recommendations.size());
+ assertTrue(recommendations.containsKey(4L));
+ assertEquals(1, recommendations.get(4L).size());
+
+ /* berries should have been recommended to the donkey */
+ RecommendedItem recommendedItem = recommendations.get(4L).get(0);
+ assertEquals(3L, recommendedItem.getItemID());
+ assertEquals(3.5, recommendedItem.getValue(), 0.05);
+ }
+
+ static Map<Pair<Long,Long>, Double> readSimilarities(File file) throws IOException {
+ Map<Pair<Long,Long>, Double> similarities = Maps.newHashMap();
+ for (String line : new FileLineIterable(file)) {
+ String[] parts = line.split("\t");
+ similarities.put(new Pair<Long,Long>(Long.parseLong(parts[0]), Long.parseLong(parts[1])),
+ Double.parseDouble(parts[2]));
+ }
+ return similarities;
+ }
+
+ static Map<Long,List<RecommendedItem>> readRecommendations(File file) throws IOException {
+ Map<Long,List<RecommendedItem>> recommendations = Maps.newHashMap();
+ for (String line : new FileLineIterable(file)) {
+
+ String[] keyValue = line.split("\t");
+ long userID = Long.parseLong(keyValue[0]);
+ String[] tokens = keyValue[1].replaceAll("\\[", "")
+ .replaceAll("\\]", "").split(",");
+
+ List<RecommendedItem> items = Lists.newLinkedList();
+ for (String token : tokens) {
+ String[] itemTokens = token.split(":");
+ long itemID = Long.parseLong(itemTokens[0]);
+ float value = Float.parseFloat(itemTokens[1]);
+ items.add(new GenericRecommendedItem(itemID, value));
+ }
+ recommendations.put(userID, items);
+ }
+ return recommendations;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/ToUserVectorsReducerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/ToUserVectorsReducerTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/ToUserVectorsReducerTest.java
new file mode 100644
index 0000000..bb22b71
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/ToUserVectorsReducerTest.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import org.apache.hadoop.mapreduce.Counter;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+
+/**
+ * tests {@link ToUserVectorsReducer}
+ */
+public class ToUserVectorsReducerTest extends TasteTestCase {
+
+ @Test
+ public void testToUsersReducerMinPreferencesUserIgnored() throws Exception {
+ Reducer<VarLongWritable,VarLongWritable,VarLongWritable,VectorWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+
+ ToUserVectorsReducer reducer = new ToUserVectorsReducer();
+ setField(reducer, "minPreferences", 2);
+
+ EasyMock.replay(context);
+
+ reducer.reduce(new VarLongWritable(123), Collections.singletonList(new VarLongWritable(456)), context);
+
+ EasyMock.verify(context);
+ }
+
+ @Test
+ public void testToUsersReducerMinPreferencesUserPasses() throws Exception {
+ Reducer<VarLongWritable,VarLongWritable,VarLongWritable,VectorWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+ Counter userCounters = EasyMock.createMock(Counter.class);
+
+ ToUserVectorsReducer reducer = new ToUserVectorsReducer();
+ setField(reducer, "minPreferences", 2);
+
+ EasyMock.expect(context.getCounter(ToUserVectorsReducer.Counters.USERS)).andReturn(userCounters);
+ userCounters.increment(1);
+ context.write(EasyMock.eq(new VarLongWritable(123)), MathHelper.vectorMatches(
+ MathHelper.elem(TasteHadoopUtils.idToIndex(456L), 1.0), MathHelper.elem(TasteHadoopUtils.idToIndex(789L), 1.0)));
+
+ EasyMock.replay(context, userCounters);
+
+ reducer.reduce(new VarLongWritable(123), Arrays.asList(new VarLongWritable(456), new VarLongWritable(789)), context);
+
+ EasyMock.verify(context, userCounters);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJobTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJobTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJobTest.java
new file mode 100644
index 0000000..f61b5e6
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/similarity/item/ItemSimilarityJobTest.java
@@ -0,0 +1,269 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.similarity.item;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FilenameFilter;
+import java.util.Arrays;
+import java.util.regex.Pattern;
+
+import com.google.common.base.Charsets;
+import com.google.common.io.Files;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.cf.taste.hadoop.EntityEntityWritable;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.CosineSimilarity;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.TanimotoCoefficientSimilarity;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+/**
+ * Unit tests for the mappers and reducers in org.apache.mahout.cf.taste.hadoop.similarity.item
+ * some integration tests with tiny data sets at the end
+ */
+public final class ItemSimilarityJobTest extends TasteTestCase {
+
+ private static final Pattern TAB = Pattern.compile("\t");
+
+ /**
+ * Tests {@link ItemSimilarityJob.MostSimilarItemPairsMapper}
+ */
+ @Test
+ public void testMostSimilarItemsPairsMapper() throws Exception {
+
+ OpenIntLongHashMap indexItemIDMap = new OpenIntLongHashMap();
+ indexItemIDMap.put(12, 12L);
+ indexItemIDMap.put(34, 34L);
+ indexItemIDMap.put(56, 56L);
+
+ Mapper<IntWritable,VectorWritable,EntityEntityWritable,DoubleWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+
+ context.write(new EntityEntityWritable(34L, 56L), new DoubleWritable(0.9));
+
+ EasyMock.replay(context);
+
+ Vector vector = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ vector.set(12, 0.2);
+ vector.set(56, 0.9);
+
+ ItemSimilarityJob.MostSimilarItemPairsMapper mapper = new ItemSimilarityJob.MostSimilarItemPairsMapper();
+ setField(mapper, "indexItemIDMap", indexItemIDMap);
+ setField(mapper, "maxSimilarItemsPerItem", 1);
+
+ mapper.map(new IntWritable(34), new VectorWritable(vector), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * Tests {@link ItemSimilarityJob.MostSimilarItemPairsReducer}
+ */
+ @Test
+ public void testMostSimilarItemPairsReducer() throws Exception {
+ Reducer<EntityEntityWritable,DoubleWritable,EntityEntityWritable,DoubleWritable>.Context context =
+ EasyMock.createMock(Reducer.Context.class);
+
+ context.write(new EntityEntityWritable(123L, 456L), new DoubleWritable(0.5));
+
+ EasyMock.replay(context);
+
+ new ItemSimilarityJob.MostSimilarItemPairsReducer().reduce(new EntityEntityWritable(123L, 456L),
+ Arrays.asList(new DoubleWritable(0.5), new DoubleWritable(0.5)), context);
+
+ EasyMock.verify(context);
+ }
+
+ /**
+ * Integration test with a tiny data set
+ *
+ * <pre>
+ * user-item-matrix
+ *
+ * Game Mouse PC Disk
+ * Jane - 1 2 -
+ * Paul 1 - 1 -
+ * Fred - - - 1
+ * </pre>
+ */
+ @Test
+ public void testCompleteJob() throws Exception {
+
+ File inputFile = getTestTempFile("prefs.txt");
+ File outputDir = getTestTempDir("output");
+ outputDir.delete();
+ File tmpDir = getTestTempDir("tmp");
+
+ writeLines(inputFile,
+ "2,1,1",
+ "1,2,1",
+ "3,4,1",
+ "1,3,2",
+ "2,3,1");
+
+ ItemSimilarityJob similarityJob = new ItemSimilarityJob();
+
+ Configuration conf = getConfiguration();
+ conf.set("mapred.input.dir", inputFile.getAbsolutePath());
+ conf.set("mapred.output.dir", outputDir.getAbsolutePath());
+ conf.setBoolean("mapred.output.compress", false);
+
+ similarityJob.setConf(conf);
+ similarityJob.run(new String[] { "--tempDir", tmpDir.getAbsolutePath(), "--similarityClassname",
+ CosineSimilarity.class.getName() });
+ File outPart = outputDir.listFiles(new FilenameFilter() {
+ @Override
+ public boolean accept(File dir, String name) {
+ return name.startsWith("part-");
+ }
+ })[0];
+ BufferedReader reader = Files.newReader(outPart, Charsets.UTF_8);
+
+ String line;
+ int currentLine = 1;
+ while ( (line = reader.readLine()) != null) {
+
+ String[] tokens = TAB.split(line);
+
+ long itemAID = Long.parseLong(tokens[0]);
+ long itemBID = Long.parseLong(tokens[1]);
+ double similarity = Double.parseDouble(tokens[2]);
+
+ if (currentLine == 1) {
+ assertEquals(1L, itemAID);
+ assertEquals(3L, itemBID);
+ assertEquals(0.45, similarity, 0.01);
+ }
+
+ if (currentLine == 2) {
+ assertEquals(2L, itemAID);
+ assertEquals(3L, itemBID);
+ assertEquals(0.89, similarity, 0.01);
+ }
+
+ currentLine++;
+ }
+
+ int linesWritten = currentLine-1;
+ assertEquals(2, linesWritten);
+ }
+
+ /**
+ * integration test for the limitation of the number of computed similarities
+ *
+ * <pre>
+ * user-item-matrix
+ *
+ * i1 i2 i3
+ * u1 1 0 1
+ * u2 0 1 1
+ * u3 1 1 0
+ * u4 1 1 1
+ * u5 0 1 0
+ * u6 1 1 0
+ *
+ * tanimoto(i1,i2) = 0.5
+ * tanimoto(i2,i3) = 0.333
+ * tanimoto(i3,i1) = 0.4
+ *
+ * When we set maxSimilaritiesPerItem to 1 the following pairs should be found:
+ *
+ * i1 --> i2
+ * i2 --> i1
+ * i3 --> i1
+ * </pre>
+ */
+ @Test
+ public void testMaxSimilaritiesPerItem() throws Exception {
+
+ File inputFile = getTestTempFile("prefsForMaxSimilarities.txt");
+ File outputDir = getTestTempDir("output");
+ outputDir.delete();
+ File tmpDir = getTestTempDir("tmp");
+
+ writeLines(inputFile,
+ "1,1,1",
+ "1,3,1",
+ "2,2,1",
+ "2,3,1",
+ "3,1,1",
+ "3,2,1",
+ "4,1,1",
+ "4,2,1",
+ "4,3,1",
+ "5,2,1",
+ "6,1,1",
+ "6,2,1");
+
+ ItemSimilarityJob similarityJob = new ItemSimilarityJob();
+
+ Configuration conf = getConfiguration();
+ conf.set("mapred.input.dir", inputFile.getAbsolutePath());
+ conf.set("mapred.output.dir", outputDir.getAbsolutePath());
+ conf.setBoolean("mapred.output.compress", false);
+
+ similarityJob.setConf(conf);
+ similarityJob.run(new String[] { "--tempDir", tmpDir.getAbsolutePath(), "--similarityClassname",
+ TanimotoCoefficientSimilarity.class.getName(), "--maxSimilaritiesPerItem", "1" });
+ File outPart = outputDir.listFiles(new FilenameFilter() {
+ @Override
+ public boolean accept(File dir, String name) {
+ return name.startsWith("part-");
+ }
+ })[0];
+ BufferedReader reader = Files.newReader(outPart, Charsets.UTF_8);
+
+ String line;
+ int currentLine = 1;
+ while ((line = reader.readLine()) != null) {
+
+ String[] tokens = TAB.split(line);
+
+ long itemAID = Long.parseLong(tokens[0]);
+ long itemBID = Long.parseLong(tokens[1]);
+ double similarity = Double.parseDouble(tokens[2]);
+
+ if (currentLine == 1) {
+ assertEquals(1L, itemAID);
+ assertEquals(2L, itemBID);
+ assertEquals(0.5, similarity, 0.0001);
+ }
+
+ if (currentLine == 2) {
+ assertEquals(1L, itemAID);
+ assertEquals(3L, itemBID);
+ assertEquals(0.4, similarity, 0.0001);
+ }
+
+ currentLine++;
+ }
+
+ int linesWritten = currentLine - 1;
+ assertEquals(2, linesWritten);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/TasteTestCase.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/TasteTestCase.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/TasteTestCase.java
new file mode 100644
index 0000000..2f8ca95
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/TasteTestCase.java
@@ -0,0 +1,98 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.model.GenericBooleanPrefDataModel;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+
+import java.util.List;
+
+public abstract class TasteTestCase extends MahoutTestCase {
+
+ public static DataModel getDataModel(long[] userIDs, Double[][] prefValues) {
+ FastByIDMap<PreferenceArray> result = new FastByIDMap<PreferenceArray>();
+ for (int i = 0; i < userIDs.length; i++) {
+ List<Preference> prefsList = Lists.newArrayList();
+ for (int j = 0; j < prefValues[i].length; j++) {
+ if (prefValues[i][j] != null) {
+ prefsList.add(new GenericPreference(userIDs[i], j, prefValues[i][j].floatValue()));
+ }
+ }
+ if (!prefsList.isEmpty()) {
+ result.put(userIDs[i], new GenericUserPreferenceArray(prefsList));
+ }
+ }
+ return new GenericDataModel(result);
+ }
+
+ public static DataModel getBooleanDataModel(long[] userIDs, boolean[][] prefs) {
+ FastByIDMap<FastIDSet> result = new FastByIDMap<FastIDSet>();
+ for (int i = 0; i < userIDs.length; i++) {
+ FastIDSet prefsSet = new FastIDSet();
+ for (int j = 0; j < prefs[i].length; j++) {
+ if (prefs[i][j]) {
+ prefsSet.add(j);
+ }
+ }
+ if (!prefsSet.isEmpty()) {
+ result.put(userIDs[i], prefsSet);
+ }
+ }
+ return new GenericBooleanPrefDataModel(result);
+ }
+
+ protected static DataModel getDataModel() {
+ return getDataModel(
+ new long[] {1, 2, 3, 4},
+ new Double[][] {
+ {0.1, 0.3},
+ {0.2, 0.3, 0.3},
+ {0.4, 0.3, 0.5},
+ {0.7, 0.3, 0.8},
+ });
+ }
+
+ protected static DataModel getBooleanDataModel() {
+ return getBooleanDataModel(new long[] {1, 2, 3, 4},
+ new boolean[][] {
+ {false, true, false},
+ {false, true, true, false},
+ {true, false, false, true},
+ {true, false, true, true},
+ });
+ }
+
+ protected static boolean arrayContains(long[] array, long value) {
+ for (long l : array) {
+ if (l == value) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/BitSetTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/BitSetTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/BitSetTest.java
new file mode 100644
index 0000000..1f7c76b
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/BitSetTest.java
@@ -0,0 +1,74 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+
+public final class BitSetTest extends TasteTestCase {
+
+ private static final int NUM_BITS = 100;
+
+ @Test
+ public void testGetSet() {
+ BitSet bitSet = new BitSet(NUM_BITS);
+ for (int i = 0; i < NUM_BITS; i++) {
+ assertFalse(bitSet.get(i));
+ }
+ bitSet.set(0);
+ bitSet.set(NUM_BITS-1);
+ assertTrue(bitSet.get(0));
+ assertTrue(bitSet.get(NUM_BITS-1));
+ }
+
+ @Test(expected = ArrayIndexOutOfBoundsException.class)
+ public void testBounds1() {
+ BitSet bitSet = new BitSet(NUM_BITS);
+ bitSet.set(1000);
+ }
+
+ @Test(expected = ArrayIndexOutOfBoundsException.class)
+ public void testBounds2() {
+ BitSet bitSet = new BitSet(NUM_BITS);
+ bitSet.set(-1);
+ }
+
+ @Test
+ public void testClear() {
+ BitSet bitSet = new BitSet(NUM_BITS);
+ for (int i = 0; i < NUM_BITS; i++) {
+ bitSet.set(i);
+ }
+ for (int i = 0; i < NUM_BITS; i++) {
+ assertTrue(bitSet.get(i));
+ }
+ bitSet.clear();
+ for (int i = 0; i < NUM_BITS; i++) {
+ assertFalse(bitSet.get(i));
+ }
+ }
+
+ @Test
+ public void testClone() {
+ BitSet bitSet = new BitSet(NUM_BITS);
+ bitSet.set(NUM_BITS-1);
+ bitSet = bitSet.clone();
+ assertTrue(bitSet.get(NUM_BITS-1));
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/CacheTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/CacheTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/CacheTest.java
new file mode 100644
index 0000000..cab1984
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/CacheTest.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import java.util.Random;
+
+public final class CacheTest extends TasteTestCase {
+
+ @Test
+ public void testLotsOfGets() throws TasteException {
+ Retriever<Object,Object> retriever = new IdentityRetriever();
+ Cache<Object,Object> cache = new Cache<Object,Object>(retriever, 1000);
+ for (int i = 0; i < 1000000; i++) {
+ assertEquals(i, cache.get(i));
+ }
+ }
+
+ @Test
+ public void testMixedUsage() throws TasteException {
+ Random random = RandomUtils.getRandom();
+ Retriever<Object,Object> retriever = new IdentityRetriever();
+ Cache<Object,Object> cache = new Cache<Object,Object>(retriever, 1000);
+ for (int i = 0; i < 1000000; i++) {
+ double r = random.nextDouble();
+ if (r < 0.01) {
+ cache.clear();
+ } else if (r < 0.1) {
+ cache.remove(r - 100);
+ } else {
+ assertEquals(i, cache.get(i));
+ }
+ }
+ }
+
+ private static class IdentityRetriever implements Retriever<Object,Object> {
+ @Override
+ public Object get(Object key) throws TasteException {
+ return key;
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java
new file mode 100644
index 0000000..9263ce7
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastByIDMapTest.java
@@ -0,0 +1,147 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import com.google.common.collect.Maps;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import java.util.Map;
+import java.util.Random;
+
+/** <p>Tests {@link FastByIDMap}.</p> */
+public final class FastByIDMapTest extends TasteTestCase {
+
+ @Test
+ public void testPutAndGet() {
+ FastByIDMap<Long> map = new FastByIDMap<Long>();
+ assertNull(map.get(500000L));
+ map.put(500000L, 2L);
+ assertEquals(2L, (long) map.get(500000L));
+ }
+
+ @Test
+ public void testRemove() {
+ FastByIDMap<Long> map = new FastByIDMap<Long>();
+ map.put(500000L, 2L);
+ map.remove(500000L);
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ assertNull(map.get(500000L));
+ }
+
+ @Test
+ public void testClear() {
+ FastByIDMap<Long> map = new FastByIDMap<Long>();
+ map.put(500000L, 2L);
+ map.clear();
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ assertNull(map.get(500000L));
+ }
+
+ @Test
+ public void testSizeEmpty() {
+ FastByIDMap<Long> map = new FastByIDMap<Long>();
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ map.put(500000L, 2L);
+ assertEquals(1, map.size());
+ assertFalse(map.isEmpty());
+ map.remove(500000L);
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ }
+
+ @Test
+ public void testContains() {
+ FastByIDMap<String> map = buildTestFastMap();
+ assertTrue(map.containsKey(500000L));
+ assertTrue(map.containsKey(47L));
+ assertTrue(map.containsKey(2L));
+ assertTrue(map.containsValue("alpha"));
+ assertTrue(map.containsValue("bang"));
+ assertTrue(map.containsValue("beta"));
+ assertFalse(map.containsKey(999));
+ assertFalse(map.containsValue("something"));
+ }
+
+ @Test
+ public void testRehash() {
+ FastByIDMap<String> map = buildTestFastMap();
+ map.remove(500000L);
+ map.rehash();
+ assertNull(map.get(500000L));
+ assertEquals("bang", map.get(47L));
+ }
+
+ @Test
+ public void testGrow() {
+ FastByIDMap<String> map = new FastByIDMap<String>(1,1);
+ map.put(500000L, "alpha");
+ map.put(47L, "bang");
+ assertNull(map.get(500000L));
+ assertEquals("bang", map.get(47L));
+ }
+
+ @Test
+ public void testVersusHashMap() {
+ FastByIDMap<String> actual = new FastByIDMap<String>();
+ Map<Long, String> expected = Maps.newHashMapWithExpectedSize(1000000);
+ Random r = RandomUtils.getRandom();
+ for (int i = 0; i < 1000000; i++) {
+ double d = r.nextDouble();
+ Long key = (long) r.nextInt(100);
+ if (d < 0.4) {
+ assertEquals(expected.get(key), actual.get(key));
+ } else {
+ if (d < 0.7) {
+ assertEquals(expected.put(key, "bang"), actual.put(key, "bang"));
+ } else {
+ assertEquals(expected.remove(key), actual.remove(key));
+ }
+ assertEquals(expected.size(), actual.size());
+ assertEquals(expected.isEmpty(), actual.isEmpty());
+ }
+ }
+ }
+
+ @Test
+ public void testMaxSize() {
+ FastByIDMap<String> map = new FastByIDMap<String>();
+ map.put(4, "bang");
+ assertEquals(1, map.size());
+ map.put(47L, "bang");
+ assertEquals(2, map.size());
+ assertNull(map.get(500000L));
+ map.put(47L, "buzz");
+ assertEquals(2, map.size());
+ assertEquals("buzz", map.get(47L));
+ }
+
+
+ private static FastByIDMap<String> buildTestFastMap() {
+ FastByIDMap<String> map = new FastByIDMap<String>();
+ map.put(500000L, "alpha");
+ map.put(47L, "bang");
+ map.put(2L, "beta");
+ return map;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastIDSetTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastIDSetTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastIDSetTest.java
new file mode 100644
index 0000000..aec1738
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastIDSetTest.java
@@ -0,0 +1,162 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import com.google.common.collect.Sets;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import java.util.Collection;
+import java.util.Random;
+
+/** <p>Tests {@link FastIDSet}.</p> */
+public final class FastIDSetTest extends TasteTestCase {
+
+ @Test
+ public void testContainsAndAdd() {
+ FastIDSet set = new FastIDSet();
+ assertFalse(set.contains(1));
+ set.add(1);
+ assertTrue(set.contains(1));
+ }
+
+ @Test
+ public void testRemove() {
+ FastIDSet set = new FastIDSet();
+ set.add(1);
+ set.remove(1);
+ assertEquals(0, set.size());
+ assertTrue(set.isEmpty());
+ assertFalse(set.contains(1));
+ }
+
+ @Test
+ public void testClear() {
+ FastIDSet set = new FastIDSet();
+ set.add(1);
+ set.clear();
+ assertEquals(0, set.size());
+ assertTrue(set.isEmpty());
+ assertFalse(set.contains(1));
+ }
+
+ @Test
+ public void testSizeEmpty() {
+ FastIDSet set = new FastIDSet();
+ assertEquals(0, set.size());
+ assertTrue(set.isEmpty());
+ set.add(1);
+ assertEquals(1, set.size());
+ assertFalse(set.isEmpty());
+ set.remove(1);
+ assertEquals(0, set.size());
+ assertTrue(set.isEmpty());
+ }
+
+ @Test
+ public void testContains() {
+ FastIDSet set = buildTestFastSet();
+ assertTrue(set.contains(1));
+ assertTrue(set.contains(2));
+ assertTrue(set.contains(3));
+ assertFalse(set.contains(4));
+ }
+
+ @Test
+ public void testReservedValues() {
+ FastIDSet set = new FastIDSet();
+ try {
+ set.add(Long.MIN_VALUE);
+ fail("Should have thrown IllegalArgumentException");
+ } catch (IllegalArgumentException iae) {
+ // good
+ }
+ assertFalse(set.contains(Long.MIN_VALUE));
+ try {
+ set.add(Long.MAX_VALUE);
+ fail("Should have thrown IllegalArgumentException");
+ } catch (IllegalArgumentException iae) {
+ // good
+ }
+ assertFalse(set.contains(Long.MAX_VALUE));
+ }
+
+ @Test
+ public void testRehash() {
+ FastIDSet set = buildTestFastSet();
+ set.remove(1);
+ set.rehash();
+ assertFalse(set.contains(1));
+ }
+
+ @Test
+ public void testGrow() {
+ FastIDSet set = new FastIDSet(1);
+ set.add(1);
+ set.add(2);
+ assertTrue(set.contains(1));
+ assertTrue(set.contains(2));
+ }
+
+ @Test
+ public void testIterator() {
+ FastIDSet set = buildTestFastSet();
+ Collection<Long> expected = Sets.newHashSetWithExpectedSize(3);
+ expected.add(1L);
+ expected.add(2L);
+ expected.add(3L);
+ LongPrimitiveIterator it = set.iterator();
+ while (it.hasNext()) {
+ expected.remove(it.nextLong());
+ }
+ assertTrue(expected.isEmpty());
+ }
+
+ @Test
+ public void testVersusHashSet() {
+ FastIDSet actual = new FastIDSet(1);
+ Collection<Integer> expected = Sets.newHashSetWithExpectedSize(1000000);
+ Random r = RandomUtils.getRandom();
+ for (int i = 0; i < 1000000; i++) {
+ double d = r.nextDouble();
+ Integer key = r.nextInt(100);
+ if (d < 0.4) {
+ assertEquals(expected.contains(key), actual.contains(key));
+ } else {
+ if (d < 0.7) {
+ assertEquals(expected.add(key), actual.add(key));
+ } else {
+ assertEquals(expected.remove(key), actual.remove(key));
+ }
+ assertEquals(expected.size(), actual.size());
+ assertEquals(expected.isEmpty(), actual.isEmpty());
+ }
+ }
+ }
+
+ private static FastIDSet buildTestFastSet() {
+ FastIDSet set = new FastIDSet();
+ set.add(1);
+ set.add(2);
+ set.add(3);
+ return set;
+ }
+
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastMapTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastMapTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastMapTest.java
new file mode 100644
index 0000000..2f27483
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/FastMapTest.java
@@ -0,0 +1,228 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+/** <p>Tests {@link FastMap}.</p> */
+public final class FastMapTest extends TasteTestCase {
+
+ @Test
+ public void testPutAndGet() {
+ Map<String, String> map = new FastMap<String, String>();
+ assertNull(map.get("foo"));
+ map.put("foo", "bar");
+ assertEquals("bar", map.get("foo"));
+ }
+
+ @Test
+ public void testRemove() {
+ Map<String, String> map = new FastMap<String, String>();
+ map.put("foo", "bar");
+ map.remove("foo");
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ assertNull(map.get("foo"));
+ }
+
+ @Test
+ public void testClear() {
+ Map<String, String> map = new FastMap<String, String>();
+ map.put("foo", "bar");
+ map.clear();
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ assertNull(map.get("foo"));
+ }
+
+ @Test
+ public void testSizeEmpty() {
+ Map<String, String> map = new FastMap<String, String>();
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ map.put("foo", "bar");
+ assertEquals(1, map.size());
+ assertFalse(map.isEmpty());
+ map.remove("foo");
+ assertEquals(0, map.size());
+ assertTrue(map.isEmpty());
+ }
+
+ @Test
+ public void testContains() {
+ FastMap<String, String> map = buildTestFastMap();
+ assertTrue(map.containsKey("foo"));
+ assertTrue(map.containsKey("baz"));
+ assertTrue(map.containsKey("alpha"));
+ assertTrue(map.containsValue("bar"));
+ assertTrue(map.containsValue("bang"));
+ assertTrue(map.containsValue("beta"));
+ assertFalse(map.containsKey("something"));
+ assertFalse(map.containsValue("something"));
+ }
+
+ @Test(expected = NullPointerException.class)
+ public void testNull1() {
+ Map<String, String> map = new FastMap<String, String>();
+ assertNull(map.get(null));
+ map.put(null, "bar");
+ }
+
+ @Test(expected = NullPointerException.class)
+ public void testNull2() {
+ Map<String, String> map = new FastMap<String, String>();
+ map.put("foo", null);
+ }
+
+ @Test
+ public void testRehash() {
+ FastMap<String, String> map = buildTestFastMap();
+ map.remove("foo");
+ map.rehash();
+ assertNull(map.get("foo"));
+ assertEquals("bang", map.get("baz"));
+ }
+
+ @Test
+ public void testGrow() {
+ Map<String, String> map = new FastMap<String, String>(1, FastMap.NO_MAX_SIZE);
+ map.put("foo", "bar");
+ map.put("baz", "bang");
+ assertEquals("bar", map.get("foo"));
+ assertEquals("bang", map.get("baz"));
+ }
+
+ @Test
+ public void testKeySet() {
+ FastMap<String, String> map = buildTestFastMap();
+ Collection<String> expected = Sets.newHashSetWithExpectedSize(3);
+ expected.add("foo");
+ expected.add("baz");
+ expected.add("alpha");
+ Set<String> actual = map.keySet();
+ assertTrue(expected.containsAll(actual));
+ assertTrue(actual.containsAll(expected));
+ Iterator<String> it = actual.iterator();
+ while (it.hasNext()) {
+ String value = it.next();
+ if (!"baz".equals(value)) {
+ it.remove();
+ }
+ }
+ assertTrue(map.containsKey("baz"));
+ assertFalse(map.containsKey("foo"));
+ assertFalse(map.containsKey("alpha"));
+ }
+
+ @Test
+ public void testValues() {
+ FastMap<String, String> map = buildTestFastMap();
+ Collection<String> expected = Sets.newHashSetWithExpectedSize(3);
+ expected.add("bar");
+ expected.add("bang");
+ expected.add("beta");
+ Collection<String> actual = map.values();
+ assertTrue(expected.containsAll(actual));
+ assertTrue(actual.containsAll(expected));
+ Iterator<String> it = actual.iterator();
+ while (it.hasNext()) {
+ String value = it.next();
+ if (!"bang".equals(value)) {
+ it.remove();
+ }
+ }
+ assertTrue(map.containsValue("bang"));
+ assertFalse(map.containsValue("bar"));
+ assertFalse(map.containsValue("beta"));
+ }
+
+ @Test
+ public void testEntrySet() {
+ FastMap<String, String> map = buildTestFastMap();
+ Set<Map.Entry<String, String>> actual = map.entrySet();
+ Collection<String> expectedKeys = Sets.newHashSetWithExpectedSize(3);
+ expectedKeys.add("foo");
+ expectedKeys.add("baz");
+ expectedKeys.add("alpha");
+ Collection<String> expectedValues = Sets.newHashSetWithExpectedSize(3);
+ expectedValues.add("bar");
+ expectedValues.add("bang");
+ expectedValues.add("beta");
+ assertEquals(3, actual.size());
+ for (Map.Entry<String, String> entry : actual) {
+ expectedKeys.remove(entry.getKey());
+ expectedValues.remove(entry.getValue());
+ }
+ assertEquals(0, expectedKeys.size());
+ assertEquals(0, expectedValues.size());
+ }
+
+ @Test
+ public void testVersusHashMap() {
+ Map<Integer, String> actual = new FastMap<Integer, String>(1, 1000000);
+ Map<Integer, String> expected = Maps.newHashMapWithExpectedSize(1000000);
+ Random r = RandomUtils.getRandom();
+ for (int i = 0; i < 1000000; i++) {
+ double d = r.nextDouble();
+ Integer key = r.nextInt(100);
+ if (d < 0.4) {
+ assertEquals(expected.get(key), actual.get(key));
+ } else {
+ if (d < 0.7) {
+ assertEquals(expected.put(key, "foo"), actual.put(key, "foo"));
+ } else {
+ assertEquals(expected.remove(key), actual.remove(key));
+ }
+ assertEquals(expected.size(), actual.size());
+ assertEquals(expected.isEmpty(), actual.isEmpty());
+ }
+ }
+ }
+
+ @Test
+ public void testMaxSize() {
+ Map<String, String> map = new FastMap<String, String>(1, 1);
+ map.put("foo", "bar");
+ assertEquals(1, map.size());
+ map.put("baz", "bang");
+ assertEquals(1, map.size());
+ assertNull(map.get("foo"));
+ map.put("baz", "buzz");
+ assertEquals(1, map.size());
+ assertEquals("buzz", map.get("baz"));
+ }
+
+ private static FastMap<String, String> buildTestFastMap() {
+ FastMap<String, String> map = new FastMap<String, String>();
+ map.put("foo", "bar");
+ map.put("baz", "bang");
+ map.put("alpha", "beta");
+ return map;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageTest.java
new file mode 100644
index 0000000..1fcc800
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageTest.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+
+public final class InvertedRunningAverageTest extends TasteTestCase {
+
+ @Test
+ public void testAverage() {
+ RunningAverage avg = new FullRunningAverage();
+ RunningAverage inverted = new InvertedRunningAverage(avg);
+ assertEquals(0, inverted.getCount());
+ avg.addDatum(1.0);
+ assertEquals(1, inverted.getCount());
+ assertEquals(-1.0, inverted.getAverage(), EPSILON);
+ avg.addDatum(2.0);
+ assertEquals(2, inverted.getCount());
+ assertEquals(-1.5, inverted.getAverage(), EPSILON);
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void testUnsupported1() {
+ RunningAverage inverted = new InvertedRunningAverage(new FullRunningAverage());
+ inverted.addDatum(1.0);
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void testUnsupported2() {
+ RunningAverage inverted = new InvertedRunningAverage(new FullRunningAverage());
+ inverted.changeDatum(1.0);
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void testUnsupported3() {
+ RunningAverage inverted = new InvertedRunningAverage(new FullRunningAverage());
+ inverted.removeDatum(1.0);
+ }
+
+ @Test
+ public void testAverageAndStdDev() {
+ RunningAverageAndStdDev avg = new FullRunningAverageAndStdDev();
+ RunningAverageAndStdDev inverted = new InvertedRunningAverageAndStdDev(avg);
+ assertEquals(0, inverted.getCount());
+ avg.addDatum(1.0);
+ assertEquals(1, inverted.getCount());
+ assertEquals(-1.0, inverted.getAverage(), EPSILON);
+ avg.addDatum(2.0);
+ assertEquals(2, inverted.getCount());
+ assertEquals(-1.5, inverted.getAverage(), EPSILON);
+ assertEquals(Math.sqrt(2.0)/2.0, inverted.getStandardDeviation(), EPSILON);
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void testAndStdDevUnsupported1() {
+ RunningAverage inverted = new InvertedRunningAverageAndStdDev(new FullRunningAverageAndStdDev());
+ inverted.addDatum(1.0);
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void testAndStdDevUnsupported2() {
+ RunningAverage inverted = new InvertedRunningAverageAndStdDev(new FullRunningAverageAndStdDev());
+ inverted.changeDatum(1.0);
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void testAndStdDevUnsupported3() {
+ RunningAverage inverted = new InvertedRunningAverageAndStdDev(new FullRunningAverageAndStdDev());
+ inverted.removeDatum(1.0);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIteratorTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIteratorTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIteratorTest.java
new file mode 100644
index 0000000..7458df3
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIteratorTest.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+
+import java.util.NoSuchElementException;
+
+public final class LongPrimitiveArrayIteratorTest extends TasteTestCase {
+
+ @Test(expected = NoSuchElementException.class)
+ public void testEmpty() {
+ LongPrimitiveIterator it = new LongPrimitiveArrayIterator(new long[0]);
+ assertFalse(it.hasNext());
+ it.next();
+ }
+
+ @Test(expected = NoSuchElementException.class)
+ public void testNext() {
+ LongPrimitiveIterator it = new LongPrimitiveArrayIterator(new long[] {3,2,1});
+ assertTrue(it.hasNext());
+ assertEquals(3, (long) it.next());
+ assertTrue(it.hasNext());
+ assertEquals(2, it.nextLong());
+ assertTrue(it.hasNext());
+ assertEquals(1, (long) it.next());
+ assertFalse(it.hasNext());
+ it.nextLong();
+ }
+
+ @Test
+ public void testPeekSkip() {
+ LongPrimitiveIterator it = new LongPrimitiveArrayIterator(new long[] {3,2,1});
+ assertEquals(3, it.peek());
+ it.skip(2);
+ assertEquals(1, it.nextLong());
+ assertFalse(it.hasNext());
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/MockRefreshable.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/MockRefreshable.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/MockRefreshable.java
new file mode 100644
index 0000000..20233a7
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/MockRefreshable.java
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+
+/** A mock {@link Refreshable} which counts the number of times it has been refreshed, for use in tests. */
+final class MockRefreshable implements Refreshable, Callable<Object> {
+
+ private int callCount;
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ call();
+ }
+
+ @Override
+ public Object call() {
+ callCount++;
+ return null;
+ }
+
+ int getCallCount() {
+ return callCount;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RefreshHelperTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RefreshHelperTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RefreshHelperTest.java
new file mode 100644
index 0000000..54c97e3
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RefreshHelperTest.java
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import com.google.common.collect.Sets;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+
+import java.util.Collection;
+
+/** Tests {@link RefreshHelper} */
+public final class RefreshHelperTest extends TasteTestCase {
+
+ @Test
+ public void testCallable() {
+ MockRefreshable mock = new MockRefreshable();
+ Refreshable helper = new RefreshHelper(mock);
+ helper.refresh(null);
+ assertEquals(1, mock.getCallCount());
+ }
+
+ @Test
+ public void testNoCallable() {
+ Refreshable helper = new RefreshHelper(null);
+ helper.refresh(null);
+ }
+
+ @Test
+ public void testDependencies() {
+ RefreshHelper helper = new RefreshHelper(null);
+ MockRefreshable mock1 = new MockRefreshable();
+ MockRefreshable mock2 = new MockRefreshable();
+ helper.addDependency(mock1);
+ helper.addDependency(mock2);
+ helper.refresh(null);
+ assertEquals(1, mock1.getCallCount());
+ assertEquals(1, mock2.getCallCount());
+ }
+
+ @Test
+ public void testAlreadyRefreshed() {
+ RefreshHelper helper = new RefreshHelper(null);
+ MockRefreshable mock1 = new MockRefreshable();
+ MockRefreshable mock2 = new MockRefreshable();
+ helper.addDependency(mock1);
+ helper.addDependency(mock2);
+ Collection<Refreshable> alreadyRefreshed = Sets.newHashSetWithExpectedSize(1);
+ alreadyRefreshed.add(mock1);
+ helper.refresh(alreadyRefreshed);
+ assertEquals(0, mock1.getCallCount());
+ assertEquals(1, mock2.getCallCount());
+ }
+
+}
[38/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java b/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
new file mode 100644
index 0000000..efd233f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/AbstractVectorClassifier.java
@@ -0,0 +1,248 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Defines the interface for classifiers that take a vector as input. This is
+ * implemented as an abstract class so that it can implement a number of handy
+ * convenience methods related to classification of vectors.
+ *
+ * <p>
+ * A classifier takes an input vector and calculates the scores (usually
+ * probabilities) that the input vector belongs to one of {@code n}
+ * categories. In {@code AbstractVectorClassifier} each category is denoted
+ * by an integer {@code c} between {@code 0} and {@code n-1}
+ * (inclusive).
+ *
+ * <p>
+ * New users should start by looking at {@link #classifyFull} (not {@link #classify}).
+ *
+ */
+public abstract class AbstractVectorClassifier {
+
+ /** Minimum allowable log likelihood value. */
+ public static final double MIN_LOG_LIKELIHOOD = -100.0;
+
+ /**
+ * Returns the number of categories that a target variable can be assigned to.
+ * A vector classifier will encode it's output as an integer from
+ * {@code 0} to {@code numCategories()-1} (inclusive).
+ *
+ * @return The number of categories.
+ */
+ public abstract int numCategories();
+
+ /**
+ * Compute and return a vector containing {@code n-1} scores, where
+ * {@code n} is equal to {@code numCategories()}, given an input
+ * vector {@code instance}. Higher scores indicate that the input vector
+ * is more likely to belong to that category. The categories are denoted by
+ * the integers {@code 0} through {@code n-1} (inclusive), and the
+ * scores in the returned vector correspond to categories 1 through
+ * {@code n-1} (leaving out category 0). It is assumed that the score for
+ * category 0 is one minus the sum of the scores in the returned vector.
+ *
+ * @param instance A feature vector to be classified.
+ * @return A vector of probabilities in 1 of {@code n-1} encoding.
+ */
+ public abstract Vector classify(Vector instance);
+
+ /**
+ * Compute and return a vector of scores before applying the inverse link
+ * function. For logistic regression and other generalized linear models, this
+ * is just the linear part of the classification.
+ *
+ * <p>
+ * The implementation of this method provided by {@code AbstractVectorClassifier} throws an
+ * {@link UnsupportedOperationException}. Your subclass must explicitly override this method to support
+ * this operation.
+ *
+ * @param features A feature vector to be classified.
+ * @return A vector of scores. If transformed by the link function, these will become probabilities.
+ */
+ public Vector classifyNoLink(Vector features) {
+ throw new UnsupportedOperationException(this.getClass().getName()
+ + " doesn't support classification without a link");
+ }
+
+ /**
+ * Classifies a vector in the special case of a binary classifier where
+ * {@link #classify(Vector)} would return a vector with only one element. As
+ * such, using this method can avoid the allocation of a vector.
+ *
+ * @param instance The feature vector to be classified.
+ * @return The score for category 1.
+ *
+ * @see #classify(Vector)
+ */
+ public abstract double classifyScalar(Vector instance);
+
+ /**
+ * Computes and returns a vector containing {@code n} scores, where
+ * {@code n} is {@code numCategories()}, given an input vector
+ * {@code instance}. Higher scores indicate that the input vector is more
+ * likely to belong to the corresponding category. The categories are denoted
+ * by the integers {@code 0} through {@code n-1} (inclusive).
+ *
+ * <p>
+ * Using this method it is possible to classify an input vector, for example,
+ * by selecting the category with the largest score. If
+ * {@code classifier} is an instance of
+ * {@code AbstractVectorClassifier} and {@code input} is a
+ * {@code Vector} of features describing an element to be classified,
+ * then the following code could be used to classify {@code input}.<br>
+ * {@code
+ * Vector scores = classifier.classifyFull(input);<br>
+ * int assignedCategory = scores.maxValueIndex();<br>
+ * } Here {@code assignedCategory} is the index of the category
+ * with the maximum score.
+ *
+ * <p>
+ * If an {@code n-1} encoding is acceptable, and allocation performance
+ * is an issue, then the {@link #classify(Vector)} method is probably better
+ * to use.
+ *
+ * @see #classify(Vector)
+ * @see #classifyFull(Vector r, Vector instance)
+ *
+ * @param instance A vector of features to be classified.
+ * @return A vector of probabilities, one for each category.
+ */
+ public Vector classifyFull(Vector instance) {
+ return classifyFull(new DenseVector(numCategories()), instance);
+ }
+
+ /**
+ * Computes and returns a vector containing {@code n} scores, where
+ * {@code n} is {@code numCategories()}, given an input vector
+ * {@code instance}. Higher scores indicate that the input vector is more
+ * likely to belong to the corresponding category. The categories are denoted
+ * by the integers {@code 0} through {@code n-1} (inclusive). The
+ * main difference between this method and {@link #classifyFull(Vector)} is
+ * that this method allows a user to provide a previously allocated
+ * {@code Vector r} to store the returned scores.
+ *
+ * <p>
+ * Using this method it is possible to classify an input vector, for example,
+ * by selecting the category with the largest score. If
+ * {@code classifier} is an instance of
+ * {@code AbstractVectorClassifier}, {@code result} is a non-null
+ * {@code Vector}, and {@code input} is a {@code Vector} of
+ * features describing an element to be classified, then the following code
+ * could be used to classify {@code input}.<br>
+ * {@code
+ * Vector scores = classifier.classifyFull(result, input); // Notice that scores == result<br>
+ * int assignedCategory = scores.maxValueIndex();<br>
+ * } Here {@code assignedCategory} is the index of the category
+ * with the maximum score.
+ *
+ * @param r Where to put the results.
+ * @param instance A vector of features to be classified.
+ * @return A vector of scores/probabilities, one for each category.
+ */
+ public Vector classifyFull(Vector r, Vector instance) {
+ r.viewPart(1, numCategories() - 1).assign(classify(instance));
+ r.setQuick(0, 1.0 - r.zSum());
+ return r;
+ }
+
+
+ /**
+ * Returns n-1 probabilities, one for each categories 1 through
+ * {@code n-1}, for each row of a matrix, where {@code n} is equal
+ * to {@code numCategories()}. The probability of the missing 0-th
+ * category is 1 - rowSum(this result).
+ *
+ * @param data The matrix whose rows are the input vectors to classify
+ * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category.
+ */
+ public Matrix classify(Matrix data) {
+ Matrix r = new DenseMatrix(data.numRows(), numCategories() - 1);
+ for (int row = 0; row < data.numRows(); row++) {
+ r.assignRow(row, classify(data.viewRow(row)));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a matrix where the rows of the matrix each contain {@code n} probabilities, one for each category.
+ *
+ * @param data The matrix whose rows are the input vectors to classify
+ * @return A matrix of scores, one row per row of the input matrix, one column for each but the last category.
+ */
+ public Matrix classifyFull(Matrix data) {
+ Matrix r = new DenseMatrix(data.numRows(), numCategories());
+ for (int row = 0; row < data.numRows(); row++) {
+ classifyFull(r.viewRow(row), data.viewRow(row));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a vector of probabilities of category 1, one for each row
+ * of a matrix. This only makes sense if there are exactly two categories, but
+ * calling this method in that case can save a number of vector allocations.
+ *
+ * @param data The matrix whose rows are vectors to classify
+ * @return A vector of scores, with one value per row of the input matrix.
+ */
+ public Vector classifyScalar(Matrix data) {
+ Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories");
+
+ Vector r = new DenseVector(data.numRows());
+ for (int row = 0; row < data.numRows(); row++) {
+ r.set(row, classifyScalar(data.viewRow(row)));
+ }
+ return r;
+ }
+
+ /**
+ * Returns a measure of how good the classification for a particular example
+ * actually is.
+ *
+ * @param actual The correct category for the example.
+ * @param data The vector to be classified.
+ * @return The log likelihood of the correct answer as estimated by the current model. This will always be <= 0
+ * and larger (closer to 0) indicates better accuracy. In order to simplify code that maintains eunning averages,
+ * we bound this value at -100.
+ */
+ public double logLikelihood(int actual, Vector data) {
+ if (numCategories() == 2) {
+ double p = classifyScalar(data);
+ if (actual > 0) {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log(p));
+ } else {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log1p(-p));
+ }
+ } else {
+ Vector p = classify(data);
+ if (actual > 0) {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log(p.get(actual - 1)));
+ } else {
+ return Math.max(MIN_LOG_LIKELIHOOD, Math.log1p(-p.zSum()));
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java b/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
new file mode 100644
index 0000000..29eaa0d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/ClassifierResult.java
@@ -0,0 +1,74 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+/**
+ * Result of a document classification. The label and the associated score (usually probabilty)
+ */
+public class ClassifierResult {
+
+ private String label;
+ private double score;
+ private double logLikelihood = Double.MAX_VALUE;
+
+ public ClassifierResult() { }
+
+ public ClassifierResult(String label, double score) {
+ this.label = label;
+ this.score = score;
+ }
+
+ public ClassifierResult(String label) {
+ this.label = label;
+ }
+
+ public ClassifierResult(String label, double score, double logLikelihood) {
+ this.label = label;
+ this.score = score;
+ this.logLikelihood = logLikelihood;
+ }
+
+ public double getLogLikelihood() {
+ return logLikelihood;
+ }
+
+ public void setLogLikelihood(double logLikelihood) {
+ this.logLikelihood = logLikelihood;
+ }
+
+ public String getLabel() {
+ return label;
+ }
+
+ public double getScore() {
+ return score;
+ }
+
+ public void setLabel(String label) {
+ this.label = label;
+ }
+
+ public void setScore(double score) {
+ this.score = score;
+ }
+
+ @Override
+ public String toString() {
+ return "ClassifierResult{" + "category='" + label + '\'' + ", score=" + score + '}';
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java b/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
new file mode 100644
index 0000000..0baa4bf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/ConfusionMatrix.java
@@ -0,0 +1,444 @@
+/**
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Map;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.math3.stat.descriptive.moment.Mean;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The ConfusionMatrix Class stores the result of Classification of a Test Dataset.
+ *
+ * The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default.
+ *
+ * See http://en.wikipedia.org/wiki/Confusion_matrix for background
+ */
+public class ConfusionMatrix {
+ private static final Logger LOG = LoggerFactory.getLogger(ConfusionMatrix.class);
+ private final Map<String,Integer> labelMap = Maps.newLinkedHashMap();
+ private final int[][] confusionMatrix;
+ private int samples = 0;
+ private String defaultLabel = "unknown";
+
+ public ConfusionMatrix(Collection<String> labels, String defaultLabel) {
+ confusionMatrix = new int[labels.size() + 1][labels.size() + 1];
+ this.defaultLabel = defaultLabel;
+ int i = 0;
+ for (String label : labels) {
+ labelMap.put(label, i++);
+ }
+ labelMap.put(defaultLabel, i);
+ }
+
+ public ConfusionMatrix(Matrix m) {
+ confusionMatrix = new int[m.numRows()][m.numRows()];
+ setMatrix(m);
+ }
+
+ public int[][] getConfusionMatrix() {
+ return confusionMatrix;
+ }
+
+ public Collection<String> getLabels() {
+ return Collections.unmodifiableCollection(labelMap.keySet());
+ }
+
+ private int numLabels() {
+ return labelMap.size();
+ }
+
+ public double getAccuracy(String label) {
+ int labelId = labelMap.get(label);
+ int labelTotal = 0;
+ int correct = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ labelTotal += confusionMatrix[labelId][i];
+ if (i == labelId) {
+ correct += confusionMatrix[labelId][i];
+ }
+ }
+ return 100.0 * correct / labelTotal;
+ }
+
+ // Producer accuracy
+ public double getAccuracy() {
+ int total = 0;
+ int correct = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ for (int j = 0; j < numLabels(); j++) {
+ total += confusionMatrix[i][j];
+ if (i == j) {
+ correct += confusionMatrix[i][j];
+ }
+ }
+ }
+ return 100.0 * correct / total;
+ }
+
+ /** Sum of true positives and false negatives */
+ private int getActualNumberOfTestExamplesForClass(String label) {
+ int labelId = labelMap.get(label);
+ int sum = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ sum += confusionMatrix[labelId][i];
+ }
+ return sum;
+ }
+
+ public double getPrecision(String label) {
+ int labelId = labelMap.get(label);
+ int truePositives = confusionMatrix[labelId][labelId];
+ int falsePositives = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ if (i == labelId) {
+ continue;
+ }
+ falsePositives += confusionMatrix[i][labelId];
+ }
+
+ if (truePositives + falsePositives == 0) {
+ return 0;
+ }
+
+ return ((double) truePositives) / (truePositives + falsePositives);
+ }
+
+ public double getWeightedPrecision() {
+ double[] precisions = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ precisions[index] = getPrecision(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(precisions, weights);
+ }
+
+ public double getRecall(String label) {
+ int labelId = labelMap.get(label);
+ int truePositives = confusionMatrix[labelId][labelId];
+ int falseNegatives = 0;
+ for (int i = 0; i < numLabels(); i++) {
+ if (i == labelId) {
+ continue;
+ }
+ falseNegatives += confusionMatrix[labelId][i];
+ }
+ if (truePositives + falseNegatives == 0) {
+ return 0;
+ }
+ return ((double) truePositives) / (truePositives + falseNegatives);
+ }
+
+ public double getWeightedRecall() {
+ double[] recalls = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ recalls[index] = getRecall(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(recalls, weights);
+ }
+
+ public double getF1score(String label) {
+ double precision = getPrecision(label);
+ double recall = getRecall(label);
+ if (precision + recall == 0) {
+ return 0;
+ }
+ return 2 * precision * recall / (precision + recall);
+ }
+
+ public double getWeightedF1score() {
+ double[] f1Scores = new double[numLabels()];
+ double[] weights = new double[numLabels()];
+
+ int index = 0;
+ for (String label : labelMap.keySet()) {
+ f1Scores[index] = getF1score(label);
+ weights[index] = getActualNumberOfTestExamplesForClass(label);
+ index++;
+ }
+ return new Mean().evaluate(f1Scores, weights);
+ }
+
+ // User accuracy
+ public double getReliability() {
+ int count = 0;
+ double accuracy = 0;
+ for (String label: labelMap.keySet()) {
+ if (!label.equals(defaultLabel)) {
+ accuracy += getAccuracy(label);
+ }
+ count++;
+ }
+ return accuracy / count;
+ }
+
+ /**
+ * Accuracy v.s. randomly classifying all samples.
+ * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy())
+ * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales.
+ * Educational And Psychological Measurement 20:37-46.
+ *
+ * Formula and variable names from:
+ * http://www.yale.edu/ceo/OEFS/Accuracy.pdf
+ *
+ * @return double
+ */
+ public double getKappa() {
+ double a = 0.0;
+ double b = 0.0;
+ for (int i = 0; i < confusionMatrix.length; i++) {
+ a += confusionMatrix[i][i];
+ double br = 0;
+ for (int j = 0; j < confusionMatrix.length; j++) {
+ br += confusionMatrix[i][j];
+ }
+ double bc = 0;
+ for (int[] vec : confusionMatrix) {
+ bc += vec[i];
+ }
+ b += br * bc;
+ }
+ return (samples * a - b) / (samples * samples - b);
+ }
+
+ /**
+ * Standard deviation of normalized producer accuracy
+ * Not a standard score
+ * @return double
+ */
+ public RunningAverageAndStdDev getNormalizedStats() {
+ RunningAverageAndStdDev summer = new FullRunningAverageAndStdDev();
+ for (int d = 0; d < confusionMatrix.length; d++) {
+ double total = 0;
+ for (int j = 0; j < confusionMatrix.length; j++) {
+ total += confusionMatrix[d][j];
+ }
+ summer.addDatum(confusionMatrix[d][d] / (total + 0.000001));
+ }
+
+ return summer;
+ }
+
+ public int getCorrect(String label) {
+ int labelId = labelMap.get(label);
+ return confusionMatrix[labelId][labelId];
+ }
+
+ public int getTotal(String label) {
+ int labelId = labelMap.get(label);
+ int labelTotal = 0;
+ for (int i = 0; i < labelMap.size(); i++) {
+ labelTotal += confusionMatrix[labelId][i];
+ }
+ return labelTotal;
+ }
+
+ public void addInstance(String correctLabel, ClassifierResult classifiedResult) {
+ samples++;
+ incrementCount(correctLabel, classifiedResult.getLabel());
+ }
+
+ public void addInstance(String correctLabel, String classifiedLabel) {
+ samples++;
+ incrementCount(correctLabel, classifiedLabel);
+ }
+
+ public int getCount(String correctLabel, String classifiedLabel) {
+ if(!labelMap.containsKey(correctLabel)) {
+ LOG.warn("Label {} did not appear in the training examples", correctLabel);
+ return 0;
+ }
+ Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
+ int correctId = labelMap.get(correctLabel);
+ int classifiedId = labelMap.get(classifiedLabel);
+ return confusionMatrix[correctId][classifiedId];
+ }
+
+ public void putCount(String correctLabel, String classifiedLabel, int count) {
+ if(!labelMap.containsKey(correctLabel)) {
+ LOG.warn("Label {} did not appear in the training examples", correctLabel);
+ return;
+ }
+ Preconditions.checkArgument(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel);
+ int correctId = labelMap.get(correctLabel);
+ int classifiedId = labelMap.get(classifiedLabel);
+ if (confusionMatrix[correctId][classifiedId] == 0.0 && count != 0) {
+ samples++;
+ }
+ confusionMatrix[correctId][classifiedId] = count;
+ }
+
+ public String getDefaultLabel() {
+ return defaultLabel;
+ }
+
+ public void incrementCount(String correctLabel, String classifiedLabel, int count) {
+ putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel));
+ }
+
+ public void incrementCount(String correctLabel, String classifiedLabel) {
+ incrementCount(correctLabel, classifiedLabel, 1);
+ }
+
+ public ConfusionMatrix merge(ConfusionMatrix b) {
+ Preconditions.checkArgument(labelMap.size() == b.getLabels().size(), "The label sizes do not match");
+ for (String correctLabel : this.labelMap.keySet()) {
+ for (String classifiedLabel : this.labelMap.keySet()) {
+ incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel));
+ }
+ }
+ return this;
+ }
+
+ public Matrix getMatrix() {
+ int length = confusionMatrix.length;
+ Matrix m = new DenseMatrix(length, length);
+ for (int r = 0; r < length; r++) {
+ for (int c = 0; c < length; c++) {
+ m.set(r, c, confusionMatrix[r][c]);
+ }
+ }
+ Map<String,Integer> labels = Maps.newHashMap();
+ for (Map.Entry<String, Integer> entry : labelMap.entrySet()) {
+ labels.put(entry.getKey(), entry.getValue());
+ }
+ m.setRowLabelBindings(labels);
+ m.setColumnLabelBindings(labels);
+ return m;
+ }
+
+ public void setMatrix(Matrix m) {
+ int length = confusionMatrix.length;
+ if (m.numRows() != m.numCols()) {
+ throw new IllegalArgumentException(
+ "ConfusionMatrix: matrix(" + m.numRows() + ',' + m.numCols() + ") must be square");
+ }
+ for (int r = 0; r < length; r++) {
+ for (int c = 0; c < length; c++) {
+ confusionMatrix[r][c] = (int) Math.round(m.get(r, c));
+ }
+ }
+ Map<String,Integer> labels = m.getRowLabelBindings();
+ if (labels == null) {
+ labels = m.getColumnLabelBindings();
+ }
+ if (labels != null) {
+ String[] sorted = sortLabels(labels);
+ verifyLabels(length, sorted);
+ labelMap.clear();
+ for (int i = 0; i < length; i++) {
+ labelMap.put(sorted[i], i);
+ }
+ }
+ }
+
+ private static String[] sortLabels(Map<String,Integer> labels) {
+ String[] sorted = new String[labels.size()];
+ for (Map.Entry<String,Integer> entry : labels.entrySet()) {
+ sorted[entry.getValue()] = entry.getKey();
+ }
+ return sorted;
+ }
+
+ private static void verifyLabels(int length, String[] sorted) {
+ Preconditions.checkArgument(sorted.length == length, "One label, one row");
+ for (int i = 0; i < length; i++) {
+ if (sorted[i] == null) {
+ Preconditions.checkArgument(false, "One label, one row");
+ }
+ }
+ }
+
+ /**
+ * This is overloaded. toString() is not a formatted report you print for a manager :)
+ * Assume that if there are no default assignments, the default feature was not used
+ */
+ @Override
+ public String toString() {
+ StringBuilder returnString = new StringBuilder(200);
+ returnString.append("=======================================================").append('\n');
+ returnString.append("Confusion Matrix\n");
+ returnString.append("-------------------------------------------------------").append('\n');
+
+ int unclassified = getTotal(defaultLabel);
+ for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
+ if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
+
+ returnString.append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5)).append('\t');
+ }
+
+ returnString.append("<--Classified as").append('\n');
+ for (Map.Entry<String,Integer> entry : this.labelMap.entrySet()) {
+ if (entry.getKey().equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
+ String correctLabel = entry.getKey();
+ int labelTotal = 0;
+ for (String classifiedLabel : this.labelMap.keySet()) {
+ if (classifiedLabel.equals(defaultLabel) && unclassified == 0) {
+ continue;
+ }
+ returnString.append(
+ StringUtils.rightPad(Integer.toString(getCount(correctLabel, classifiedLabel)), 5)).append('\t');
+ labelTotal += getCount(correctLabel, classifiedLabel);
+ }
+ returnString.append(" | ").append(StringUtils.rightPad(String.valueOf(labelTotal), 6)).append('\t')
+ .append(StringUtils.rightPad(getSmallLabel(entry.getValue()), 5))
+ .append(" = ").append(correctLabel).append('\n');
+ }
+ if (unclassified > 0) {
+ returnString.append("Default Category: ").append(defaultLabel).append(": ").append(unclassified).append('\n');
+ }
+ returnString.append('\n');
+ return returnString.toString();
+ }
+
+ static String getSmallLabel(int i) {
+ int val = i;
+ StringBuilder returnString = new StringBuilder();
+ do {
+ int n = val % 26;
+ returnString.insert(0, (char) ('a' + n));
+ val /= 26;
+ } while (val > 0);
+ return returnString.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java b/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
new file mode 100644
index 0000000..af1d5e7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/OnlineLearner.java
@@ -0,0 +1,96 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import org.apache.mahout.math.Vector;
+
+import java.io.Closeable;
+
+/**
+ * The simplest interface for online learning algorithms.
+ */
+public interface OnlineLearner extends Closeable {
+ /**
+ * Updates the model using a particular target variable value and a feature vector.
+ * <p/>
+ * There may an assumption that if multiple passes through the training data are necessary, then
+ * the training examples will be presented in the same order. This is because the order of
+ * training examples may be used to assign records to different data splits for evaluation by
+ * cross-validation. Without the order invariance, records might be assigned to training and test
+ * splits and error estimates could be seriously affected.
+ * <p/>
+ * If re-ordering is necessary, then using the alternative API which allows a tracking key to be
+ * added to the training example can be used.
+ *
+ * @param actual The value of the target variable. This value should be in the half-open
+ * interval [0..n) where n is the number of target categories.
+ * @param instance The feature vector for this example.
+ */
+ void train(int actual, Vector instance);
+
+ /**
+ * Updates the model using a particular target variable value and a feature vector.
+ * <p/>
+ * There may an assumption that if multiple passes through the training data are necessary that
+ * the tracking key for a record will be the same for each pass and that there will be a
+ * relatively large number of distinct tracking keys and that the low-order bits of the tracking
+ * keys will not correlate with any of the input variables. This tracking key is used to assign
+ * training examples to different test/training splits.
+ * <p/>
+ * Examples of useful tracking keys include id-numbers for the training records derived from
+ * a database id for the base table from the which the record is derived, or the offset of
+ * the original data record in a data file.
+ *
+ * @param trackingKey The tracking key for this training example.
+ * @param groupKey An optional value that allows examples to be grouped in the computation of
+ * the update to the model.
+ * @param actual The value of the target variable. This value should be in the half-open
+ * interval [0..n) where n is the number of target categories.
+ * @param instance The feature vector for this example.
+ */
+ void train(long trackingKey, String groupKey, int actual, Vector instance);
+
+ /**
+ * Updates the model using a particular target variable value and a feature vector.
+ * <p/>
+ * There may an assumption that if multiple passes through the training data are necessary that
+ * the tracking key for a record will be the same for each pass and that there will be a
+ * relatively large number of distinct tracking keys and that the low-order bits of the tracking
+ * keys will not correlate with any of the input variables. This tracking key is used to assign
+ * training examples to different test/training splits.
+ * <p/>
+ * Examples of useful tracking keys include id-numbers for the training records derived from
+ * a database id for the base table from the which the record is derived, or the offset of
+ * the original data record in a data file.
+ *
+ * @param trackingKey The tracking key for this training example.
+ * @param actual The value of the target variable. This value should be in the half-open
+ * interval [0..n) where n is the number of target categories.
+ * @param instance The feature vector for this example.
+ */
+ void train(long trackingKey, int actual, Vector instance);
+
+ /**
+ * Prepares the classifier for classification and deallocates any temporary data structures.
+ *
+ * An online classifier should be able to accept more training after being closed, but
+ * closing the classifier may make classification more efficient.
+ */
+ @Override
+ void close();
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java b/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
new file mode 100644
index 0000000..5d8b9ed
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/RegressionResultAnalyzer.java
@@ -0,0 +1,144 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.text.DecimalFormat;
+import java.text.NumberFormat;
+import java.util.List;
+import java.util.Locale;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.lang3.StringUtils;
+
+/**
+ * ResultAnalyzer captures the classification statistics and displays in a tabular manner
+ */
+public class RegressionResultAnalyzer {
+
+ private static class Result {
+ private final double actual;
+ private final double result;
+ Result(double actual, double result) {
+ this.actual = actual;
+ this.result = result;
+ }
+ double getActual() {
+ return actual;
+ }
+ double getResult() {
+ return result;
+ }
+ }
+
+ private List<Result> results;
+
+ /**
+ *
+ * @param actual
+ * The actual answer
+ * @param result
+ * The regression result
+ */
+ public void addInstance(double actual, double result) {
+ if (results == null) {
+ results = Lists.newArrayList();
+ }
+ results.add(new Result(actual, result));
+ }
+
+ /**
+ *
+ * @param results
+ * The results table
+ */
+ public void setInstances(double[][] results) {
+ for (double[] res : results) {
+ addInstance(res[0], res[1]);
+ }
+ }
+
+ @Override
+ public String toString() {
+ double sumActual = 0.0;
+ double sumActualSquared = 0.0;
+ double sumResult = 0.0;
+ double sumResultSquared = 0.0;
+ double sumActualResult = 0.0;
+ double sumAbsolute = 0.0;
+ double sumAbsoluteSquared = 0.0;
+ int predictable = 0;
+ int unpredictable = 0;
+
+ for (Result res : results) {
+ double actual = res.getActual();
+ double result = res.getResult();
+ if (Double.isNaN(result)) {
+ unpredictable++;
+ } else {
+ sumActual += actual;
+ sumActualSquared += actual * actual;
+ sumResult += result;
+ sumResultSquared += result * result;
+ sumActualResult += actual * result;
+ double absolute = Math.abs(actual - result);
+ sumAbsolute += absolute;
+ sumAbsoluteSquared += absolute * absolute;
+ predictable++;
+ }
+ }
+
+ StringBuilder returnString = new StringBuilder();
+
+ returnString.append("=======================================================\n");
+ returnString.append("Summary\n");
+ returnString.append("-------------------------------------------------------\n");
+
+ if (predictable > 0) {
+ double varActual = sumActualSquared - sumActual * sumActual / predictable;
+ double varResult = sumResultSquared - sumResult * sumResult / predictable;
+ double varCo = sumActualResult - sumActual * sumResult / predictable;
+
+ double correlation;
+ if (varActual * varResult <= 0) {
+ correlation = 0.0;
+ } else {
+ correlation = varCo / Math.sqrt(varActual * varResult);
+ }
+
+ Locale.setDefault(Locale.US);
+ NumberFormat decimalFormatter = new DecimalFormat("0.####");
+
+ returnString.append(StringUtils.rightPad("Correlation coefficient", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(correlation), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Mean absolute error", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(sumAbsolute / predictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Root mean squared error", 40)).append(": ").append(
+ StringUtils.leftPad(decimalFormatter.format(Math.sqrt(sumAbsoluteSquared / predictable)),
+ 10)).append('\n');
+ }
+ returnString.append(StringUtils.rightPad("Predictable Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(predictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Unpredictable Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(unpredictable), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Total Regressed Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(results.size()), 10)).append('\n');
+ returnString.append('\n');
+
+ return returnString.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java b/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
new file mode 100644
index 0000000..1711f19
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/ResultAnalyzer.java
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier;
+
+import java.text.DecimalFormat;
+import java.text.NumberFormat;
+import java.util.Collection;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+
+/** ResultAnalyzer captures the classification statistics and displays in a tabular manner */
+public class ResultAnalyzer {
+
+ private final ConfusionMatrix confusionMatrix;
+ private final OnlineSummarizer summarizer;
+ private boolean hasLL;
+
+ /*
+ * === Summary ===
+ *
+ * Correctly Classified Instances 635 92.9722 % Incorrectly Classified Instances 48 7.0278 % Kappa statistic
+ * 0.923 Mean absolute error 0.0096 Root mean squared error 0.0817 Relative absolute error 9.9344 % Root
+ * relative squared error 37.2742 % Total Number of Instances 683
+ */
+ private int correctlyClassified;
+ private int incorrectlyClassified;
+
+ public ResultAnalyzer(Collection<String> labelSet, String defaultLabel) {
+ confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel);
+ summarizer = new OnlineSummarizer();
+ }
+
+ public ConfusionMatrix getConfusionMatrix() {
+ return this.confusionMatrix;
+ }
+
+ /**
+ *
+ * @param correctLabel
+ * The correct label
+ * @param classifiedResult
+ * The classified result
+ * @return whether the instance was correct or not
+ */
+ public boolean addInstance(String correctLabel, ClassifierResult classifiedResult) {
+ boolean result = correctLabel.equals(classifiedResult.getLabel());
+ if (result) {
+ correctlyClassified++;
+ } else {
+ incorrectlyClassified++;
+ }
+ confusionMatrix.addInstance(correctLabel, classifiedResult);
+ if (classifiedResult.getLogLikelihood() != Double.MAX_VALUE) {
+ summarizer.add(classifiedResult.getLogLikelihood());
+ hasLL = true;
+ }
+ return result;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder returnString = new StringBuilder();
+
+ returnString.append('\n');
+ returnString.append("=======================================================\n");
+ returnString.append("Summary\n");
+ returnString.append("-------------------------------------------------------\n");
+ int totalClassified = correctlyClassified + incorrectlyClassified;
+ double percentageCorrect = (double) 100 * correctlyClassified / totalClassified;
+ double percentageIncorrect = (double) 100 * incorrectlyClassified / totalClassified;
+ NumberFormat decimalFormatter = new DecimalFormat("0.####");
+
+ returnString.append(StringUtils.rightPad("Correctly Classified Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(correctlyClassified), 10)).append('\t').append(
+ StringUtils.leftPad(decimalFormatter.format(percentageCorrect), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Incorrectly Classified Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(incorrectlyClassified), 10)).append('\t').append(
+ StringUtils.leftPad(decimalFormatter.format(percentageIncorrect), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Total Classified Instances", 40)).append(": ").append(
+ StringUtils.leftPad(Integer.toString(totalClassified), 10)).append('\n');
+ returnString.append('\n');
+
+ returnString.append(confusionMatrix);
+ returnString.append("=======================================================\n");
+ returnString.append("Statistics\n");
+ returnString.append("-------------------------------------------------------\n");
+
+ RunningAverageAndStdDev normStats = confusionMatrix.getNormalizedStats();
+ returnString.append(StringUtils.rightPad("Kappa", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getKappa()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Accuracy", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getAccuracy()), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Reliability", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(normStats.getAverage() * 100.00000001), 10)).append("%\n");
+ returnString.append(StringUtils.rightPad("Reliability (standard deviation)", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(normStats.getStandardDeviation()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted precision", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedPrecision()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted recall", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedRecall()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("Weighted F1 score", 40)).append(
+ StringUtils.leftPad(decimalFormatter.format(confusionMatrix.getWeightedF1score()), 10)).append('\n');
+
+ if (hasLL) {
+ returnString.append(StringUtils.rightPad("Log-likelihood", 30)).append("mean : ").append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getMean()), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("25%-ile : ", 10)).append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(1)), 10)).append('\n');
+ returnString.append(StringUtils.rightPad("", 30)).append(StringUtils.rightPad("75%-ile : ", 10)).append(
+ StringUtils.leftPad(decimalFormatter.format(summarizer.getQuartile(3)), 10)).append('\n');
+ }
+
+ return returnString.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java b/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java
new file mode 100644
index 0000000..0ec5b55
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/Bagging.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.node.Node;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Arrays;
+import java.util.Random;
+
+/**
+ * Builds a tree using bagging
+ */
+public class Bagging {
+
+ private static final Logger log = LoggerFactory.getLogger(Bagging.class);
+
+ private final TreeBuilder treeBuilder;
+
+ private final Data data;
+
+ private final boolean[] sampled;
+
+ public Bagging(TreeBuilder treeBuilder, Data data) {
+ this.treeBuilder = treeBuilder;
+ this.data = data;
+ sampled = new boolean[data.size()];
+ }
+
+ /**
+ * Builds one tree
+ */
+ public Node build(Random rng) {
+ log.debug("Bagging...");
+ Arrays.fill(sampled, false);
+ Data bag = data.bagging(rng, sampled);
+
+ log.debug("Building...");
+ return treeBuilder.build(rng, bag);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java b/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java
new file mode 100644
index 0000000..137b174
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/DFUtils.java
@@ -0,0 +1,181 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.charset.Charset;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+
+/**
+ * Utility class that contains various helper methods
+ */
+public final class DFUtils {
+
+ private DFUtils() {}
+
+ /**
+ * Writes an Node[] into a DataOutput
+ * @throws java.io.IOException
+ */
+ public static void writeArray(DataOutput out, Node[] array) throws IOException {
+ out.writeInt(array.length);
+ for (Node w : array) {
+ w.write(out);
+ }
+ }
+
+ /**
+ * Reads a Node[] from a DataInput
+ * @throws java.io.IOException
+ */
+ public static Node[] readNodeArray(DataInput in) throws IOException {
+ int length = in.readInt();
+ Node[] nodes = new Node[length];
+ for (int index = 0; index < length; index++) {
+ nodes[index] = Node.read(in);
+ }
+
+ return nodes;
+ }
+
+ /**
+ * Writes a double[] into a DataOutput
+ * @throws java.io.IOException
+ */
+ public static void writeArray(DataOutput out, double[] array) throws IOException {
+ out.writeInt(array.length);
+ for (double value : array) {
+ out.writeDouble(value);
+ }
+ }
+
+ /**
+ * Reads a double[] from a DataInput
+ * @throws java.io.IOException
+ */
+ public static double[] readDoubleArray(DataInput in) throws IOException {
+ int length = in.readInt();
+ double[] array = new double[length];
+ for (int index = 0; index < length; index++) {
+ array[index] = in.readDouble();
+ }
+
+ return array;
+ }
+
+ /**
+ * Writes an int[] into a DataOutput
+ * @throws java.io.IOException
+ */
+ public static void writeArray(DataOutput out, int[] array) throws IOException {
+ out.writeInt(array.length);
+ for (int value : array) {
+ out.writeInt(value);
+ }
+ }
+
+ /**
+ * Reads an int[] from a DataInput
+ * @throws java.io.IOException
+ */
+ public static int[] readIntArray(DataInput in) throws IOException {
+ int length = in.readInt();
+ int[] array = new int[length];
+ for (int index = 0; index < length; index++) {
+ array[index] = in.readInt();
+ }
+
+ return array;
+ }
+
+ /**
+ * Return a list of all files in the output directory
+ * @throws IOException if no file is found
+ */
+ public static Path[] listOutputFiles(FileSystem fs, Path outputPath) throws IOException {
+ List<Path> outputFiles = Lists.newArrayList();
+ for (FileStatus s : fs.listStatus(outputPath, PathFilters.logsCRCFilter())) {
+ if (!s.isDir() && !s.getPath().getName().startsWith("_")) {
+ outputFiles.add(s.getPath());
+ }
+ }
+ if (outputFiles.isEmpty()) {
+ throw new IOException("No output found !");
+ }
+ return outputFiles.toArray(new Path[outputFiles.size()]);
+ }
+
+ /**
+ * Formats a time interval in milliseconds to a String in the form "hours:minutes:seconds:millis"
+ */
+ public static String elapsedTime(long milli) {
+ long seconds = milli / 1000;
+ milli %= 1000;
+
+ long minutes = seconds / 60;
+ seconds %= 60;
+
+ long hours = minutes / 60;
+ minutes %= 60;
+
+ return hours + "h " + minutes + "m " + seconds + "s " + milli;
+ }
+
+ public static void storeWritable(Configuration conf, Path path, Writable writable) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+
+ FSDataOutputStream out = fs.create(path);
+ try {
+ writable.write(out);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ /**
+ * Write a string to a path.
+ * @param conf From which the file system will be picked
+ * @param path Where the string will be written
+ * @param string The string to write
+ * @throws IOException if things go poorly
+ */
+ public static void storeString(Configuration conf, Path path, String string) throws IOException {
+ DataOutputStream out = null;
+ try {
+ out = path.getFileSystem(conf).create(path);
+ out.write(string.getBytes(Charset.defaultCharset()));
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java b/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
new file mode 100644
index 0000000..1b47ec7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/DecisionForest.java
@@ -0,0 +1,244 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataUtils;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.Node;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Represents a forest of decision trees.
+ */
+public class DecisionForest implements Writable {
+
+ private final List<Node> trees;
+
+ private DecisionForest() {
+ trees = Lists.newArrayList();
+ }
+
+ public DecisionForest(List<Node> trees) {
+ Preconditions.checkArgument(trees != null && !trees.isEmpty(), "trees argument must not be null or empty");
+
+ this.trees = trees;
+ }
+
+ List<Node> getTrees() {
+ return trees;
+ }
+
+ /**
+ * Classifies the data and calls callback for each classification
+ */
+ public void classify(Data data, double[][] predictions) {
+ Preconditions.checkArgument(data.size() == predictions.length, "predictions.length must be equal to data.size()");
+
+ if (data.isEmpty()) {
+ return; // nothing to classify
+ }
+
+ int treeId = 0;
+ for (Node tree : trees) {
+ for (int index = 0; index < data.size(); index++) {
+ if (predictions[index] == null) {
+ predictions[index] = new double[trees.size()];
+ }
+ predictions[index][treeId] = tree.classify(data.get(index));
+ }
+ treeId++;
+ }
+ }
+
+ /**
+ * predicts the label for the instance
+ *
+ * @param rng
+ * Random number generator, used to break ties randomly
+ * @return NaN if the label cannot be predicted
+ */
+ public double classify(Dataset dataset, Random rng, Instance instance) {
+ if (dataset.isNumerical(dataset.getLabelId())) {
+ double sum = 0;
+ int cnt = 0;
+ for (Node tree : trees) {
+ double prediction = tree.classify(instance);
+ if (!Double.isNaN(prediction)) {
+ sum += prediction;
+ cnt++;
+ }
+ }
+
+ if (cnt > 0) {
+ return sum / cnt;
+ } else {
+ return Double.NaN;
+ }
+ } else {
+ int[] predictions = new int[dataset.nblabels()];
+ for (Node tree : trees) {
+ double prediction = tree.classify(instance);
+ if (!Double.isNaN(prediction)) {
+ predictions[(int) prediction]++;
+ }
+ }
+
+ if (DataUtils.sum(predictions) == 0) {
+ return Double.NaN; // no prediction available
+ }
+
+ return DataUtils.maxindex(rng, predictions);
+ }
+ }
+
+ /**
+ * @return Mean number of nodes per tree
+ */
+ public long meanNbNodes() {
+ long sum = 0;
+
+ for (Node tree : trees) {
+ sum += tree.nbNodes();
+ }
+
+ return sum / trees.size();
+ }
+
+ /**
+ * @return Total number of nodes in all the trees
+ */
+ public long nbNodes() {
+ long sum = 0;
+
+ for (Node tree : trees) {
+ sum += tree.nbNodes();
+ }
+
+ return sum;
+ }
+
+ /**
+ * @return Mean maximum depth per tree
+ */
+ public long meanMaxDepth() {
+ long sum = 0;
+
+ for (Node tree : trees) {
+ sum += tree.maxDepth();
+ }
+
+ return sum / trees.size();
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (!(obj instanceof DecisionForest)) {
+ return false;
+ }
+
+ DecisionForest rf = (DecisionForest) obj;
+
+ return trees.size() == rf.getTrees().size() && trees.containsAll(rf.getTrees());
+ }
+
+ @Override
+ public int hashCode() {
+ return trees.hashCode();
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ dataOutput.writeInt(trees.size());
+ for (Node tree : trees) {
+ tree.write(dataOutput);
+ }
+ }
+
+ /**
+ * Reads the trees from the input and adds them to the existing trees
+ */
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ int size = dataInput.readInt();
+ for (int i = 0; i < size; i++) {
+ trees.add(Node.read(dataInput));
+ }
+ }
+
+ /**
+ * Read the forest from inputStream
+ * @param dataInput - input forest
+ * @return {@link org.apache.mahout.classifier.df.DecisionForest}
+ * @throws IOException
+ */
+ public static DecisionForest read(DataInput dataInput) throws IOException {
+ DecisionForest forest = new DecisionForest();
+ forest.readFields(dataInput);
+ return forest;
+ }
+
+ /**
+ * Load the forest from a single file or a directory of files
+ * @throws java.io.IOException
+ */
+ public static DecisionForest load(Configuration conf, Path forestPath) throws IOException {
+ FileSystem fs = forestPath.getFileSystem(conf);
+ Path[] files;
+ if (fs.getFileStatus(forestPath).isDir()) {
+ files = DFUtils.listOutputFiles(fs, forestPath);
+ } else {
+ files = new Path[]{forestPath};
+ }
+
+ DecisionForest forest = null;
+ for (Path path : files) {
+ FSDataInputStream dataInput = new FSDataInputStream(fs.open(path));
+ try {
+ if (forest == null) {
+ forest = read(dataInput);
+ } else {
+ forest.readFields(dataInput);
+ }
+ } finally {
+ Closeables.close(dataInput, true);
+ }
+ }
+
+ return forest;
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java b/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java
new file mode 100644
index 0000000..2a7facc
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/ErrorEstimate.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Various methods to compute from the output of a random forest
+ */
+public final class ErrorEstimate {
+
+ private ErrorEstimate() {
+ }
+
+ public static double errorRate(double[] labels, double[] predictions) {
+ Preconditions.checkArgument(labels.length == predictions.length, "labels.length != predictions.length");
+ double nberrors = 0; // number of instance that got bad predictions
+ double datasize = 0; // number of classified instances
+
+ for (int index = 0; index < labels.length; index++) {
+ if (predictions[index] == -1) {
+ continue; // instance not classified
+ }
+
+ if (predictions[index] != labels[index]) {
+ nberrors++;
+ }
+
+ datasize++;
+ }
+
+ return nberrors / datasize;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java b/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java
new file mode 100644
index 0000000..895188b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/builder/DecisionTreeBuilder.java
@@ -0,0 +1,421 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import com.google.common.collect.Sets;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+import org.apache.mahout.classifier.df.node.CategoricalNode;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.classifier.df.node.NumericalNode;
+import org.apache.mahout.classifier.df.split.IgSplit;
+import org.apache.mahout.classifier.df.split.OptIgSplit;
+import org.apache.mahout.classifier.df.split.RegressionSplit;
+import org.apache.mahout.classifier.df.split.Split;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collection;
+import java.util.Random;
+
+/**
+ * Builds a classification tree or regression tree<br>
+ * A classification tree is built when the criterion variable is the categorical attribute.<br>
+ * A regression tree is built when the criterion variable is the numerical attribute.
+ */
+public class DecisionTreeBuilder implements TreeBuilder {
+
+ private static final Logger log = LoggerFactory.getLogger(DecisionTreeBuilder.class);
+
+ private static final int[] NO_ATTRIBUTES = new int[0];
+ private static final double EPSILON = 1.0e-6;
+
+ /**
+ * indicates which CATEGORICAL attributes have already been selected in the parent nodes
+ */
+ private boolean[] selected;
+ /**
+ * number of attributes to select randomly at each node
+ */
+ private int m;
+ /**
+ * IgSplit implementation
+ */
+ private IgSplit igSplit;
+ /**
+ * tree is complemented
+ */
+ private boolean complemented = true;
+ /**
+ * minimum number for split
+ */
+ private double minSplitNum = 2.0;
+ /**
+ * minimum proportion of the total variance for split
+ */
+ private double minVarianceProportion = 1.0e-3;
+ /**
+ * full set data
+ */
+ private Data fullSet;
+ /**
+ * minimum variance for split
+ */
+ private double minVariance = Double.NaN;
+
+ public void setM(int m) {
+ this.m = m;
+ }
+
+ public void setIgSplit(IgSplit igSplit) {
+ this.igSplit = igSplit;
+ }
+
+ public void setComplemented(boolean complemented) {
+ this.complemented = complemented;
+ }
+
+ public void setMinSplitNum(int minSplitNum) {
+ this.minSplitNum = minSplitNum;
+ }
+
+ public void setMinVarianceProportion(double minVarianceProportion) {
+ this.minVarianceProportion = minVarianceProportion;
+ }
+
+ @Override
+ public Node build(Random rng, Data data) {
+ if (selected == null) {
+ selected = new boolean[data.getDataset().nbAttributes()];
+ selected[data.getDataset().getLabelId()] = true; // never select the label
+ }
+ if (m == 0) {
+ // set default m
+ double e = data.getDataset().nbAttributes() - 1;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ // regression
+ m = (int) Math.ceil(e / 3.0);
+ } else {
+ // classification
+ m = (int) Math.ceil(Math.sqrt(e));
+ }
+ }
+
+ if (data.isEmpty()) {
+ return new Leaf(Double.NaN);
+ }
+
+ double sum = 0.0;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ // regression
+ // sum and sum squared of a label is computed
+ double sumSquared = 0.0;
+ for (int i = 0; i < data.size(); i++) {
+ double label = data.getDataset().getLabel(data.get(i));
+ sum += label;
+ sumSquared += label * label;
+ }
+
+ // computes the variance
+ double var = sumSquared - (sum * sum) / data.size();
+
+ // computes the minimum variance
+ if (Double.compare(minVariance, Double.NaN) == 0) {
+ minVariance = var / data.size() * minVarianceProportion;
+ log.debug("minVariance:{}", minVariance);
+ }
+
+ // variance is compared with minimum variance
+ if ((var / data.size()) < minVariance) {
+ log.debug("variance({}) < minVariance({}) Leaf({})", var / data.size(), minVariance, sum / data.size());
+ return new Leaf(sum / data.size());
+ }
+ } else {
+ // classification
+ if (isIdentical(data)) {
+ return new Leaf(data.majorityLabel(rng));
+ }
+ if (data.identicalLabel()) {
+ return new Leaf(data.getDataset().getLabel(data.get(0)));
+ }
+ }
+
+ // store full set data
+ if (fullSet == null) {
+ fullSet = data;
+ }
+
+ int[] attributes = randomAttributes(rng, selected, m);
+ if (attributes == null || attributes.length == 0) {
+ // we tried all the attributes and could not split the data anymore
+ double label;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ // regression
+ label = sum / data.size();
+ } else {
+ // classification
+ label = data.majorityLabel(rng);
+ }
+ log.warn("attribute which can be selected is not found Leaf({})", label);
+ return new Leaf(label);
+ }
+
+ if (igSplit == null) {
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ // regression
+ igSplit = new RegressionSplit();
+ } else {
+ // classification
+ igSplit = new OptIgSplit();
+ }
+ }
+
+ // find the best split
+ Split best = null;
+ for (int attr : attributes) {
+ Split split = igSplit.computeSplit(data, attr);
+ if (best == null || best.getIg() < split.getIg()) {
+ best = split;
+ }
+ }
+
+ // information gain is near to zero.
+ if (best.getIg() < EPSILON) {
+ double label;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ label = sum / data.size();
+ } else {
+ label = data.majorityLabel(rng);
+ }
+ log.debug("ig is near to zero Leaf({})", label);
+ return new Leaf(label);
+ }
+
+ log.debug("best split attr:{}, split:{}, ig:{}", best.getAttr(), best.getSplit(), best.getIg());
+
+ boolean alreadySelected = selected[best.getAttr()];
+ if (alreadySelected) {
+ // attribute already selected
+ log.warn("attribute {} already selected in a parent node", best.getAttr());
+ }
+
+ Node childNode;
+ if (data.getDataset().isNumerical(best.getAttr())) {
+ boolean[] temp = null;
+
+ Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
+ Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));
+
+ if (loSubset.isEmpty() || hiSubset.isEmpty()) {
+ // the selected attribute did not change the data, avoid using it in the child notes
+ selected[best.getAttr()] = true;
+ } else {
+ // the data changed, so we can unselect all previousely selected NUMERICAL attributes
+ temp = selected;
+ selected = cloneCategoricalAttributes(data.getDataset(), selected);
+ }
+
+ // size of the subset is less than the minSpitNum
+ if (loSubset.size() < minSplitNum || hiSubset.size() < minSplitNum) {
+ // branch is not split
+ double label;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ label = sum / data.size();
+ } else {
+ label = data.majorityLabel(rng);
+ }
+ log.debug("branch is not split Leaf({})", label);
+ return new Leaf(label);
+ }
+
+ Node loChild = build(rng, loSubset);
+ Node hiChild = build(rng, hiSubset);
+
+ // restore the selection state of the attributes
+ if (temp != null) {
+ selected = temp;
+ } else {
+ selected[best.getAttr()] = alreadySelected;
+ }
+
+ childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
+ } else { // CATEGORICAL attribute
+ double[] values = data.values(best.getAttr());
+
+ // tree is complemented
+ Collection<Double> subsetValues = null;
+ if (complemented) {
+ subsetValues = Sets.newHashSet();
+ for (double value : values) {
+ subsetValues.add(value);
+ }
+ values = fullSet.values(best.getAttr());
+ }
+
+ int cnt = 0;
+ Data[] subsets = new Data[values.length];
+ for (int index = 0; index < values.length; index++) {
+ if (complemented && !subsetValues.contains(values[index])) {
+ continue;
+ }
+ subsets[index] = data.subset(Condition.equals(best.getAttr(), values[index]));
+ if (subsets[index].size() >= minSplitNum) {
+ cnt++;
+ }
+ }
+
+ // size of the subset is less than the minSpitNum
+ if (cnt < 2) {
+ // branch is not split
+ double label;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ label = sum / data.size();
+ } else {
+ label = data.majorityLabel(rng);
+ }
+ log.debug("branch is not split Leaf({})", label);
+ return new Leaf(label);
+ }
+
+ selected[best.getAttr()] = true;
+
+ Node[] children = new Node[values.length];
+ for (int index = 0; index < values.length; index++) {
+ if (complemented && (subsetValues == null || !subsetValues.contains(values[index]))) {
+ // tree is complemented
+ double label;
+ if (data.getDataset().isNumerical(data.getDataset().getLabelId())) {
+ label = sum / data.size();
+ } else {
+ label = data.majorityLabel(rng);
+ }
+ log.debug("complemented Leaf({})", label);
+ children[index] = new Leaf(label);
+ continue;
+ }
+ children[index] = build(rng, subsets[index]);
+ }
+
+ selected[best.getAttr()] = alreadySelected;
+
+ childNode = new CategoricalNode(best.getAttr(), values, children);
+ }
+
+ return childNode;
+ }
+
+ /**
+ * checks if all the vectors have identical attribute values. Ignore selected attributes.
+ *
+ * @return true is all the vectors are identical or the data is empty<br>
+ * false otherwise
+ */
+ private boolean isIdentical(Data data) {
+ if (data.isEmpty()) {
+ return true;
+ }
+
+ Instance instance = data.get(0);
+ for (int attr = 0; attr < selected.length; attr++) {
+ if (selected[attr]) {
+ continue;
+ }
+
+ for (int index = 1; index < data.size(); index++) {
+ if (data.get(index).get(attr) != instance.get(attr)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+ /**
+ * Make a copy of the selection state of the attributes, unselect all numerical attributes
+ *
+ * @param selected selection state to clone
+ * @return cloned selection state
+ */
+ private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) {
+ boolean[] cloned = new boolean[selected.length];
+
+ for (int i = 0; i < selected.length; i++) {
+ cloned[i] = !dataset.isNumerical(i) && selected[i];
+ }
+ cloned[dataset.getLabelId()] = true;
+
+ return cloned;
+ }
+
+ /**
+ * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes
+ *
+ * @param rng random-numbers generator
+ * @param selected attributes' state (selected or not)
+ * @param m number of attributes to choose
+ * @return list of selected attributes' indices, or null if all attributes have already been selected
+ */
+ private static int[] randomAttributes(Random rng, boolean[] selected, int m) {
+ int nbNonSelected = 0; // number of non selected attributes
+ for (boolean sel : selected) {
+ if (!sel) {
+ nbNonSelected++;
+ }
+ }
+
+ if (nbNonSelected == 0) {
+ log.warn("All attributes are selected !");
+ return NO_ATTRIBUTES;
+ }
+
+ int[] result;
+ if (nbNonSelected <= m) {
+ // return all non selected attributes
+ result = new int[nbNonSelected];
+ int index = 0;
+ for (int attr = 0; attr < selected.length; attr++) {
+ if (!selected[attr]) {
+ result[index++] = attr;
+ }
+ }
+ } else {
+ result = new int[m];
+ for (int index = 0; index < m; index++) {
+ // randomly choose a "non selected" attribute
+ int rind;
+ do {
+ rind = rng.nextInt(selected.length);
+ } while (selected[rind]);
+
+ result[index] = rind;
+ selected[rind] = true; // temporarily set the chosen attribute to be selected
+ }
+
+ // the chosen attributes are not yet selected
+ for (int attr : result) {
+ selected[attr] = false;
+ }
+ }
+
+ return result;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java b/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java
new file mode 100644
index 0000000..f03698d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/df/builder/DefaultTreeBuilder.java
@@ -0,0 +1,252 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.builder;
+
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+import org.apache.mahout.classifier.df.node.CategoricalNode;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.classifier.df.node.NumericalNode;
+import org.apache.mahout.classifier.df.split.IgSplit;
+import org.apache.mahout.classifier.df.split.OptIgSplit;
+import org.apache.mahout.classifier.df.split.Split;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Random;
+
+/**
+ * Builds a Decision Tree <br>
+ * Based on the algorithm described in the "Decision Trees" tutorials by Andrew W. Moore, available at:<br>
+ * <br>
+ * http://www.cs.cmu.edu/~awm/tutorials
+ * <br><br>
+ * This class can be used when the criterion variable is the categorical attribute.
+ */
+public class DefaultTreeBuilder implements TreeBuilder {
+
+ private static final Logger log = LoggerFactory.getLogger(DefaultTreeBuilder.class);
+
+ private static final int[] NO_ATTRIBUTES = new int[0];
+
+ /**
+ * indicates which CATEGORICAL attributes have already been selected in the parent nodes
+ */
+ private boolean[] selected;
+ /**
+ * number of attributes to select randomly at each node
+ */
+ private int m = 1;
+ /**
+ * IgSplit implementation
+ */
+ private final IgSplit igSplit;
+
+ public DefaultTreeBuilder() {
+ igSplit = new OptIgSplit();
+ }
+
+ public void setM(int m) {
+ this.m = m;
+ }
+
+ @Override
+ public Node build(Random rng, Data data) {
+
+ if (selected == null) {
+ selected = new boolean[data.getDataset().nbAttributes()];
+ selected[data.getDataset().getLabelId()] = true; // never select the label
+ }
+
+ if (data.isEmpty()) {
+ return new Leaf(-1);
+ }
+ if (isIdentical(data)) {
+ return new Leaf(data.majorityLabel(rng));
+ }
+ if (data.identicalLabel()) {
+ return new Leaf(data.getDataset().getLabel(data.get(0)));
+ }
+
+ int[] attributes = randomAttributes(rng, selected, m);
+ if (attributes == null || attributes.length == 0) {
+ // we tried all the attributes and could not split the data anymore
+ return new Leaf(data.majorityLabel(rng));
+ }
+
+ // find the best split
+ Split best = null;
+ for (int attr : attributes) {
+ Split split = igSplit.computeSplit(data, attr);
+ if (best == null || best.getIg() < split.getIg()) {
+ best = split;
+ }
+ }
+
+ boolean alreadySelected = selected[best.getAttr()];
+ if (alreadySelected) {
+ // attribute already selected
+ log.warn("attribute {} already selected in a parent node", best.getAttr());
+ }
+
+ Node childNode;
+ if (data.getDataset().isNumerical(best.getAttr())) {
+ boolean[] temp = null;
+
+ Data loSubset = data.subset(Condition.lesser(best.getAttr(), best.getSplit()));
+ Data hiSubset = data.subset(Condition.greaterOrEquals(best.getAttr(), best.getSplit()));
+
+ if (loSubset.isEmpty() || hiSubset.isEmpty()) {
+ // the selected attribute did not change the data, avoid using it in the child notes
+ selected[best.getAttr()] = true;
+ } else {
+ // the data changed, so we can unselect all previousely selected NUMERICAL attributes
+ temp = selected;
+ selected = cloneCategoricalAttributes(data.getDataset(), selected);
+ }
+
+ Node loChild = build(rng, loSubset);
+ Node hiChild = build(rng, hiSubset);
+
+ // restore the selection state of the attributes
+ if (temp != null) {
+ selected = temp;
+ } else {
+ selected[best.getAttr()] = alreadySelected;
+ }
+
+ childNode = new NumericalNode(best.getAttr(), best.getSplit(), loChild, hiChild);
+ } else { // CATEGORICAL attribute
+ selected[best.getAttr()] = true;
+
+ double[] values = data.values(best.getAttr());
+ Node[] children = new Node[values.length];
+
+ for (int index = 0; index < values.length; index++) {
+ Data subset = data.subset(Condition.equals(best.getAttr(), values[index]));
+ children[index] = build(rng, subset);
+ }
+
+ selected[best.getAttr()] = alreadySelected;
+
+ childNode = new CategoricalNode(best.getAttr(), values, children);
+ }
+
+ return childNode;
+ }
+
+ /**
+ * checks if all the vectors have identical attribute values. Ignore selected attributes.
+ *
+ * @return true is all the vectors are identical or the data is empty<br>
+ * false otherwise
+ */
+ private boolean isIdentical(Data data) {
+ if (data.isEmpty()) {
+ return true;
+ }
+
+ Instance instance = data.get(0);
+ for (int attr = 0; attr < selected.length; attr++) {
+ if (selected[attr]) {
+ continue;
+ }
+
+ for (int index = 1; index < data.size(); index++) {
+ if (data.get(index).get(attr) != instance.get(attr)) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+ }
+
+
+ /**
+ * Make a copy of the selection state of the attributes, unselect all numerical attributes
+ *
+ * @param selected selection state to clone
+ * @return cloned selection state
+ */
+ private static boolean[] cloneCategoricalAttributes(Dataset dataset, boolean[] selected) {
+ boolean[] cloned = new boolean[selected.length];
+
+ for (int i = 0; i < selected.length; i++) {
+ cloned[i] = !dataset.isNumerical(i) && selected[i];
+ }
+
+ return cloned;
+ }
+
+ /**
+ * Randomly selects m attributes to consider for split, excludes IGNORED and LABEL attributes
+ *
+ * @param rng random-numbers generator
+ * @param selected attributes' state (selected or not)
+ * @param m number of attributes to choose
+ * @return list of selected attributes' indices, or null if all attributes have already been selected
+ */
+ protected static int[] randomAttributes(Random rng, boolean[] selected, int m) {
+ int nbNonSelected = 0; // number of non selected attributes
+ for (boolean sel : selected) {
+ if (!sel) {
+ nbNonSelected++;
+ }
+ }
+
+ if (nbNonSelected == 0) {
+ log.warn("All attributes are selected !");
+ return NO_ATTRIBUTES;
+ }
+
+ int[] result;
+ if (nbNonSelected <= m) {
+ // return all non selected attributes
+ result = new int[nbNonSelected];
+ int index = 0;
+ for (int attr = 0; attr < selected.length; attr++) {
+ if (!selected[attr]) {
+ result[index++] = attr;
+ }
+ }
+ } else {
+ result = new int[m];
+ for (int index = 0; index < m; index++) {
+ // randomly choose a "non selected" attribute
+ int rind;
+ do {
+ rind = rng.nextInt(selected.length);
+ } while (selected[rind]);
+
+ result[index] = rind;
+ selected[rind] = true; // temporarily set the chosen attribute to be selected
+ }
+
+ // the chosen attributes are not yet selected
+ for (int attr : result) {
+ selected[attr] = false;
+ }
+ }
+
+ return result;
+ }
+}
[23/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java
new file mode 100644
index 0000000..a8fa091
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/MahalanobisDistanceMeasure.java
@@ -0,0 +1,204 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.parameters.ClassParameter;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.common.parameters.PathParameter;
+import org.apache.mahout.math.Algebra;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.SingularValueDecomposition;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+//See http://en.wikipedia.org/wiki/Mahalanobis_distance for details
+public class MahalanobisDistanceMeasure implements DistanceMeasure {
+
+ private Matrix inverseCovarianceMatrix;
+ private Vector meanVector;
+
+ private ClassParameter vectorClass;
+ private ClassParameter matrixClass;
+ private List<Parameter<?>> parameters;
+ private Parameter<Path> inverseCovarianceFile;
+ private Parameter<Path> meanVectorFile;
+
+ /*public MahalanobisDistanceMeasure(Vector meanVector,Matrix inputMatrix, boolean inversionNeeded)
+ {
+ this.meanVector=meanVector;
+ if (inversionNeeded)
+ setCovarianceMatrix(inputMatrix);
+ else
+ setInverseCovarianceMatrix(inputMatrix);
+ }*/
+
+ @Override
+ public void configure(Configuration jobConf) {
+ if (parameters == null) {
+ ParameteredGeneralizations.configureParameters(this, jobConf);
+ }
+ try {
+ if (inverseCovarianceFile.get() != null) {
+ FileSystem fs = FileSystem.get(inverseCovarianceFile.get().toUri(), jobConf);
+ MatrixWritable inverseCovarianceMatrix =
+ ClassUtils.instantiateAs((Class<? extends MatrixWritable>) matrixClass.get(), MatrixWritable.class);
+ if (!fs.exists(inverseCovarianceFile.get())) {
+ throw new FileNotFoundException(inverseCovarianceFile.get().toString());
+ }
+ DataInputStream in = fs.open(inverseCovarianceFile.get());
+ try {
+ inverseCovarianceMatrix.readFields(in);
+ } finally {
+ Closeables.close(in, true);
+ }
+ this.inverseCovarianceMatrix = inverseCovarianceMatrix.get();
+ Preconditions.checkArgument(this.inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized");
+ }
+
+ if (meanVectorFile.get() != null) {
+ FileSystem fs = FileSystem.get(meanVectorFile.get().toUri(), jobConf);
+ VectorWritable meanVector =
+ ClassUtils.instantiateAs((Class<? extends VectorWritable>) vectorClass.get(), VectorWritable.class);
+ if (!fs.exists(meanVectorFile.get())) {
+ throw new FileNotFoundException(meanVectorFile.get().toString());
+ }
+ DataInputStream in = fs.open(meanVectorFile.get());
+ try {
+ meanVector.readFields(in);
+ } finally {
+ Closeables.close(in, true);
+ }
+ this.meanVector = meanVector.get();
+ Preconditions.checkArgument(this.meanVector != null, "meanVector not initialized");
+ }
+
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return parameters;
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ parameters = Lists.newArrayList();
+ inverseCovarianceFile = new PathParameter(prefix, "inverseCovarianceFile", jobConf, null,
+ "Path on DFS to a file containing the inverse covariance matrix.");
+ parameters.add(inverseCovarianceFile);
+
+ matrixClass = new ClassParameter(prefix, "maxtrixClass", jobConf, DenseMatrix.class,
+ "Class<Matix> file specified in parameter inverseCovarianceFile has been serialized with.");
+ parameters.add(matrixClass);
+
+ meanVectorFile = new PathParameter(prefix, "meanVectorFile", jobConf, null,
+ "Path on DFS to a file containing the mean Vector.");
+ parameters.add(meanVectorFile);
+
+ vectorClass = new ClassParameter(prefix, "vectorClass", jobConf, DenseVector.class,
+ "Class file specified in parameter meanVectorFile has been serialized with.");
+ parameters.add(vectorClass);
+ }
+
+ /**
+ * @param v The vector to compute the distance to
+ * @return Mahalanobis distance of a multivariate vector
+ */
+ public double distance(Vector v) {
+ return Math.sqrt(v.minus(meanVector).dot(Algebra.mult(inverseCovarianceMatrix, v.minus(meanVector))));
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ if (v1.size() != v2.size()) {
+ throw new CardinalityException(v1.size(), v2.size());
+ }
+ return Math.sqrt(v1.minus(v2).dot(Algebra.mult(inverseCovarianceMatrix, v1.minus(v2))));
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+
+ public void setInverseCovarianceMatrix(Matrix inverseCovarianceMatrix) {
+ Preconditions.checkArgument(inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized");
+ this.inverseCovarianceMatrix = inverseCovarianceMatrix;
+ }
+
+
+ /**
+ * Computes the inverse covariance from the input covariance matrix given in input.
+ *
+ * @param m A covariance matrix.
+ * @throws IllegalArgumentException if <tt>eigen values equal to 0 found</tt>.
+ */
+ public void setCovarianceMatrix(Matrix m) {
+ if (m.numRows() != m.numCols()) {
+ throw new CardinalityException(m.numRows(), m.numCols());
+ }
+ // See http://www.mlahanas.de/Math/svd.htm for details,
+ // which specifically details the case of covariance matrix inversion
+ // Complexity: O(min(nm2,mn2))
+ SingularValueDecomposition svd = new SingularValueDecomposition(m);
+ Matrix sInv = svd.getS();
+ // Inverse Diagonal Elems
+ for (int i = 0; i < sInv.numRows(); i++) {
+ double diagElem = sInv.get(i, i);
+ if (diagElem > 0.0) {
+ sInv.set(i, i, 1 / diagElem);
+ } else {
+ throw new IllegalStateException("Eigen Value equals to 0 found.");
+ }
+ }
+ inverseCovarianceMatrix = svd.getU().times(sInv.times(svd.getU().transpose()));
+ Preconditions.checkArgument(inverseCovarianceMatrix != null, "inverseCovarianceMatrix not initialized");
+ }
+
+ public Matrix getInverseCovarianceMatrix() {
+ return inverseCovarianceMatrix;
+ }
+
+ public void setMeanVector(Vector meanVector) {
+ Preconditions.checkArgument(meanVector != null, "meanVector not initialized");
+ this.meanVector = meanVector;
+ }
+
+ public Vector getMeanVector() {
+ return meanVector;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java
new file mode 100644
index 0000000..5c32fcf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/ManhattanDistanceMeasure.java
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.CardinalityException;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * This class implements a "manhattan distance" metric by summing the absolute values of the difference
+ * between each coordinate
+ */
+public class ManhattanDistanceMeasure implements DistanceMeasure {
+
+ public static double distance(double[] p1, double[] p2) {
+ double result = 0.0;
+ for (int i = 0; i < p1.length; i++) {
+ result += Math.abs(p2[i] - p1[i]);
+ }
+ return result;
+ }
+
+ @Override
+ public void configure(Configuration job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ // nothing to do
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ if (v1.size() != v2.size()) {
+ throw new CardinalityException(v1.size(), v2.size());
+ }
+ return v1.aggregate(v2, Functions.PLUS, Functions.MINUS_ABS);
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java
new file mode 100644
index 0000000..3a57f2f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/MinkowskiDistanceMeasure.java
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.util.Collection;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.DoubleParameter;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Implement Minkowski distance, a real-valued generalization of the
+ * integral L(n) distances: Manhattan = L1, Euclidean = L2.
+ * For high numbers of dimensions, very high exponents give more useful distances.
+ *
+ * Note: Math.pow is clever about integer-valued doubles.
+ **/
+public class MinkowskiDistanceMeasure implements DistanceMeasure {
+
+ private static final double EXPONENT = 3.0;
+
+ private List<Parameter<?>> parameters;
+ private double exponent = EXPONENT;
+
+ public MinkowskiDistanceMeasure() {
+ }
+
+ public MinkowskiDistanceMeasure(double exponent) {
+ this.exponent = exponent;
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration conf) {
+ parameters = Lists.newArrayList();
+ Parameter<?> param =
+ new DoubleParameter(prefix, "exponent", conf, EXPONENT, "Exponent for Fractional Lagrange distance");
+ parameters.add(param);
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return parameters;
+ }
+
+ @Override
+ public void configure(Configuration jobConf) {
+ if (parameters == null) {
+ ParameteredGeneralizations.configureParameters(this, jobConf);
+ }
+ }
+
+ public double getExponent() {
+ return exponent;
+ }
+
+ public void setExponent(double exponent) {
+ this.exponent = exponent;
+ }
+
+ /**
+ * Math.pow is clever about integer-valued doubles
+ */
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ return Math.pow(v1.aggregate(v2, Functions.PLUS, Functions.minusAbsPow(exponent)), 1.0 / exponent);
+ }
+
+ // TODO: how?
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO - can this use centroidLengthSquare somehow?
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java
new file mode 100644
index 0000000..66da121
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/SquaredEuclideanDistanceMeasure.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Like {@link EuclideanDistanceMeasure} but it does not take the square root.
+ * <p/>
+ * Thus, it is not actually the Euclidean Distance, but it is saves on computation when you only need the
+ * distance for comparison and don't care about the actual value as a distance.
+ */
+public class SquaredEuclideanDistanceMeasure implements DistanceMeasure {
+
+ @Override
+ public void configure(Configuration job) {
+ // nothing to do
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ // nothing to do
+ }
+
+ @Override
+ public double distance(Vector v1, Vector v2) {
+ return v2.getDistanceSquared(v1);
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return centroidLengthSquare - 2 * v.dot(centroid) + v.getLengthSquared();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java
new file mode 100644
index 0000000..cfeb119
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/TanimotoDistanceMeasure.java
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Tanimoto coefficient implementation.
+ *
+ * http://en.wikipedia.org/wiki/Jaccard_index
+ */
+public class TanimotoDistanceMeasure extends WeightedDistanceMeasure {
+
+ /**
+ * Calculates the distance between two vectors.
+ *
+ * The coefficient (a measure of similarity) is: T(a, b) = a.b / (|a|^2 + |b|^2 - a.b)
+ *
+ * The distance d(a,b) = 1 - T(a,b)
+ *
+ * @return 0 for perfect match, > 0 for greater distance
+ */
+ @Override
+ public double distance(Vector a, Vector b) {
+ double ab;
+ double denominator;
+ if (getWeights() != null) {
+ ab = a.times(b).aggregate(getWeights(), Functions.PLUS, Functions.MULT);
+ denominator = a.aggregate(getWeights(), Functions.PLUS, Functions.MULT_SQUARE_LEFT)
+ + b.aggregate(getWeights(), Functions.PLUS, Functions.MULT_SQUARE_LEFT)
+ - ab;
+ } else {
+ ab = b.dot(a); // b is SequentialAccess
+ denominator = a.getLengthSquared() + b.getLengthSquared() - ab;
+ }
+
+ if (denominator < ab) { // correct for fp round-off: distance >= 0
+ denominator = ab;
+ }
+ if (denominator > 0) {
+ // denominator == 0 only when dot(a,a) == dot(b,b) == dot(a,b) == 0
+ return 1.0 - ab / denominator;
+ } else {
+ return 0.0;
+ }
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java
new file mode 100644
index 0000000..0c1d2cd
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/WeightedDistanceMeasure.java
@@ -0,0 +1,97 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import java.io.DataInputStream;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.parameters.ClassParameter;
+import org.apache.mahout.common.parameters.Parameter;
+import org.apache.mahout.common.parameters.PathParameter;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/** Abstract implementation of DistanceMeasure with support for weights. */
+public abstract class WeightedDistanceMeasure implements DistanceMeasure {
+
+ private List<Parameter<?>> parameters;
+ private Parameter<Path> weightsFile;
+ private ClassParameter vectorClass;
+ private Vector weights;
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) {
+ parameters = Lists.newArrayList();
+ weightsFile = new PathParameter(prefix, "weightsFile", jobConf, null,
+ "Path on DFS to a file containing the weights.");
+ parameters.add(weightsFile);
+ vectorClass = new ClassParameter(prefix, "vectorClass", jobConf, DenseVector.class,
+ "Class<Vector> file specified in parameter weightsFile has been serialized with.");
+ parameters.add(vectorClass);
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return parameters;
+ }
+
+ @Override
+ public void configure(Configuration jobConf) {
+ if (parameters == null) {
+ ParameteredGeneralizations.configureParameters(this, jobConf);
+ }
+ try {
+ if (weightsFile.get() != null) {
+ FileSystem fs = FileSystem.get(weightsFile.get().toUri(), jobConf);
+ VectorWritable weights =
+ ClassUtils.instantiateAs((Class<? extends VectorWritable>) vectorClass.get(), VectorWritable.class);
+ if (!fs.exists(weightsFile.get())) {
+ throw new FileNotFoundException(weightsFile.get().toString());
+ }
+ DataInputStream in = fs.open(weightsFile.get());
+ try {
+ weights.readFields(in);
+ } finally {
+ Closeables.close(in, true);
+ }
+ this.weights = weights.get();
+ }
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ public Vector getWeights() {
+ return weights;
+ }
+
+ public void setWeights(Vector weights) {
+ this.weights = weights;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java
new file mode 100644
index 0000000..c6889e2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/WeightedEuclideanDistanceMeasure.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * This class implements a Euclidean distance metric by summing the square root of the squared differences
+ * between each coordinate, optionally adding weights.
+ */
+public class WeightedEuclideanDistanceMeasure extends WeightedDistanceMeasure {
+
+ @Override
+ public double distance(Vector p1, Vector p2) {
+ double result = 0;
+ Vector res = p2.minus(p1);
+ Vector theWeights = getWeights();
+ if (theWeights == null) {
+ for (Element elt : res.nonZeroes()) {
+ result += elt.get() * elt.get();
+ }
+ } else {
+ for (Element elt : res.nonZeroes()) {
+ result += elt.get() * elt.get() * theWeights.get(elt.index());
+ }
+ }
+ return Math.sqrt(result);
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java b/mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java
new file mode 100644
index 0000000..2c280e2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/distance/WeightedManhattanDistanceMeasure.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.distance;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * This class implements a "Manhattan distance" metric by summing the absolute values of the difference
+ * between each coordinate, optionally with weights.
+ */
+public class WeightedManhattanDistanceMeasure extends WeightedDistanceMeasure {
+
+ @Override
+ public double distance(Vector p1, Vector p2) {
+ double result = 0;
+
+ Vector res = p2.minus(p1);
+ if (getWeights() == null) {
+ for (Element elt : res.nonZeroes()) {
+ result += Math.abs(elt.get());
+ }
+
+ } else {
+ for (Element elt : res.nonZeroes()) {
+ result += Math.abs(elt.get() * getWeights().get(elt.index()));
+ }
+ }
+
+ return result;
+ }
+
+ @Override
+ public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
+ return distance(centroid, v); // TODO
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java
new file mode 100644
index 0000000..73cc821
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/CopyConstructorIterator.java
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.util.Iterator;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Iterators;
+
+/**
+ * An iterator that copies the values in an underlying iterator by finding an appropriate copy constructor.
+ */
+public final class CopyConstructorIterator<T> extends ForwardingIterator<T> {
+
+ private final Iterator<T> delegate;
+ private Constructor<T> constructor;
+
+ public CopyConstructorIterator(Iterator<? extends T> copyFrom) {
+ this.delegate = Iterators.transform(
+ copyFrom,
+ new Function<T,T>() {
+ @Override
+ public T apply(T from) {
+ if (constructor == null) {
+ Class<T> elementClass = (Class<T>) from.getClass();
+ try {
+ constructor = elementClass.getConstructor(elementClass);
+ } catch (NoSuchMethodException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+ try {
+ return constructor.newInstance(from);
+ } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+ });
+ }
+
+ @Override
+ protected Iterator<T> delegate() {
+ return delegate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java
new file mode 100644
index 0000000..658c1f1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/CountingIterator.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import com.google.common.collect.AbstractIterator;
+
+/**
+ * Iterates over the integers from 0 through {@code to-1}.
+ */
+public final class CountingIterator extends AbstractIterator<Integer> {
+
+ private int count;
+ private final int to;
+
+ public CountingIterator(int to) {
+ this.to = to;
+ }
+
+ @Override
+ protected Integer computeNext() {
+ if (count < to) {
+ return count++;
+ } else {
+ return endOfData();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java b/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java
new file mode 100644
index 0000000..cfc18d6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterable.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.Charset;
+import java.util.Iterator;
+
+import com.google.common.base.Charsets;
+
+/**
+ * Iterable representing the lines of a text file. It can produce an {@link Iterator} over those lines. This
+ * assumes the text file's lines are delimited in a manner consistent with how {@link java.io.BufferedReader}
+ * defines lines.
+ *
+ */
+public final class FileLineIterable implements Iterable<String> {
+
+ private final InputStream is;
+ private final Charset encoding;
+ private final boolean skipFirstLine;
+ private final String origFilename;
+
+ /** Creates a {@link FileLineIterable} over a given file, assuming a UTF-8 encoding. */
+ public FileLineIterable(File file) throws IOException {
+ this(file, Charsets.UTF_8, false);
+ }
+
+ /** Creates a {@link FileLineIterable} over a given file, assuming a UTF-8 encoding. */
+ public FileLineIterable(File file, boolean skipFirstLine) throws IOException {
+ this(file, Charsets.UTF_8, skipFirstLine);
+ }
+
+ /** Creates a {@link FileLineIterable} over a given file, using the given encoding. */
+ public FileLineIterable(File file, Charset encoding, boolean skipFirstLine) throws IOException {
+ this(FileLineIterator.getFileInputStream(file), encoding, skipFirstLine);
+ }
+
+ public FileLineIterable(InputStream is) {
+ this(is, Charsets.UTF_8, false);
+ }
+
+ public FileLineIterable(InputStream is, boolean skipFirstLine) {
+ this(is, Charsets.UTF_8, skipFirstLine);
+ }
+
+ public FileLineIterable(InputStream is, Charset encoding, boolean skipFirstLine) {
+ this.is = is;
+ this.encoding = encoding;
+ this.skipFirstLine = skipFirstLine;
+ this.origFilename = "";
+ }
+
+ public FileLineIterable(InputStream is, Charset encoding, boolean skipFirstLine, String filename) {
+ this.is = is;
+ this.encoding = encoding;
+ this.skipFirstLine = skipFirstLine;
+ this.origFilename = filename;
+ }
+
+
+ @Override
+ public Iterator<String> iterator() {
+ try {
+ return new FileLineIterator(is, encoding, skipFirstLine, this.origFilename);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java
new file mode 100644
index 0000000..b7cc51e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/FileLineIterator.java
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.io.BufferedReader;
+import java.io.Closeable;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.charset.Charset;
+import java.util.zip.GZIPInputStream;
+import java.util.zip.ZipInputStream;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.AbstractIterator;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.mahout.cf.taste.impl.common.SkippingIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Iterates over the lines of a text file. This assumes the text file's lines are delimited in a manner
+ * consistent with how {@link BufferedReader} defines lines.
+ * <p/>
+ * This class will uncompress files that end in .zip or .gz accordingly, too.
+ */
+public final class FileLineIterator extends AbstractIterator<String> implements SkippingIterator<String>, Closeable {
+
+ private final BufferedReader reader;
+
+ private static final Logger log = LoggerFactory.getLogger(FileLineIterator.class);
+
+ /**
+ * Creates a {@link FileLineIterator} over a given file, assuming a UTF-8 encoding.
+ *
+ * @throws java.io.FileNotFoundException if the file does not exist
+ * @throws IOException
+ * if the file cannot be read
+ */
+ public FileLineIterator(File file) throws IOException {
+ this(file, Charsets.UTF_8, false);
+ }
+
+ /**
+ * Creates a {@link FileLineIterator} over a given file, assuming a UTF-8 encoding.
+ *
+ * @throws java.io.FileNotFoundException if the file does not exist
+ * @throws IOException if the file cannot be read
+ */
+ public FileLineIterator(File file, boolean skipFirstLine) throws IOException {
+ this(file, Charsets.UTF_8, skipFirstLine);
+ }
+
+ /**
+ * Creates a {@link FileLineIterator} over a given file, using the given encoding.
+ *
+ * @throws java.io.FileNotFoundException if the file does not exist
+ * @throws IOException if the file cannot be read
+ */
+ public FileLineIterator(File file, Charset encoding, boolean skipFirstLine) throws IOException {
+ this(getFileInputStream(file), encoding, skipFirstLine);
+ }
+
+ public FileLineIterator(InputStream is) throws IOException {
+ this(is, Charsets.UTF_8, false);
+ }
+
+ public FileLineIterator(InputStream is, boolean skipFirstLine) throws IOException {
+ this(is, Charsets.UTF_8, skipFirstLine);
+ }
+
+ public FileLineIterator(InputStream is, Charset encoding, boolean skipFirstLine) throws IOException {
+ reader = new BufferedReader(new InputStreamReader(is, encoding));
+ if (skipFirstLine) {
+ reader.readLine();
+ }
+ }
+
+ public FileLineIterator(InputStream is, Charset encoding, boolean skipFirstLine, String filename)
+ throws IOException {
+ InputStream compressedInputStream;
+
+ if ("gz".equalsIgnoreCase(Files.getFileExtension(filename.toLowerCase()))) {
+ compressedInputStream = new GZIPInputStream(is);
+ } else if ("zip".equalsIgnoreCase(Files.getFileExtension(filename.toLowerCase()))) {
+ compressedInputStream = new ZipInputStream(is);
+ } else {
+ compressedInputStream = is;
+ }
+
+ reader = new BufferedReader(new InputStreamReader(compressedInputStream, encoding));
+ if (skipFirstLine) {
+ reader.readLine();
+ }
+ }
+
+ static InputStream getFileInputStream(File file) throws IOException {
+ InputStream is = new FileInputStream(file);
+ String name = file.getName();
+ if ("gz".equalsIgnoreCase(Files.getFileExtension(name.toLowerCase()))) {
+ return new GZIPInputStream(is);
+ } else if ("zip".equalsIgnoreCase(Files.getFileExtension(name.toLowerCase()))) {
+ return new ZipInputStream(is);
+ } else {
+ return is;
+ }
+ }
+
+ @Override
+ protected String computeNext() {
+ String line;
+ try {
+ line = reader.readLine();
+ } catch (IOException ioe) {
+ try {
+ close();
+ } catch (IOException e) {
+ log.error(e.getMessage(), e);
+ }
+ throw new IllegalStateException(ioe);
+ }
+ return line == null ? endOfData() : line;
+ }
+
+
+ @Override
+ public void skip(int n) {
+ try {
+ for (int i = 0; i < n; i++) {
+ if (reader.readLine() == null) {
+ break;
+ }
+ }
+ } catch (IOException ioe) {
+ try {
+ close();
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ endOfData();
+ Closeables.close(reader, true);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java
new file mode 100644
index 0000000..1905654
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/FixedSizeSamplingIterator.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.RandomUtils;
+
+/**
+ * Sample a fixed number of elements from an Iterator. The results can appear in any order.
+ */
+public final class FixedSizeSamplingIterator<T> extends ForwardingIterator<T> {
+
+ private final Iterator<T> delegate;
+
+ public FixedSizeSamplingIterator(int size, Iterator<T> source) {
+ List<T> buf = Lists.newArrayListWithCapacity(size);
+ int sofar = 0;
+ Random random = RandomUtils.getRandom();
+ while (source.hasNext()) {
+ T v = source.next();
+ sofar++;
+ if (buf.size() < size) {
+ buf.add(v);
+ } else {
+ int position = random.nextInt(sofar);
+ if (position < buf.size()) {
+ buf.set(position, v);
+ }
+ }
+ }
+ delegate = buf.iterator();
+ }
+
+ @Override
+ protected Iterator<T> delegate() {
+ return delegate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java b/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java
new file mode 100644
index 0000000..46ef411
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterable.java
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Iterator;
+
+/**
+ * Wraps an {@link Iterable} whose {@link Iterable#iterator()} returns only some subset of the elements that
+ * it would, as determined by a iterator rate parameter.
+ */
+public final class SamplingIterable<T> implements Iterable<T> {
+
+ private final Iterable<? extends T> delegate;
+ private final double samplingRate;
+
+ public SamplingIterable(Iterable<? extends T> delegate, double samplingRate) {
+ this.delegate = delegate;
+ this.samplingRate = samplingRate;
+ }
+
+ @Override
+ public Iterator<T> iterator() {
+ return new SamplingIterator<T>(delegate.iterator(), samplingRate);
+ }
+
+ public static <T> Iterable<T> maybeWrapIterable(Iterable<T> delegate, double samplingRate) {
+ return samplingRate >= 1.0 ? delegate : new SamplingIterable<T>(delegate, samplingRate);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java
new file mode 100644
index 0000000..2ba46fd
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/SamplingIterator.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Iterator;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.AbstractIterator;
+import org.apache.commons.math3.distribution.PascalDistribution;
+import org.apache.mahout.cf.taste.impl.common.SkippingIterator;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+
+/**
+ * Wraps an {@link Iterator} and returns only some subset of the elements that it would, as determined by a
+ * iterator rate parameter.
+ */
+public final class SamplingIterator<T> extends AbstractIterator<T> {
+
+ private final PascalDistribution geometricDistribution;
+ private final Iterator<? extends T> delegate;
+
+ public SamplingIterator(Iterator<? extends T> delegate, double samplingRate) {
+ this(RandomUtils.getRandom(), delegate, samplingRate);
+ }
+
+ public SamplingIterator(RandomWrapper random, Iterator<? extends T> delegate, double samplingRate) {
+ Preconditions.checkNotNull(delegate);
+ Preconditions.checkArgument(samplingRate > 0.0 && samplingRate <= 1.0,
+ "Must be: 0.0 < samplingRate <= 1.0. But samplingRate = " + samplingRate);
+ // Geometric distribution is special case of negative binomial (aka Pascal) with r=1:
+ geometricDistribution = new PascalDistribution(random.getRandomGenerator(), 1, samplingRate);
+ this.delegate = delegate;
+ }
+
+ @Override
+ protected T computeNext() {
+ int toSkip = geometricDistribution.sample();
+ if (delegate instanceof SkippingIterator<?>) {
+ SkippingIterator<? extends T> skippingDelegate = (SkippingIterator<? extends T>) delegate;
+ skippingDelegate.skip(toSkip);
+ if (skippingDelegate.hasNext()) {
+ return skippingDelegate.next();
+ }
+ } else {
+ for (int i = 0; i < toSkip && delegate.hasNext(); i++) {
+ delegate.next();
+ }
+ if (delegate.hasNext()) {
+ return delegate.next();
+ }
+ }
+ return endOfData();
+ }
+
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java
new file mode 100644
index 0000000..c4ddf7b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/StableFixedSizeSamplingIterator.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+
+/**
+ * Sample a fixed number of elements from an Iterator. The results will appear in the original order at some
+ * cost in time and memory relative to a FixedSizeSampler.
+ */
+public class StableFixedSizeSamplingIterator<T> extends ForwardingIterator<T> {
+
+ private final Iterator<T> delegate;
+
+ public StableFixedSizeSamplingIterator(int size, Iterator<T> source) {
+ List<Pair<Integer,T>> buf = Lists.newArrayListWithCapacity(size);
+ int sofar = 0;
+ Random random = RandomUtils.getRandom();
+ while (source.hasNext()) {
+ T v = source.next();
+ sofar++;
+ if (buf.size() < size) {
+ buf.add(new Pair<>(sofar, v));
+ } else {
+ int position = random.nextInt(sofar);
+ if (position < buf.size()) {
+ buf.set(position, new Pair<>(sofar, v));
+ }
+ }
+ }
+
+ Collections.sort(buf);
+ delegate = Iterators.transform(buf.iterator(),
+ new Function<Pair<Integer,T>,T>() {
+ @Override
+ public T apply(Pair<Integer,T> from) {
+ return from.getSecond();
+ }
+ });
+ }
+
+ @Override
+ protected Iterator<T> delegate() {
+ return delegate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java
new file mode 100644
index 0000000..73b841e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/StringRecordIterator.java
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+import java.util.regex.Pattern;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Iterators;
+import org.apache.mahout.common.Pair;
+
+public class StringRecordIterator extends ForwardingIterator<Pair<List<String>,Long>> {
+
+ private static final Long ONE = 1L;
+
+ private final Pattern splitter;
+ private final Iterator<Pair<List<String>,Long>> delegate;
+
+ public StringRecordIterator(Iterable<String> stringIterator, String pattern) {
+ this.splitter = Pattern.compile(pattern);
+ delegate = Iterators.transform(
+ stringIterator.iterator(),
+ new Function<String,Pair<List<String>,Long>>() {
+ @Override
+ public Pair<List<String>,Long> apply(String from) {
+ String[] items = splitter.split(from);
+ return new Pair<>(Arrays.asList(items), ONE);
+ }
+ });
+ }
+
+ @Override
+ protected Iterator<Pair<List<String>,Long>> delegate() {
+ return delegate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathFilters.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathFilters.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathFilters.java
new file mode 100644
index 0000000..19f78b5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathFilters.java
@@ -0,0 +1,81 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator.sequencefile;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+
+/**
+ * Supplies some useful and repeatedly-used instances of {@link PathFilter}.
+ */
+public final class PathFilters {
+
+ private static final PathFilter PART_FILE_INSTANCE = new PathFilter() {
+ @Override
+ public boolean accept(Path path) {
+ String name = path.getName();
+ return name.startsWith("part-") && !name.endsWith(".crc");
+ }
+ };
+
+ /**
+ * Pathfilter to read the final clustering file.
+ */
+ private static final PathFilter CLUSTER_FINAL = new PathFilter() {
+ @Override
+ public boolean accept(Path path) {
+ String name = path.getName();
+ return name.startsWith("clusters-") && name.endsWith("-final");
+ }
+ };
+
+ private static final PathFilter LOGS_CRC_INSTANCE = new PathFilter() {
+ @Override
+ public boolean accept(Path path) {
+ String name = path.getName();
+ return !(name.endsWith(".crc") || name.startsWith(".") || name.startsWith("_"));
+ }
+ };
+
+ private PathFilters() {
+ }
+
+ /**
+ * @return {@link PathFilter} that accepts paths whose file name starts with "part-". Excludes
+ * ".crc" files.
+ */
+ public static PathFilter partFilter() {
+ return PART_FILE_INSTANCE;
+ }
+
+ /**
+ * @return {@link PathFilter} that accepts paths whose file name starts with "part-" and ends with "-final".
+ */
+ public static PathFilter finalPartFilter() {
+ return CLUSTER_FINAL;
+ }
+
+ /**
+ * @return {@link PathFilter} that rejects paths whose file name starts with "_" (e.g. Cloudera
+ * _SUCCESS files or Hadoop _logs), or "." (e.g. local hidden files), or ends with ".crc"
+ */
+ public static PathFilter logsCRCFilter() {
+ return LOGS_CRC_INSTANCE;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathType.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathType.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathType.java
new file mode 100644
index 0000000..7ea713e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/PathType.java
@@ -0,0 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator.sequencefile;
+
+/**
+ * Used by {@link SequenceFileDirIterable} and the like to select whether the input path specifies a
+ * directory to list, or a glob pattern.
+ */
+public enum PathType {
+ GLOB,
+ LIST,
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirIterable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirIterable.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirIterable.java
new file mode 100644
index 0000000..ca4d6b8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirIterable.java
@@ -0,0 +1,84 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator.sequencefile;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.Pair;
+
+/**
+ * <p>{@link Iterable} counterpart to {@link SequenceFileDirIterator}.</p>
+ */
+public final class SequenceFileDirIterable<K extends Writable,V extends Writable> implements Iterable<Pair<K,V>> {
+
+ private final Path path;
+ private final PathType pathType;
+ private final PathFilter filter;
+ private final Comparator<FileStatus> ordering;
+ private final boolean reuseKeyValueInstances;
+ private final Configuration conf;
+
+ public SequenceFileDirIterable(Path path, PathType pathType, Configuration conf) {
+ this(path, pathType, null, conf);
+ }
+
+ public SequenceFileDirIterable(Path path, PathType pathType, PathFilter filter, Configuration conf) {
+ this(path, pathType, filter, null, false, conf);
+ }
+
+ /**
+ * @param path file to iterate over
+ * @param pathType whether or not to treat path as a directory ({@link PathType#LIST}) or
+ * glob pattern ({@link PathType#GLOB})
+ * @param filter if not null, specifies sequence files to be ignored by the iteration
+ * @param ordering if not null, specifies the order in which to iterate over matching sequence files
+ * @param reuseKeyValueInstances if true, reuses instances of the value object instead of creating a new
+ * one for each read from the file
+ */
+ public SequenceFileDirIterable(Path path,
+ PathType pathType,
+ PathFilter filter,
+ Comparator<FileStatus> ordering,
+ boolean reuseKeyValueInstances,
+ Configuration conf) {
+ this.path = path;
+ this.pathType = pathType;
+ this.filter = filter;
+ this.ordering = ordering;
+ this.reuseKeyValueInstances = reuseKeyValueInstances;
+ this.conf = conf;
+ }
+
+ @Override
+ public Iterator<Pair<K,V>> iterator() {
+ try {
+ return new SequenceFileDirIterator<>(path, pathType, filter, ordering, reuseKeyValueInstances, conf);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(path.toString(), ioe);
+ }
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirIterator.java
new file mode 100644
index 0000000..cf6a871
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirIterator.java
@@ -0,0 +1,136 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator.sequencefile;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.Pair;
+
+/**
+ * Like {@link SequenceFileIterator}, but iterates not just over one sequence file, but many. The input path
+ * may be specified as a directory of files to read, or as a glob pattern. The set of files may be optionally
+ * restricted with a {@link PathFilter}.
+ */
+public final class SequenceFileDirIterator<K extends Writable,V extends Writable>
+ extends ForwardingIterator<Pair<K,V>> implements Closeable {
+
+ private static final FileStatus[] NO_STATUSES = new FileStatus[0];
+
+ private Iterator<Pair<K,V>> delegate;
+ private final List<SequenceFileIterator<K,V>> iterators;
+
+ /**
+ * Multifile sequence file iterator where files are specified explicitly by
+ * path parameters.
+ */
+ public SequenceFileDirIterator(Path[] path,
+ boolean reuseKeyValueInstances,
+ Configuration conf) throws IOException {
+
+ iterators = Lists.newArrayList();
+ // we assume all files should exist, otherwise we will bail out.
+ FileSystem fs = FileSystem.get(path[0].toUri(), conf);
+ FileStatus[] statuses = new FileStatus[path.length];
+ for (int i = 0; i < statuses.length; i++) {
+ statuses[i] = fs.getFileStatus(path[i]);
+ }
+ init(statuses, reuseKeyValueInstances, conf);
+ }
+
+ /**
+ * Constructor that uses either {@link FileSystem#listStatus(Path)} or
+ * {@link FileSystem#globStatus(Path)} to obtain list of files to iterate over
+ * (depending on pathType parameter).
+ */
+ public SequenceFileDirIterator(Path path,
+ PathType pathType,
+ PathFilter filter,
+ Comparator<FileStatus> ordering,
+ boolean reuseKeyValueInstances,
+ Configuration conf) throws IOException {
+
+ FileStatus[] statuses = HadoopUtil.getFileStatus(path, pathType, filter, ordering, conf);
+ iterators = Lists.newArrayList();
+ init(statuses, reuseKeyValueInstances, conf);
+ }
+
+ private void init(FileStatus[] statuses,
+ final boolean reuseKeyValueInstances,
+ final Configuration conf) {
+
+ /*
+ * prevent NPEs. Unfortunately, Hadoop would return null for list if nothing
+ * was qualified. In this case, which is a corner case, we should assume an
+ * empty iterator, not an NPE.
+ */
+ if (statuses == null) {
+ statuses = NO_STATUSES;
+ }
+
+ Iterator<FileStatus> fileStatusIterator = Iterators.forArray(statuses);
+
+ Iterator<Iterator<Pair<K, V>>> fsIterators =
+ Iterators.transform(fileStatusIterator,
+ new Function<FileStatus, Iterator<Pair<K, V>>>() {
+ @Override
+ public Iterator<Pair<K, V>> apply(FileStatus from) {
+ try {
+ SequenceFileIterator<K, V> iterator = new SequenceFileIterator<>(from.getPath(),
+ reuseKeyValueInstances, conf);
+ iterators.add(iterator);
+ return iterator;
+ } catch (IOException ioe) {
+ throw new IllegalStateException(from.getPath().toString(), ioe);
+ }
+ }
+ });
+
+ Collections.reverse(iterators); // close later in reverse order
+
+ delegate = Iterators.concat(fsIterators);
+ }
+
+ @Override
+ protected Iterator<Pair<K,V>> delegate() {
+ return delegate;
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(iterators);
+ iterators.clear();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirValueIterable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirValueIterable.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirValueIterable.java
new file mode 100644
index 0000000..1cb4ebc
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirValueIterable.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator.sequencefile;
+
+import java.io.IOException;
+import java.util.Comparator;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Writable;
+
+/**
+ * <p>{@link Iterable} counterpart to {@link SequenceFileDirValueIterator}.</p>
+ */
+public final class SequenceFileDirValueIterable<V extends Writable> implements Iterable<V> {
+
+ private final Path path;
+ private final PathType pathType;
+ private final PathFilter filter;
+ private final Comparator<FileStatus> ordering;
+ private final boolean reuseKeyValueInstances;
+ private final Configuration conf;
+
+ public SequenceFileDirValueIterable(Path path, PathType pathType, Configuration conf) {
+ this(path, pathType, null, conf);
+ }
+
+ public SequenceFileDirValueIterable(Path path, PathType pathType, PathFilter filter, Configuration conf) {
+ this(path, pathType, filter, null, false, conf);
+ }
+
+ /**
+ * @param path file to iterate over
+ * @param pathType whether or not to treat path as a directory ({@link PathType#LIST}) or
+ * glob pattern ({@link PathType#GLOB})
+ * @param filter if not null, specifies sequence files to be ignored by the iteration
+ * @param ordering if not null, specifies the order in which to iterate over matching sequence files
+ * @param reuseKeyValueInstances if true, reuses instances of the value object instead of creating a new
+ * one for each read from the file
+ */
+ public SequenceFileDirValueIterable(Path path,
+ PathType pathType,
+ PathFilter filter,
+ Comparator<FileStatus> ordering,
+ boolean reuseKeyValueInstances,
+ Configuration conf) {
+ this.path = path;
+ this.pathType = pathType;
+ this.filter = filter;
+ this.ordering = ordering;
+ this.reuseKeyValueInstances = reuseKeyValueInstances;
+ this.conf = conf;
+ }
+
+ @Override
+ public Iterator<V> iterator() {
+ try {
+ return new SequenceFileDirValueIterator<>(path, pathType, filter, ordering, reuseKeyValueInstances, conf);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(path.toString(), ioe);
+ }
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirValueIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirValueIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirValueIterator.java
new file mode 100644
index 0000000..908c8bb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileDirValueIterator.java
@@ -0,0 +1,159 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator.sequencefile;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.PathFilter;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.IOUtils;
+
+/**
+ * Like {@link SequenceFileValueIterator}, but iterates not just over one
+ * sequence file, but many. The input path may be specified as a directory of
+ * files to read, or as a glob pattern. The set of files may be optionally
+ * restricted with a {@link PathFilter}.
+ */
+public final class SequenceFileDirValueIterator<V extends Writable> extends
+ ForwardingIterator<V> implements Closeable {
+
+ private static final FileStatus[] NO_STATUSES = new FileStatus[0];
+
+ private Iterator<V> delegate;
+ private final List<SequenceFileValueIterator<V>> iterators;
+
+ /**
+ * Constructor that uses either {@link FileSystem#listStatus(Path)} or
+ * {@link FileSystem#globStatus(Path)} to obtain list of files to iterate over
+ * (depending on pathType parameter).
+ */
+ public SequenceFileDirValueIterator(Path path,
+ PathType pathType,
+ PathFilter filter,
+ Comparator<FileStatus> ordering,
+ boolean reuseKeyValueInstances,
+ Configuration conf) throws IOException {
+ FileStatus[] statuses;
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ if (filter == null) {
+ statuses = pathType == PathType.GLOB ? fs.globStatus(path) : fs.listStatus(path);
+ } else {
+ statuses = pathType == PathType.GLOB ? fs.globStatus(path, filter) : fs.listStatus(path, filter);
+ }
+ iterators = Lists.newArrayList();
+ init(statuses, ordering, reuseKeyValueInstances, conf);
+ }
+
+ /**
+ * Multifile sequence file iterator where files are specified explicitly by
+ * path parameters.
+ */
+ public SequenceFileDirValueIterator(Path[] path,
+ Comparator<FileStatus> ordering,
+ boolean reuseKeyValueInstances,
+ Configuration conf) throws IOException {
+
+ iterators = Lists.newArrayList();
+ /*
+ * we assume all files should exist, otherwise we will bail out.
+ */
+ FileSystem fs = FileSystem.get(path[0].toUri(), conf);
+ FileStatus[] statuses = new FileStatus[path.length];
+ for (int i = 0; i < statuses.length; i++) {
+ statuses[i] = fs.getFileStatus(path[i]);
+ }
+ init(statuses, ordering, reuseKeyValueInstances, conf);
+ }
+
+ private void init(FileStatus[] statuses,
+ Comparator<FileStatus> ordering,
+ final boolean reuseKeyValueInstances,
+ final Configuration conf) throws IOException {
+
+ /*
+ * prevent NPEs. Unfortunately, Hadoop would return null for list if nothing
+ * was qualified. In this case, which is a corner case, we should assume an
+ * empty iterator, not an NPE.
+ */
+ if (statuses == null) {
+ statuses = NO_STATUSES;
+ }
+
+ if (ordering != null) {
+ Arrays.sort(statuses, ordering);
+ }
+ Iterator<FileStatus> fileStatusIterator = Iterators.forArray(statuses);
+
+ try {
+
+ Iterator<Iterator<V>> fsIterators =
+ Iterators.transform(fileStatusIterator,
+ new Function<FileStatus, Iterator<V>>() {
+ @Override
+ public Iterator<V> apply(FileStatus from) {
+ try {
+ SequenceFileValueIterator<V> iterator = new SequenceFileValueIterator<>(from.getPath(),
+ reuseKeyValueInstances, conf);
+ iterators.add(iterator);
+ return iterator;
+ } catch (IOException ioe) {
+ throw new IllegalStateException(from.getPath().toString(), ioe);
+ }
+ }
+ });
+
+ Collections.reverse(iterators); // close later in reverse order
+
+ delegate = Iterators.concat(fsIterators);
+
+ } finally {
+ /*
+ * prevent file handle leaks in case one of handles fails to open. If some
+ * of the files fail to open, constructor will fail and close() will never
+ * be called. Thus, those handles that did open in constructor, would leak
+ * out, unless we specifically handle it here.
+ */
+ IOUtils.close(iterators);
+ }
+ }
+
+ @Override
+ protected Iterator<V> delegate() {
+ return delegate;
+ }
+
+ @Override
+ public void close() throws IOException {
+ IOUtils.close(iterators);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileIterable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileIterable.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileIterable.java
new file mode 100644
index 0000000..f17c2a1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileIterable.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator.sequencefile;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.Pair;
+
+/**
+ * <p>{@link Iterable} counterpart to {@link SequenceFileIterator}.</p>
+ */
+public final class SequenceFileIterable<K extends Writable,V extends Writable> implements Iterable<Pair<K,V>> {
+
+ private final Path path;
+ private final boolean reuseKeyValueInstances;
+ private final Configuration conf;
+
+ /**
+ * Like {@link #SequenceFileIterable(Path, boolean, Configuration)} but key and value instances are not reused
+ * by default.
+ *
+ * @param path file to iterate over
+ */
+ public SequenceFileIterable(Path path, Configuration conf) {
+ this(path, false, conf);
+ }
+
+ /**
+ * @param path file to iterate over
+ * @param reuseKeyValueInstances if true, reuses instances of the key and value object instead of creating a new
+ * one for each read from the file
+ */
+ public SequenceFileIterable(Path path, boolean reuseKeyValueInstances, Configuration conf) {
+ this.path = path;
+ this.reuseKeyValueInstances = reuseKeyValueInstances;
+ this.conf = conf;
+ }
+
+ @Override
+ public Iterator<Pair<K, V>> iterator() {
+ try {
+ return new SequenceFileIterator<>(path, reuseKeyValueInstances, conf);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(path.toString(), ioe);
+ }
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileIterator.java
new file mode 100644
index 0000000..bc5c549
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileIterator.java
@@ -0,0 +1,118 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator.sequencefile;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import com.google.common.collect.AbstractIterator;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.util.ReflectionUtils;
+import org.apache.mahout.common.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>{@link java.util.Iterator} over a {@link SequenceFile}'s keys and values, as a {@link Pair}
+ * containing key and value.</p>
+ */
+public final class SequenceFileIterator<K extends Writable,V extends Writable>
+ extends AbstractIterator<Pair<K,V>> implements Closeable {
+
+ private final SequenceFile.Reader reader;
+ private final Configuration conf;
+ private final Class<K> keyClass;
+ private final Class<V> valueClass;
+ private final boolean noValue;
+ private K key;
+ private V value;
+ private final boolean reuseKeyValueInstances;
+
+ private static final Logger log = LoggerFactory.getLogger(SequenceFileIterator.class);
+
+ /**
+ * @throws IOException if path can't be read, or its key or value class can't be instantiated
+ */
+
+ public SequenceFileIterator(Path path, boolean reuseKeyValueInstances, Configuration conf) throws IOException {
+ key = null;
+ value = null;
+ FileSystem fs = path.getFileSystem(conf);
+ path = path.makeQualified(fs);
+ reader = new SequenceFile.Reader(fs, path, conf);
+ this.conf = conf;
+ keyClass = (Class<K>) reader.getKeyClass();
+ valueClass = (Class<V>) reader.getValueClass();
+ noValue = NullWritable.class.equals(valueClass);
+ this.reuseKeyValueInstances = reuseKeyValueInstances;
+ }
+
+ public Class<K> getKeyClass() {
+ return keyClass;
+ }
+
+ public Class<V> getValueClass() {
+ return valueClass;
+ }
+
+ @Override
+ public void close() throws IOException {
+ key = null;
+ value = null;
+ Closeables.close(reader, true);
+
+ endOfData();
+ }
+
+ @Override
+ protected Pair<K,V> computeNext() {
+ if (!reuseKeyValueInstances || value == null) {
+ key = ReflectionUtils.newInstance(keyClass, conf);
+ if (!noValue) {
+ value = ReflectionUtils.newInstance(valueClass, conf);
+ }
+ }
+ try {
+ boolean available;
+ if (noValue) {
+ available = reader.next(key);
+ } else {
+ available = reader.next(key, value);
+ }
+ if (!available) {
+ close();
+ return null;
+ }
+ return new Pair<>(key, value);
+ } catch (IOException ioe) {
+ try {
+ close();
+ } catch (IOException e) {
+ log.error(e.getMessage(), e);
+ }
+ throw new IllegalStateException(ioe);
+ }
+ }
+
+}
[29/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java b/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
new file mode 100644
index 0000000..c6c8427
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java
@@ -0,0 +1,324 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.fuzzykmeans;
+
+import java.io.IOException;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassificationDriver;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.clustering.iterator.ClusterIterator;
+import org.apache.mahout.clustering.iterator.ClusteringPolicy;
+import org.apache.mahout.clustering.iterator.FuzzyKMeansClusteringPolicy;
+import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
+import org.apache.mahout.clustering.topdown.PathDirectory;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class FuzzyKMeansDriver extends AbstractJob {
+
+ public static final String M_OPTION = "m";
+
+ private static final Logger log = LoggerFactory.getLogger(FuzzyKMeansDriver.class);
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new FuzzyKMeansDriver(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+ addOption(DefaultOptionCreator.clustersInOption()
+ .withDescription("The input centroids, as Vectors. Must be a SequenceFile of Writable, Cluster/Canopy. "
+ + "If k is also specified, then a random set of vectors will be selected"
+ + " and written out to this path first")
+ .create());
+ addOption(DefaultOptionCreator.numClustersOption()
+ .withDescription("The k in k-Means. If specified, then a random selection of k Vectors will be chosen"
+ + " as the Centroid and written to the clusters input path.").create());
+ addOption(DefaultOptionCreator.convergenceOption().create());
+ addOption(DefaultOptionCreator.maxIterationsOption().create());
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addOption(M_OPTION, M_OPTION, "coefficient normalization factor, must be greater than 1", true);
+ addOption(DefaultOptionCreator.clusteringOption().create());
+ addOption(DefaultOptionCreator.emitMostLikelyOption().create());
+ addOption(DefaultOptionCreator.thresholdOption().create());
+ addOption(DefaultOptionCreator.methodOption().create());
+ addOption(DefaultOptionCreator.useSetRandomSeedOption().create());
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path clusters = new Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
+ Path output = getOutputPath();
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ if (measureClass == null) {
+ measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+ }
+ double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION));
+ float fuzziness = Float.parseFloat(getOption(M_OPTION));
+
+ int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ boolean emitMostLikely = Boolean.parseBoolean(getOption(DefaultOptionCreator.EMIT_MOST_LIKELY_OPTION));
+ double threshold = Double.parseDouble(getOption(DefaultOptionCreator.THRESHOLD_OPTION));
+ DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
+
+ if (hasOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)) {
+ int numClusters = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION));
+
+ Long seed = null;
+ if (hasOption(DefaultOptionCreator.RANDOM_SEED)) {
+ seed = Long.parseLong(getOption(DefaultOptionCreator.RANDOM_SEED));
+ }
+
+ clusters = RandomSeedGenerator.buildRandom(getConf(), input, clusters, numClusters, measure, seed);
+ }
+
+ boolean runClustering = hasOption(DefaultOptionCreator.CLUSTERING_OPTION);
+ boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
+ DefaultOptionCreator.SEQUENTIAL_METHOD);
+
+ run(getConf(),
+ input,
+ clusters,
+ output,
+ convergenceDelta,
+ maxIterations,
+ fuzziness,
+ runClustering,
+ emitMostLikely,
+ threshold,
+ runSequential);
+ return 0;
+ }
+
+ /**
+ * Iterate over the input vectors to produce clusters and, if requested, use the
+ * results of the final iteration to cluster the input vectors.
+ *
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for initial & computed clusters
+ * @param output
+ * the directory pathname for output points
+ * @param convergenceDelta
+* the convergence delta value
+ * @param maxIterations
+* the maximum number of iterations
+ * @param m
+* the fuzzification factor, see
+* http://en.wikipedia.org/wiki/Data_clustering#Fuzzy_c-means_clustering
+ * @param runClustering
+* true if points are to be clustered after iterations complete
+ * @param emitMostLikely
+* a boolean if true emit only most likely cluster for each point
+ * @param threshold
+* a double threshold value emits all clusters having greater pdf (emitMostLikely = false)
+ * @param runSequential if true run in sequential execution mode
+ */
+ public static void run(Path input,
+ Path clustersIn,
+ Path output,
+ double convergenceDelta,
+ int maxIterations,
+ float m,
+ boolean runClustering,
+ boolean emitMostLikely,
+ double threshold,
+ boolean runSequential) throws IOException, ClassNotFoundException, InterruptedException {
+ Configuration conf = new Configuration();
+ Path clustersOut = buildClusters(conf,
+ input,
+ clustersIn,
+ output,
+ convergenceDelta,
+ maxIterations,
+ m,
+ runSequential);
+ if (runClustering) {
+ log.info("Clustering ");
+ clusterData(conf, input,
+ clustersOut,
+ output,
+ convergenceDelta,
+ m,
+ emitMostLikely,
+ threshold,
+ runSequential);
+ }
+ }
+
+ /**
+ * Iterate over the input vectors to produce clusters and, if requested, use the
+ * results of the final iteration to cluster the input vectors.
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for initial & computed clusters
+ * @param output
+ * the directory pathname for output points
+ * @param convergenceDelta
+* the convergence delta value
+ * @param maxIterations
+* the maximum number of iterations
+ * @param m
+* the fuzzification factor, see
+* http://en.wikipedia.org/wiki/Data_clustering#Fuzzy_c-means_clustering
+ * @param runClustering
+* true if points are to be clustered after iterations complete
+ * @param emitMostLikely
+* a boolean if true emit only most likely cluster for each point
+ * @param threshold
+* a double threshold value emits all clusters having greater pdf (emitMostLikely = false)
+ * @param runSequential if true run in sequential execution mode
+ */
+ public static void run(Configuration conf,
+ Path input,
+ Path clustersIn,
+ Path output,
+ double convergenceDelta,
+ int maxIterations,
+ float m,
+ boolean runClustering,
+ boolean emitMostLikely,
+ double threshold,
+ boolean runSequential)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ Path clustersOut =
+ buildClusters(conf, input, clustersIn, output, convergenceDelta, maxIterations, m, runSequential);
+ if (runClustering) {
+ log.info("Clustering");
+ clusterData(conf,
+ input,
+ clustersOut,
+ output,
+ convergenceDelta,
+ m,
+ emitMostLikely,
+ threshold,
+ runSequential);
+ }
+ }
+
+ /**
+ * Iterate over the input vectors to produce cluster directories for each iteration
+ *
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the file pathname for initial cluster centers
+ * @param output
+ * the directory pathname for output points
+ * @param convergenceDelta
+ * the convergence delta value
+ * @param maxIterations
+ * the maximum number of iterations
+ * @param m
+ * the fuzzification factor, see
+ * http://en.wikipedia.org/wiki/Data_clustering#Fuzzy_c-means_clustering
+ * @param runSequential if true run in sequential execution mode
+ *
+ * @return the Path of the final clusters directory
+ */
+ public static Path buildClusters(Configuration conf,
+ Path input,
+ Path clustersIn,
+ Path output,
+ double convergenceDelta,
+ int maxIterations,
+ float m,
+ boolean runSequential)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ List<Cluster> clusters = Lists.newArrayList();
+ FuzzyKMeansUtil.configureWithClusterInfo(conf, clustersIn, clusters);
+
+ if (conf == null) {
+ conf = new Configuration();
+ }
+
+ if (clusters.isEmpty()) {
+ throw new IllegalStateException("No input clusters found in " + clustersIn + ". Check your -c argument.");
+ }
+
+ Path priorClustersPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
+ ClusteringPolicy policy = new FuzzyKMeansClusteringPolicy(m, convergenceDelta);
+ ClusterClassifier prior = new ClusterClassifier(clusters, policy);
+ prior.writeToSeqFiles(priorClustersPath);
+
+ if (runSequential) {
+ ClusterIterator.iterateSeq(conf, input, priorClustersPath, output, maxIterations);
+ } else {
+ ClusterIterator.iterateMR(conf, input, priorClustersPath, output, maxIterations);
+ }
+ return output;
+ }
+
+ /**
+ * Run the job using supplied arguments
+ *
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for input clusters
+ * @param output
+ * the directory pathname for output points
+ * @param convergenceDelta
+* the convergence delta value
+ * @param emitMostLikely
+* a boolean if true emit only most likely cluster for each point
+ * @param threshold
+* a double threshold value emits all clusters having greater pdf (emitMostLikely = false)
+ * @param runSequential if true run in sequential execution mode
+ */
+ public static void clusterData(Configuration conf,
+ Path input,
+ Path clustersIn,
+ Path output,
+ double convergenceDelta,
+ float m,
+ boolean emitMostLikely,
+ double threshold,
+ boolean runSequential)
+ throws IOException, ClassNotFoundException, InterruptedException {
+
+ ClusterClassifier.writePolicy(new FuzzyKMeansClusteringPolicy(m, convergenceDelta), clustersIn);
+ ClusterClassificationDriver.run(conf, input, output, new Path(output, PathDirectory.CLUSTERED_POINTS_DIRECTORY),
+ threshold, emitMostLikely, runSequential);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java b/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java
new file mode 100644
index 0000000..25621bb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java
@@ -0,0 +1,76 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.fuzzykmeans;
+
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.clustering.kmeans.Kluster;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+
+final class FuzzyKMeansUtil {
+
+ private FuzzyKMeansUtil() {}
+
+ /**
+ * Create a list of SoftClusters from whatever type is passed in as the prior
+ *
+ * @param conf
+ * the Configuration
+ * @param clusterPath
+ * the path to the prior Clusters
+ * @param clusters
+ * a List<Cluster> to put values into
+ */
+ public static void configureWithClusterInfo(Configuration conf, Path clusterPath, List<Cluster> clusters) {
+ for (Writable value : new SequenceFileDirValueIterable<>(clusterPath, PathType.LIST,
+ PathFilters.partFilter(), conf)) {
+ Class<? extends Writable> valueClass = value.getClass();
+
+ if (valueClass.equals(ClusterWritable.class)) {
+ ClusterWritable clusterWritable = (ClusterWritable) value;
+ value = clusterWritable.getValue();
+ valueClass = value.getClass();
+ }
+
+ if (valueClass.equals(Kluster.class)) {
+ // get the cluster info
+ Kluster cluster = (Kluster) value;
+ clusters.add(new SoftCluster(cluster.getCenter(), cluster.getId(), cluster.getMeasure()));
+ } else if (valueClass.equals(SoftCluster.class)) {
+ // get the cluster info
+ clusters.add((SoftCluster) value);
+ } else if (valueClass.equals(Canopy.class)) {
+ // get the cluster info
+ Canopy canopy = (Canopy) value;
+ clusters.add(new SoftCluster(canopy.getCenter(), canopy.getId(), canopy.getMeasure()));
+ } else {
+ throw new IllegalStateException("Bad value class: " + valueClass);
+ }
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java b/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
new file mode 100644
index 0000000..52fd764
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.fuzzykmeans;
+
+import org.apache.mahout.clustering.kmeans.Kluster;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public class SoftCluster extends Kluster {
+
+ // For Writable
+ public SoftCluster() {}
+
+ /**
+ * Construct a new SoftCluster with the given point as its center
+ *
+ * @param center
+ * the center point
+ * @param measure
+ * the DistanceMeasure
+ */
+ public SoftCluster(Vector center, int clusterId, DistanceMeasure measure) {
+ super(center, clusterId, measure);
+ }
+
+ @Override
+ public String asFormatString() {
+ return this.getIdentifier() + ": "
+ + this.computeCentroid().asFormatString();
+ }
+
+ @Override
+ public String getIdentifier() {
+ return (isConverged() ? "SV-" : "SC-") + getId();
+ }
+
+ @Override
+ public double pdf(VectorWritable vw) {
+ // SoftCluster pdf cannot be calculated out of context. See
+ // FuzzyKMeansClusterer
+ throw new UnsupportedOperationException(
+ "SoftCluster pdf cannot be calculated out of context. See FuzzyKMeansClusterer");
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/AbstractClusteringPolicy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/AbstractClusteringPolicy.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/AbstractClusteringPolicy.java
new file mode 100644
index 0000000..07cc7e3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/AbstractClusteringPolicy.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.iterator;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.TimesFunction;
+
+public abstract class AbstractClusteringPolicy implements ClusteringPolicy {
+
+ @Override
+ public abstract void write(DataOutput out) throws IOException;
+
+ @Override
+ public abstract void readFields(DataInput in) throws IOException;
+
+ @Override
+ public Vector select(Vector probabilities) {
+ int maxValueIndex = probabilities.maxValueIndex();
+ Vector weights = new SequentialAccessSparseVector(probabilities.size());
+ weights.set(maxValueIndex, 1.0);
+ return weights;
+ }
+
+ @Override
+ public void update(ClusterClassifier posterior) {
+ // nothing to do in general here
+ }
+
+ @Override
+ public Vector classify(Vector data, ClusterClassifier prior) {
+ List<Cluster> models = prior.getModels();
+ int i = 0;
+ Vector pdfs = new DenseVector(models.size());
+ for (Cluster model : models) {
+ pdfs.set(i++, model.pdf(new VectorWritable(data)));
+ }
+ return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
+ }
+
+ @Override
+ public void close(ClusterClassifier posterior) {
+ for (Cluster cluster : posterior.getModels()) {
+ cluster.computeParameters();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/CIMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/CIMapper.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/CIMapper.java
new file mode 100644
index 0000000..fb2db49
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/CIMapper.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.iterator;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+
+public class CIMapper extends Mapper<WritableComparable<?>,VectorWritable,IntWritable,ClusterWritable> {
+
+ private ClusterClassifier classifier;
+ private ClusteringPolicy policy;
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+ String priorClustersPath = conf.get(ClusterIterator.PRIOR_PATH_KEY);
+ classifier = new ClusterClassifier();
+ classifier.readFromSeqFiles(conf, new Path(priorClustersPath));
+ policy = classifier.getPolicy();
+ policy.update(classifier);
+ super.setup(context);
+ }
+
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable value, Context context) throws IOException,
+ InterruptedException {
+ Vector probabilities = classifier.classify(value.get());
+ Vector selections = policy.select(probabilities);
+ for (Element el : selections.nonZeroes()) {
+ classifier.train(el.index(), value.get(), el.get());
+ }
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ List<Cluster> clusters = classifier.getModels();
+ ClusterWritable cw = new ClusterWritable();
+ for (int index = 0; index < clusters.size(); index++) {
+ cw.setValue(clusters.get(index));
+ context.write(new IntWritable(index), cw);
+ }
+ super.cleanup(context);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/CIReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/CIReducer.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/CIReducer.java
new file mode 100644
index 0000000..bf42eb1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/CIReducer.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.iterator;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+
+public class CIReducer extends Reducer<IntWritable,ClusterWritable,IntWritable,ClusterWritable> {
+
+ private ClusterClassifier classifier;
+ private ClusteringPolicy policy;
+
+ @Override
+ protected void reduce(IntWritable key, Iterable<ClusterWritable> values, Context context) throws IOException,
+ InterruptedException {
+ Iterator<ClusterWritable> iter = values.iterator();
+ Cluster first = iter.next().getValue(); // there must always be at least one
+ while (iter.hasNext()) {
+ Cluster cluster = iter.next().getValue();
+ first.observe(cluster);
+ }
+ List<Cluster> models = Lists.newArrayList();
+ models.add(first);
+ classifier = new ClusterClassifier(models, policy);
+ classifier.close();
+ context.write(key, new ClusterWritable(first));
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ Configuration conf = context.getConfiguration();
+ String priorClustersPath = conf.get(ClusterIterator.PRIOR_PATH_KEY);
+ classifier = new ClusterClassifier();
+ classifier.readFromSeqFiles(conf, new Path(priorClustersPath));
+ policy = classifier.getPolicy();
+ policy.update(classifier);
+ super.setup(context);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/CanopyClusteringPolicy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/CanopyClusteringPolicy.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/CanopyClusteringPolicy.java
new file mode 100644
index 0000000..c9a0940
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/CanopyClusteringPolicy.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.iterator;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+
+@Deprecated
+public class CanopyClusteringPolicy extends AbstractClusteringPolicy {
+
+ private double t1;
+ private double t2;
+
+ @Override
+ public Vector select(Vector probabilities) {
+ int maxValueIndex = probabilities.maxValueIndex();
+ Vector weights = new SequentialAccessSparseVector(probabilities.size());
+ weights.set(maxValueIndex, 1.0);
+ return weights;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(t1);
+ out.writeDouble(t2);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.t1 = in.readDouble();
+ this.t2 = in.readDouble();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java
new file mode 100644
index 0000000..516177f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java
@@ -0,0 +1,219 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.iterator;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.io.Closeables;
+
+/**
+ * This is a clustering iterator which works with a set of Vector data and a prior ClusterClassifier which has been
+ * initialized with a set of models. Its implementation is algorithm-neutral and works for any iterative clustering
+ * algorithm (currently k-means and fuzzy-k-means) that processes all the input vectors in each iteration.
+ * The cluster classifier is configured with a ClusteringPolicy to select the desired clustering algorithm.
+ */
+public final class ClusterIterator {
+
+ public static final String PRIOR_PATH_KEY = "org.apache.mahout.clustering.prior.path";
+
+ private ClusterIterator() {
+ }
+
+ /**
+ * Iterate over data using a prior-trained ClusterClassifier, for a number of iterations
+ *
+ * @param data
+ * a {@code List<Vector>} of input vectors
+ * @param classifier
+ * a prior ClusterClassifier
+ * @param numIterations
+ * the int number of iterations to perform
+ *
+ * @return the posterior ClusterClassifier
+ */
+ public static ClusterClassifier iterate(Iterable<Vector> data, ClusterClassifier classifier, int numIterations) {
+ ClusteringPolicy policy = classifier.getPolicy();
+ for (int iteration = 1; iteration <= numIterations; iteration++) {
+ for (Vector vector : data) {
+ // update the policy based upon the prior
+ policy.update(classifier);
+ // classification yields probabilities
+ Vector probabilities = classifier.classify(vector);
+ // policy selects weights for models given those probabilities
+ Vector weights = policy.select(probabilities);
+ // training causes all models to observe data
+ for (Vector.Element e : weights.nonZeroes()) {
+ int index = e.index();
+ classifier.train(index, vector, weights.get(index));
+ }
+ }
+ // compute the posterior models
+ classifier.close();
+ }
+ return classifier;
+ }
+
+ /**
+ * Iterate over data using a prior-trained ClusterClassifier, for a number of iterations using a sequential
+ * implementation
+ *
+ * @param conf
+ * the Configuration
+ * @param inPath
+ * a Path to input VectorWritables
+ * @param priorPath
+ * a Path to the prior classifier
+ * @param outPath
+ * a Path of output directory
+ * @param numIterations
+ * the int number of iterations to perform
+ */
+ public static void iterateSeq(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations)
+ throws IOException {
+ ClusterClassifier classifier = new ClusterClassifier();
+ classifier.readFromSeqFiles(conf, priorPath);
+ Path clustersOut = null;
+ int iteration = 1;
+ while (iteration <= numIterations) {
+ for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(inPath, PathType.LIST,
+ PathFilters.logsCRCFilter(), conf)) {
+ Vector vector = vw.get();
+ // classification yields probabilities
+ Vector probabilities = classifier.classify(vector);
+ // policy selects weights for models given those probabilities
+ Vector weights = classifier.getPolicy().select(probabilities);
+ // training causes all models to observe data
+ for (Vector.Element e : weights.nonZeroes()) {
+ int index = e.index();
+ classifier.train(index, vector, weights.get(index));
+ }
+ }
+ // compute the posterior models
+ classifier.close();
+ // update the policy
+ classifier.getPolicy().update(classifier);
+ // output the classifier
+ clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration);
+ classifier.writeToSeqFiles(clustersOut);
+ FileSystem fs = FileSystem.get(outPath.toUri(), conf);
+ iteration++;
+ if (isConverged(clustersOut, conf, fs)) {
+ break;
+ }
+ }
+ Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX);
+ FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn);
+ }
+
+ /**
+ * Iterate over data using a prior-trained ClusterClassifier, for a number of iterations using a mapreduce
+ * implementation
+ *
+ * @param conf
+ * the Configuration
+ * @param inPath
+ * a Path to input VectorWritables
+ * @param priorPath
+ * a Path to the prior classifier
+ * @param outPath
+ * a Path of output directory
+ * @param numIterations
+ * the int number of iterations to perform
+ */
+ public static void iterateMR(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ ClusteringPolicy policy = ClusterClassifier.readPolicy(priorPath);
+ Path clustersOut = null;
+ int iteration = 1;
+ while (iteration <= numIterations) {
+ conf.set(PRIOR_PATH_KEY, priorPath.toString());
+
+ String jobName = "Cluster Iterator running iteration " + iteration + " over priorPath: " + priorPath;
+ Job job = new Job(conf, jobName);
+ job.setMapOutputKeyClass(IntWritable.class);
+ job.setMapOutputValueClass(ClusterWritable.class);
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(ClusterWritable.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setMapperClass(CIMapper.class);
+ job.setReducerClass(CIReducer.class);
+
+ FileInputFormat.addInputPath(job, inPath);
+ clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration);
+ priorPath = clustersOut;
+ FileOutputFormat.setOutputPath(job, clustersOut);
+
+ job.setJarByClass(ClusterIterator.class);
+ if (!job.waitForCompletion(true)) {
+ throw new InterruptedException("Cluster Iteration " + iteration + " failed processing " + priorPath);
+ }
+ ClusterClassifier.writePolicy(policy, clustersOut);
+ FileSystem fs = FileSystem.get(outPath.toUri(), conf);
+ iteration++;
+ if (isConverged(clustersOut, conf, fs)) {
+ break;
+ }
+ }
+ Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX);
+ FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn);
+ }
+
+ /**
+ * Return if all of the Clusters in the parts in the filePath have converged or not
+ *
+ * @param filePath
+ * the file path to the single file containing the clusters
+ * @return true if all Clusters are converged
+ * @throws IOException
+ * if there was an IO error
+ */
+ private static boolean isConverged(Path filePath, Configuration conf, FileSystem fs) throws IOException {
+ for (FileStatus part : fs.listStatus(filePath, PathFilters.partFilter())) {
+ SequenceFileValueIterator<ClusterWritable> iterator = new SequenceFileValueIterator<>(
+ part.getPath(), true, conf);
+ while (iterator.hasNext()) {
+ ClusterWritable value = iterator.next();
+ if (!value.getValue().isConverged()) {
+ Closeables.close(iterator, true);
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java
new file mode 100644
index 0000000..855685f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.iterator;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.sgd.PolymorphicWritable;
+import org.apache.mahout.clustering.Cluster;
+
+public class ClusterWritable implements Writable {
+
+ private Cluster value;
+
+ public ClusterWritable(Cluster first) {
+ value = first;
+ }
+
+ public ClusterWritable() {
+ }
+
+ public Cluster getValue() {
+ return value;
+ }
+
+ public void setValue(Cluster value) {
+ this.value = value;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ PolymorphicWritable.write(out, value);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ value = PolymorphicWritable.read(in, Cluster.class);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java
new file mode 100644
index 0000000..6e15838
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.iterator;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.math.Vector;
+
+/**
+ * A ClusteringPolicy captures the semantics of assignment of points to clusters
+ *
+ */
+public interface ClusteringPolicy extends Writable {
+
+ /**
+ * Classify the data vector given the classifier's models
+ *
+ * @param data
+ * a data Vector
+ * @param prior
+ * a prior ClusterClassifier
+ * @return a Vector of probabilities that the data is described by each of the
+ * models
+ */
+ Vector classify(Vector data, ClusterClassifier prior);
+
+ /**
+ * Return a vector of weights for each of the models given those probabilities
+ *
+ * @param probabilities
+ * a Vector of pdfs
+ * @return a Vector of weights
+ */
+ Vector select(Vector probabilities);
+
+ /**
+ * Update the policy with the given classifier
+ *
+ * @param posterior
+ * a ClusterClassifier
+ */
+ void update(ClusterClassifier posterior);
+
+ /**
+ * Close the policy using the classifier's models
+ *
+ * @param posterior
+ * a posterior ClusterClassifier
+ */
+ void close(ClusterClassifier posterior);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java
new file mode 100644
index 0000000..f69442d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.iterator;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.sgd.PolymorphicWritable;
+
+public class ClusteringPolicyWritable implements Writable {
+
+ private ClusteringPolicy value;
+
+ public ClusteringPolicyWritable(ClusteringPolicy policy) {
+ this.value = policy;
+ }
+
+ public ClusteringPolicyWritable() {
+ }
+
+ public ClusteringPolicy getValue() {
+ return value;
+ }
+
+ public void setValue(ClusteringPolicy value) {
+ this.value = value;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ PolymorphicWritable.write(out, value);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ value = PolymorphicWritable.read(in, ClusteringPolicy.class);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java
new file mode 100644
index 0000000..f61aa27
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java
@@ -0,0 +1,91 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.iterator;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.clustering.AbstractCluster;
+import org.apache.mahout.clustering.Model;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public class DistanceMeasureCluster extends AbstractCluster {
+
+ private DistanceMeasure measure;
+
+ public DistanceMeasureCluster(Vector point, int id, DistanceMeasure measure) {
+ super(point, id);
+ this.measure = measure;
+ }
+
+ public DistanceMeasureCluster() {
+ }
+
+ @Override
+ public void configure(Configuration job) {
+ if (measure != null) {
+ measure.configure(job);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ String dm = in.readUTF();
+ this.measure = ClassUtils.instantiateAs(dm, DistanceMeasure.class);
+ super.readFields(in);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeUTF(measure.getClass().getName());
+ super.write(out);
+ }
+
+ @Override
+ public double pdf(VectorWritable vw) {
+ return 1 / (1 + measure.distance(vw.get(), getCenter()));
+ }
+
+ @Override
+ public Model<VectorWritable> sampleFromPosterior() {
+ return new DistanceMeasureCluster(getCenter(), getId(), measure);
+ }
+
+ public DistanceMeasure getMeasure() {
+ return measure;
+ }
+
+ /**
+ * @param measure
+ * the measure to set
+ */
+ public void setMeasure(DistanceMeasure measure) {
+ this.measure = measure;
+ }
+
+ @Override
+ public String getIdentifier() {
+ return "DMC:" + getId();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java
new file mode 100644
index 0000000..bc91f24
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java
@@ -0,0 +1,91 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.iterator;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterer;
+import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
+import org.apache.mahout.math.Vector;
+
+import com.google.common.collect.Lists;
+
+/**
+ * This is a probability-weighted clustering policy, suitable for fuzzy k-means
+ * clustering
+ *
+ */
+public class FuzzyKMeansClusteringPolicy extends AbstractClusteringPolicy {
+
+ private double m = 2;
+ private double convergenceDelta = 0.05;
+
+ public FuzzyKMeansClusteringPolicy() {
+ }
+
+ public FuzzyKMeansClusteringPolicy(double m, double convergenceDelta) {
+ this.m = m;
+ this.convergenceDelta = convergenceDelta;
+ }
+
+ @Override
+ public Vector select(Vector probabilities) {
+ return probabilities;
+ }
+
+ @Override
+ public Vector classify(Vector data, ClusterClassifier prior) {
+ Collection<SoftCluster> clusters = Lists.newArrayList();
+ List<Double> distances = Lists.newArrayList();
+ for (Cluster model : prior.getModels()) {
+ SoftCluster sc = (SoftCluster) model;
+ clusters.add(sc);
+ distances.add(sc.getMeasure().distance(data, sc.getCenter()));
+ }
+ FuzzyKMeansClusterer fuzzyKMeansClusterer = new FuzzyKMeansClusterer();
+ fuzzyKMeansClusterer.setM(m);
+ return fuzzyKMeansClusterer.computePi(clusters, distances);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(m);
+ out.writeDouble(convergenceDelta);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.m = in.readDouble();
+ this.convergenceDelta = in.readDouble();
+ }
+
+ @Override
+ public void close(ClusterClassifier posterior) {
+ for (Cluster cluster : posterior.getModels()) {
+ ((org.apache.mahout.clustering.kmeans.Kluster) cluster).calculateConvergence(convergenceDelta);
+ cluster.computeParameters();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java b/mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java
new file mode 100644
index 0000000..1cc9faf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.iterator;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+
+/**
+ * This is a simple maximum likelihood clustering policy, suitable for k-means
+ * clustering
+ *
+ */
+public class KMeansClusteringPolicy extends AbstractClusteringPolicy {
+
+ public KMeansClusteringPolicy() {
+ }
+
+ public KMeansClusteringPolicy(double convergenceDelta) {
+ this.convergenceDelta = convergenceDelta;
+ }
+
+ private double convergenceDelta = 0.001;
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(convergenceDelta);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.convergenceDelta = in.readDouble();
+ }
+
+ @Override
+ public void close(ClusterClassifier posterior) {
+ boolean allConverged = true;
+ for (Cluster cluster : posterior.getModels()) {
+ org.apache.mahout.clustering.kmeans.Kluster kluster = (org.apache.mahout.clustering.kmeans.Kluster) cluster;
+ boolean converged = kluster.calculateConvergence(convergenceDelta);
+ allConverged = allConverged && converged;
+ cluster.computeParameters();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java b/mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java
new file mode 100644
index 0000000..96c4082
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java
@@ -0,0 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.kernel;
+
+public interface IKernelProfile {
+
+ /**
+ * @return the calculated dervative value of the kernel
+ */
+ double calculateDerivativeValue(double distance, double h);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java b/mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java
new file mode 100644
index 0000000..46909bb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java
@@ -0,0 +1,27 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.kernel;
+
+public class TriangularKernelProfile implements IKernelProfile {
+
+ @Override
+ public double calculateDerivativeValue(double distance, double h) {
+ return distance < h ? 1.0 : 0.0;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java b/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
new file mode 100644
index 0000000..13f6b46
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
@@ -0,0 +1,257 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.kmeans;
+
+import java.io.IOException;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.classify.ClusterClassificationDriver;
+import org.apache.mahout.clustering.classify.ClusterClassifier;
+import org.apache.mahout.clustering.iterator.ClusterIterator;
+import org.apache.mahout.clustering.iterator.ClusteringPolicy;
+import org.apache.mahout.clustering.iterator.KMeansClusteringPolicy;
+import org.apache.mahout.clustering.topdown.PathDirectory;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class KMeansDriver extends AbstractJob {
+
+ private static final Logger log = LoggerFactory.getLogger(KMeansDriver.class);
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new KMeansDriver(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+ addOption(DefaultOptionCreator
+ .clustersInOption()
+ .withDescription(
+ "The input centroids, as Vectors. Must be a SequenceFile of Writable, Cluster/Canopy. "
+ + "If k is also specified, then a random set of vectors will be selected"
+ + " and written out to this path first").create());
+ addOption(DefaultOptionCreator
+ .numClustersOption()
+ .withDescription(
+ "The k in k-Means. If specified, then a random selection of k Vectors will be chosen"
+ + " as the Centroid and written to the clusters input path.").create());
+ addOption(DefaultOptionCreator.useSetRandomSeedOption().create());
+ addOption(DefaultOptionCreator.convergenceOption().create());
+ addOption(DefaultOptionCreator.maxIterationsOption().create());
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addOption(DefaultOptionCreator.clusteringOption().create());
+ addOption(DefaultOptionCreator.methodOption().create());
+ addOption(DefaultOptionCreator.outlierThresholdOption().create());
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path clusters = new Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
+ Path output = getOutputPath();
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ if (measureClass == null) {
+ measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+ }
+ double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION));
+ int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION));
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
+
+ if (hasOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)) {
+ int numClusters = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION));
+
+ Long seed = null;
+ if (hasOption(DefaultOptionCreator.RANDOM_SEED)) {
+ seed = Long.parseLong(getOption(DefaultOptionCreator.RANDOM_SEED));
+ }
+
+ clusters = RandomSeedGenerator.buildRandom(getConf(), input, clusters, numClusters, measure, seed);
+ }
+ boolean runClustering = hasOption(DefaultOptionCreator.CLUSTERING_OPTION);
+ boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
+ DefaultOptionCreator.SEQUENTIAL_METHOD);
+ double clusterClassificationThreshold = 0.0;
+ if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
+ clusterClassificationThreshold = Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
+ }
+ run(getConf(), input, clusters, output, convergenceDelta, maxIterations, runClustering,
+ clusterClassificationThreshold, runSequential);
+ return 0;
+ }
+
+ /**
+ * Iterate over the input vectors to produce clusters and, if requested, use the results of the final iteration to
+ * cluster the input vectors.
+ *
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for initial & computed clusters
+ * @param output
+ * the directory pathname for output points
+ * @param convergenceDelta
+ * the convergence delta value
+ * @param maxIterations
+ * the maximum number of iterations
+ * @param runClustering
+ * true if points are to be clustered after iterations are completed
+ * @param clusterClassificationThreshold
+ * Is a clustering strictness / outlier removal parameter. Its value should be between 0 and 1. Vectors
+ * having pdf below this value will not be clustered.
+ * @param runSequential
+ * if true execute sequential algorithm
+ */
+ public static void run(Configuration conf, Path input, Path clustersIn, Path output,
+ double convergenceDelta, int maxIterations, boolean runClustering, double clusterClassificationThreshold,
+ boolean runSequential) throws IOException, InterruptedException, ClassNotFoundException {
+
+ // iterate until the clusters converge
+ String delta = Double.toString(convergenceDelta);
+ if (log.isInfoEnabled()) {
+ log.info("Input: {} Clusters In: {} Out: {}", input, clustersIn, output);
+ log.info("convergence: {} max Iterations: {}", convergenceDelta, maxIterations);
+ }
+ Path clustersOut = buildClusters(conf, input, clustersIn, output, maxIterations, delta, runSequential);
+ if (runClustering) {
+ log.info("Clustering data");
+ clusterData(conf, input, clustersOut, output, clusterClassificationThreshold, runSequential);
+ }
+ }
+
+ /**
+ * Iterate over the input vectors to produce clusters and, if requested, use the results of the final iteration to
+ * cluster the input vectors.
+ *
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for initial & computed clusters
+ * @param output
+ * the directory pathname for output points
+ * @param convergenceDelta
+ * the convergence delta value
+ * @param maxIterations
+ * the maximum number of iterations
+ * @param runClustering
+ * true if points are to be clustered after iterations are completed
+ * @param clusterClassificationThreshold
+ * Is a clustering strictness / outlier removal parameter. Its value should be between 0 and 1. Vectors
+ * having pdf below this value will not be clustered.
+ * @param runSequential
+ * if true execute sequential algorithm
+ */
+ public static void run(Path input, Path clustersIn, Path output, double convergenceDelta,
+ int maxIterations, boolean runClustering, double clusterClassificationThreshold, boolean runSequential)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ run(new Configuration(), input, clustersIn, output, convergenceDelta, maxIterations, runClustering,
+ clusterClassificationThreshold, runSequential);
+ }
+
+ /**
+ * Iterate over the input vectors to produce cluster directories for each iteration
+ *
+ *
+ * @param conf
+ * the Configuration to use
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for initial & computed clusters
+ * @param output
+ * the directory pathname for output points
+ * @param maxIterations
+ * the maximum number of iterations
+ * @param delta
+ * the convergence delta value
+ * @param runSequential
+ * if true execute sequential algorithm
+ *
+ * @return the Path of the final clusters directory
+ */
+ public static Path buildClusters(Configuration conf, Path input, Path clustersIn, Path output,
+ int maxIterations, String delta, boolean runSequential) throws IOException,
+ InterruptedException, ClassNotFoundException {
+
+ double convergenceDelta = Double.parseDouble(delta);
+ List<Cluster> clusters = Lists.newArrayList();
+ KMeansUtil.configureWithClusterInfo(conf, clustersIn, clusters);
+
+ if (clusters.isEmpty()) {
+ throw new IllegalStateException("No input clusters found in " + clustersIn + ". Check your -c argument.");
+ }
+
+ Path priorClustersPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR);
+ ClusteringPolicy policy = new KMeansClusteringPolicy(convergenceDelta);
+ ClusterClassifier prior = new ClusterClassifier(clusters, policy);
+ prior.writeToSeqFiles(priorClustersPath);
+
+ if (runSequential) {
+ ClusterIterator.iterateSeq(conf, input, priorClustersPath, output, maxIterations);
+ } else {
+ ClusterIterator.iterateMR(conf, input, priorClustersPath, output, maxIterations);
+ }
+ return output;
+ }
+
+ /**
+ * Run the job using supplied arguments
+ *
+ * @param input
+ * the directory pathname for input points
+ * @param clustersIn
+ * the directory pathname for input clusters
+ * @param output
+ * the directory pathname for output points
+ * @param clusterClassificationThreshold
+ * Is a clustering strictness / outlier removal parameter. Its value should be between 0 and 1. Vectors
+ * having pdf below this value will not be clustered.
+ * @param runSequential
+ * if true execute sequential algorithm
+ */
+ public static void clusterData(Configuration conf, Path input, Path clustersIn, Path output,
+ double clusterClassificationThreshold, boolean runSequential) throws IOException, InterruptedException,
+ ClassNotFoundException {
+
+ if (log.isInfoEnabled()) {
+ log.info("Running Clustering");
+ log.info("Input: {} Clusters In: {} Out: {}", input, clustersIn, output);
+ }
+ ClusterClassifier.writePolicy(new KMeansClusteringPolicy(), clustersIn);
+ ClusterClassificationDriver.run(conf, input, output, new Path(output, PathDirectory.CLUSTERED_POINTS_DIRECTORY),
+ clusterClassificationThreshold, true, runSequential);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java b/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java
new file mode 100644
index 0000000..3365f70
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java
@@ -0,0 +1,74 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.kmeans;
+
+import java.util.Collection;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+final class KMeansUtil {
+
+ private static final Logger log = LoggerFactory.getLogger(KMeansUtil.class);
+
+ private KMeansUtil() {}
+
+ /**
+ * Create a list of Klusters from whatever Cluster type is passed in as the prior
+ *
+ * @param conf
+ * the Configuration
+ * @param clusterPath
+ * the path to the prior Clusters
+ * @param clusters
+ * a List<Cluster> to put values into
+ */
+ public static void configureWithClusterInfo(Configuration conf, Path clusterPath, Collection<Cluster> clusters) {
+ for (Writable value : new SequenceFileDirValueIterable<>(clusterPath, PathType.LIST,
+ PathFilters.partFilter(), conf)) {
+ Class<? extends Writable> valueClass = value.getClass();
+ if (valueClass.equals(ClusterWritable.class)) {
+ ClusterWritable clusterWritable = (ClusterWritable) value;
+ value = clusterWritable.getValue();
+ valueClass = value.getClass();
+ }
+ log.debug("Read 1 Cluster from {}", clusterPath);
+
+ if (valueClass.equals(Kluster.class)) {
+ // get the cluster info
+ clusters.add((Kluster) value);
+ } else if (valueClass.equals(Canopy.class)) {
+ // get the cluster info
+ Canopy canopy = (Canopy) value;
+ clusters.add(new Kluster(canopy.getCenter(), canopy.getId(), canopy.getMeasure()));
+ } else {
+ throw new IllegalStateException("Bad value class: " + valueClass);
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/kmeans/Kluster.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/kmeans/Kluster.java b/mr/src/main/java/org/apache/mahout/clustering/kmeans/Kluster.java
new file mode 100644
index 0000000..15daec5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/kmeans/Kluster.java
@@ -0,0 +1,117 @@
+/* Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.kmeans;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.mahout.clustering.iterator.DistanceMeasureCluster;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Vector;
+
+public class Kluster extends DistanceMeasureCluster {
+
+ /** Has the centroid converged with the center? */
+ private boolean converged;
+
+ /** For (de)serialization as a Writable */
+ public Kluster() {
+ }
+
+ /**
+ * Construct a new cluster with the given point as its center
+ *
+ * @param center
+ * the Vector center
+ * @param clusterId
+ * the int cluster id
+ * @param measure
+ * a DistanceMeasure
+ */
+ public Kluster(Vector center, int clusterId, DistanceMeasure measure) {
+ super(center, clusterId, measure);
+ }
+
+ /**
+ * Format the cluster for output
+ *
+ * @param cluster
+ * the Cluster
+ * @return the String representation of the Cluster
+ */
+ public static String formatCluster(Kluster cluster) {
+ return cluster.getIdentifier() + ": " + cluster.computeCentroid().asFormatString();
+ }
+
+ public String asFormatString() {
+ return formatCluster(this);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ super.write(out);
+ out.writeBoolean(converged);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ super.readFields(in);
+ this.converged = in.readBoolean();
+ }
+
+ @Override
+ public String toString() {
+ return asFormatString(null);
+ }
+
+ @Override
+ public String getIdentifier() {
+ return (converged ? "VL-" : "CL-") + getId();
+ }
+
+ /**
+ * Return if the cluster is converged by comparing its center and centroid.
+ *
+ * @param measure
+ * The distance measure to use for cluster-point comparisons.
+ * @param convergenceDelta
+ * the convergence delta to use for stopping.
+ * @return if the cluster is converged
+ */
+ public boolean computeConvergence(DistanceMeasure measure, double convergenceDelta) {
+ Vector centroid = computeCentroid();
+ converged = measure.distance(centroid.getLengthSquared(), centroid, getCenter()) <= convergenceDelta;
+ return converged;
+ }
+
+ @Override
+ public boolean isConverged() {
+ return converged;
+ }
+
+ protected void setConverged(boolean converged) {
+ this.converged = converged;
+ }
+
+ public boolean calculateConvergence(double convergenceDelta) {
+ Vector centroid = computeCentroid();
+ converged = getMeasure().distance(centroid.getLengthSquared(), centroid, getCenter()) <= convergenceDelta;
+ return converged;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java b/mr/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
new file mode 100644
index 0000000..cc9e4cd
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.kmeans;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, randomly select k vectors and
+ * write them to the output file as a {@link org.apache.mahout.clustering.kmeans.Kluster} representing the
+ * initial centroid to use.
+ *
+ * This implementation uses reservoir sampling as described in http://en.wikipedia.org/wiki/Reservoir_sampling
+ */
+public final class RandomSeedGenerator {
+
+ private static final Logger log = LoggerFactory.getLogger(RandomSeedGenerator.class);
+
+ public static final String K = "k";
+
+ private RandomSeedGenerator() {}
+
+ public static Path buildRandom(Configuration conf, Path input, Path output, int k, DistanceMeasure measure)
+ throws IOException {
+ return buildRandom(conf, input, output, k, measure, null);
+ }
+
+ public static Path buildRandom(Configuration conf,
+ Path input,
+ Path output,
+ int k,
+ DistanceMeasure measure,
+ Long seed) throws IOException {
+
+ Preconditions.checkArgument(k > 0, "Must be: k > 0, but k = " + k);
+ // delete the output directory
+ FileSystem fs = FileSystem.get(output.toUri(), conf);
+ HadoopUtil.delete(conf, output);
+ Path outFile = new Path(output, "part-randomSeed");
+ boolean newFile = fs.createNewFile(outFile);
+ if (newFile) {
+ Path inputPathPattern;
+
+ if (fs.getFileStatus(input).isDir()) {
+ inputPathPattern = new Path(input, "*");
+ } else {
+ inputPathPattern = input;
+ }
+
+ FileStatus[] inputFiles = fs.globStatus(inputPathPattern, PathFilters.logsCRCFilter());
+ SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outFile, Text.class, ClusterWritable.class);
+
+ Random random = (seed != null) ? RandomUtils.getRandom(seed) : RandomUtils.getRandom();
+
+ List<Text> chosenTexts = Lists.newArrayListWithCapacity(k);
+ List<ClusterWritable> chosenClusters = Lists.newArrayListWithCapacity(k);
+ int nextClusterId = 0;
+
+ int index = 0;
+ for (FileStatus fileStatus : inputFiles) {
+ if (!fileStatus.isDir()) {
+ for (Pair<Writable, VectorWritable> record
+ : new SequenceFileIterable<Writable, VectorWritable>(fileStatus.getPath(), true, conf)) {
+ Writable key = record.getFirst();
+ VectorWritable value = record.getSecond();
+ Kluster newCluster = new Kluster(value.get(), nextClusterId++, measure);
+ newCluster.observe(value.get(), 1);
+ Text newText = new Text(key.toString());
+ int currentSize = chosenTexts.size();
+ if (currentSize < k) {
+ chosenTexts.add(newText);
+ ClusterWritable clusterWritable = new ClusterWritable();
+ clusterWritable.setValue(newCluster);
+ chosenClusters.add(clusterWritable);
+ } else {
+ int j = random.nextInt(index);
+ if (j < k) {
+ chosenTexts.set(j, newText);
+ ClusterWritable clusterWritable = new ClusterWritable();
+ clusterWritable.setValue(newCluster);
+ chosenClusters.set(j, clusterWritable);
+ }
+ }
+ index++;
+ }
+ }
+ }
+
+ try {
+ for (int i = 0; i < chosenTexts.size(); i++) {
+ writer.append(chosenTexts.get(i), chosenClusters.get(i));
+ }
+ log.info("Wrote {} Klusters to {}", k, outFile);
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ return outFile;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/kmeans/package-info.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/kmeans/package-info.java b/mr/src/main/java/org/apache/mahout/clustering/kmeans/package-info.java
new file mode 100644
index 0000000..d6921b6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/kmeans/package-info.java
@@ -0,0 +1,5 @@
+/**
+ * This package provides an implementation of the <a href="http://en.wikipedia.org/wiki/Kmeans">k-means</a> clustering
+ * algorithm.
+ */
+package org.apache.mahout.clustering.kmeans;
[26/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java
new file mode 100644
index 0000000..25a4022
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java
@@ -0,0 +1,456 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.cluster;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+import org.apache.mahout.math.random.Multinomial;
+import org.apache.mahout.math.random.WeightedThing;
+
+/**
+ * Implements a ball k-means algorithm for weighted vectors with probabilistic seeding similar to k-means++.
+ * The idea is that k-means++ gives good starting clusters and ball k-means can tune up the final result very nicely
+ * in only a few passes (or even in a single iteration for well-clusterable data).
+ *
+ * A good reference for this class of algorithms is "The Effectiveness of Lloyd-Type Methods for the k-Means Problem"
+ * by Rafail Ostrovsky, Yuval Rabani, Leonard J. Schulman and Chaitanya Swamy. The code here uses the seeding strategy
+ * as described in section 4.1.1 of that paper and the ball k-means step as described in section 4.2. We support
+ * multiple iterations in contrast to the algorithm described in the paper.
+ */
+public class BallKMeans implements Iterable<Centroid> {
+ /**
+ * The searcher containing the centroids.
+ */
+ private final UpdatableSearcher centroids;
+
+ /**
+ * The number of clusters to cluster the data into.
+ */
+ private final int numClusters;
+
+ /**
+ * The maximum number of iterations of the algorithm to run waiting for the cluster assignments
+ * to stabilize. If there are no changes in cluster assignment earlier, we can finish early.
+ */
+ private final int maxNumIterations;
+
+ /**
+ * When deciding which points to include in the new centroid calculation,
+ * it's preferable to exclude outliers since it increases the rate of convergence.
+ * So, we calculate the distance from each cluster to its closest neighboring cluster. When
+ * evaluating the points assigned to a cluster, we compare the distance between the centroid to
+ * the point with the distance between the centroid and its closest centroid neighbor
+ * multiplied by this trimFraction. If the distance between the centroid and the point is
+ * greater, we consider it an outlier and we don't use it.
+ */
+ private final double trimFraction;
+
+ /**
+ * Selecting the initial centroids is the most important part of the ball k-means clustering. Poor choices, like two
+ * centroids in the same actual cluster result in a low-quality final result.
+ * k-means++ initialization yields good quality clusters, especially when using BallKMeans after StreamingKMeans as
+ * the points have weights.
+ * Simple, random selection of the points based on their weights is faster but sometimes fails to produce the
+ * desired number of clusters.
+ * This field is true if the initialization should be done with k-means++.
+ */
+ private final boolean kMeansPlusPlusInit;
+
+ /**
+ * When using trimFraction, the weight of each centroid will not be the sum of the weights of
+ * the vectors assigned to that cluster because outliers are not used to compute the updated
+ * centroid.
+ * So, the total weight is probably wrong. This can be fixed by doing another pass over the
+ * data points and adjusting the weights of each centroid. This doesn't update the coordinates
+ * of the centroids, but is useful if the weights matter.
+ */
+ private final boolean correctWeights;
+
+ /**
+ * When running multiple ball k-means passes to get the one with the smallest total cost, can compute the
+ * overall cost, using all the points for clustering, or reserve a fraction of them, testProbability in a test set.
+ * The cost is the sum of the distances between each point and its corresponding centroid.
+ * We then use this set of points to compute the total cost on. We're therefore trying to select the clustering
+ * that best describes the underlying distribution of the clusters.
+ * This field is the probability of assigning a given point to the test set. If this is 0, the cost will be computed
+ * on the entire set of points.
+ */
+ private final double testProbability;
+
+ /**
+ * Whether or not testProbability > 0, i.e., there exists a non-empty 'test' set.
+ */
+ private final boolean splitTrainTest;
+
+ /**
+ * How many k-means runs to have. If there's more than one run, we compute the cost of each clustering as described
+ * above and select the clustering that minimizes the cost.
+ * Multiple runs are a lot more useful when using the random initialization. With kmeans++, 1-2 runs are enough and
+ * more runs are not likely to help quality much.
+ */
+ private final int numRuns;
+
+ /**
+ * Random object to sample values from.
+ */
+ private final Random random;
+
+ public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations) {
+ // By default, the trimFraction is 0.9, k-means++ is used, the weights will be corrected at the end,
+ // there will be 0 points in the test set and 1 run.
+ this(searcher, numClusters, maxNumIterations, 0.9, true, true, 0.0, 1);
+ }
+
+ public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations,
+ boolean kMeansPlusPlusInit, int numRuns) {
+ // By default, the trimFraction is 0.9, k-means++ is used, the weights will be corrected at the end,
+ // there will be 10% points of in the test set.
+ this(searcher, numClusters, maxNumIterations, 0.9, kMeansPlusPlusInit, true, 0.1, numRuns);
+ }
+
+ public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations,
+ double trimFraction, boolean kMeansPlusPlusInit, boolean correctWeights,
+ double testProbability, int numRuns) {
+ Preconditions.checkArgument(searcher.size() == 0, "Searcher must be empty initially to populate with centroids");
+ Preconditions.checkArgument(numClusters > 0, "The requested number of clusters must be positive");
+ Preconditions.checkArgument(maxNumIterations > 0, "The maximum number of iterations must be positive");
+ Preconditions.checkArgument(trimFraction > 0, "The trim fraction must be positive");
+ Preconditions.checkArgument(testProbability >= 0 && testProbability < 1, "The testProbability must be in [0, 1)");
+ Preconditions.checkArgument(numRuns > 0, "There has to be at least one run");
+
+ this.centroids = searcher;
+ this.numClusters = numClusters;
+ this.maxNumIterations = maxNumIterations;
+
+ this.trimFraction = trimFraction;
+ this.kMeansPlusPlusInit = kMeansPlusPlusInit;
+ this.correctWeights = correctWeights;
+
+ this.testProbability = testProbability;
+ this.splitTrainTest = testProbability > 0;
+ this.numRuns = numRuns;
+
+ this.random = RandomUtils.getRandom();
+ }
+
+ public Pair<List<? extends WeightedVector>, List<? extends WeightedVector>> splitTrainTest(
+ List<? extends WeightedVector> datapoints) {
+ // If there will be no points assigned to the test set, return now.
+ if (testProbability == 0) {
+ return new Pair<List<? extends WeightedVector>, List<? extends WeightedVector>>(datapoints,
+ Lists.<WeightedVector>newArrayList());
+ }
+
+ int numTest = (int) (testProbability * datapoints.size());
+ Preconditions.checkArgument(numTest > 0 && numTest < datapoints.size(),
+ "Must have nonzero number of training and test vectors. Asked for %.1f %% of %d vectors for test",
+ testProbability * 100, datapoints.size());
+
+ Collections.shuffle(datapoints);
+ return new Pair<List<? extends WeightedVector>, List<? extends WeightedVector>>(
+ datapoints.subList(numTest, datapoints.size()), datapoints.subList(0, numTest));
+ }
+
+ /**
+ * Clusters the datapoints in the list doing either random seeding of the centroids or k-means++.
+ *
+ * @param datapoints the points to be clustered.
+ * @return an UpdatableSearcher with the resulting clusters.
+ */
+ public UpdatableSearcher cluster(List<? extends WeightedVector> datapoints) {
+ Pair<List<? extends WeightedVector>, List<? extends WeightedVector>> trainTestSplit = splitTrainTest(datapoints);
+ List<Vector> bestCentroids = Lists.newArrayList();
+ double cost = Double.POSITIVE_INFINITY;
+ double bestCost = Double.POSITIVE_INFINITY;
+ for (int i = 0; i < numRuns; ++i) {
+ centroids.clear();
+ if (kMeansPlusPlusInit) {
+ // Use k-means++ to set initial centroids.
+ initializeSeedsKMeansPlusPlus(trainTestSplit.getFirst());
+ } else {
+ // Randomly select the initial centroids.
+ initializeSeedsRandomly(trainTestSplit.getFirst());
+ }
+ // Do k-means iterations with trimmed mean computation (aka ball k-means).
+ if (numRuns > 1) {
+ // If the clustering is successful (there are no zero-weight centroids).
+ iterativeAssignment(trainTestSplit.getFirst());
+ // Compute the cost of the clustering and possibly save the centroids.
+ cost = ClusteringUtils.totalClusterCost(
+ splitTrainTest ? datapoints : trainTestSplit.getSecond(), centroids);
+ if (cost < bestCost) {
+ bestCost = cost;
+ bestCentroids.clear();
+ Iterables.addAll(bestCentroids, centroids);
+ }
+ } else {
+ // If there is only going to be one run, the cost doesn't need to be computed, so we just return the clustering.
+ iterativeAssignment(datapoints);
+ return centroids;
+ }
+ }
+ if (bestCost == Double.POSITIVE_INFINITY) {
+ throw new RuntimeException("No valid clustering was found");
+ }
+ if (cost != bestCost) {
+ centroids.clear();
+ centroids.addAll(bestCentroids);
+ }
+ if (correctWeights) {
+ for (WeightedVector testDatapoint : trainTestSplit.getSecond()) {
+ WeightedVector closest = (WeightedVector) centroids.searchFirst(testDatapoint, false).getValue();
+ closest.setWeight(closest.getWeight() + testDatapoint.getWeight());
+ }
+ }
+ return centroids;
+ }
+
+ /**
+ * Selects some of the original points randomly with probability proportional to their weights. This is much
+ * less sophisticated than the kmeans++ approach, however it is faster and coupled with
+ *
+ * The side effect of this method is to fill the centroids structure itself.
+ *
+ * @param datapoints The datapoints to select from. These datapoints should be WeightedVectors of some kind.
+ */
+ private void initializeSeedsRandomly(List<? extends WeightedVector> datapoints) {
+ int numDatapoints = datapoints.size();
+ double totalWeight = 0;
+ for (WeightedVector datapoint : datapoints) {
+ totalWeight += datapoint.getWeight();
+ }
+ Multinomial<Integer> seedSelector = new Multinomial<>();
+ for (int i = 0; i < numDatapoints; ++i) {
+ seedSelector.add(i, datapoints.get(i).getWeight() / totalWeight);
+ }
+ for (int i = 0; i < numClusters; ++i) {
+ int sample = seedSelector.sample();
+ seedSelector.delete(sample);
+ Centroid centroid = new Centroid(datapoints.get(sample));
+ centroid.setIndex(i);
+ centroids.add(centroid);
+ }
+ }
+
+ /**
+ * Selects some of the original points according to the k-means++ algorithm. The basic idea is that
+ * points are selected with probability proportional to their distance from any selected point. In
+ * this version, points have weights which multiply their likelihood of being selected. This is the
+ * same as if there were as many copies of the same point as indicated by the weight.
+ *
+ * This is pretty expensive, but it vastly improves the quality and convergences of the k-means algorithm.
+ * The basic idea can be made much faster by only processing a random subset of the original points.
+ * In the context of streaming k-means, the total number of possible seeds will be about k log n so this
+ * selection will cost O(k^2 (log n)^2) which isn't much worse than the random sampling idea. At
+ * n = 10^9, the cost of this initialization will be about 10x worse than a reasonable random sampling
+ * implementation.
+ *
+ * The side effect of this method is to fill the centroids structure itself.
+ *
+ * @param datapoints The datapoints to select from. These datapoints should be WeightedVectors of some kind.
+ */
+ private void initializeSeedsKMeansPlusPlus(List<? extends WeightedVector> datapoints) {
+ Preconditions.checkArgument(datapoints.size() > 1, "Must have at least two datapoints points to cluster " +
+ "sensibly");
+ Preconditions.checkArgument(datapoints.size() >= numClusters,
+ String.format("Must have more datapoints [%d] than clusters [%d]", datapoints.size(), numClusters));
+ // Compute the centroid of all of the datapoints. This is then used to compute the squared radius of the datapoints.
+ Centroid center = new Centroid(datapoints.iterator().next());
+ for (WeightedVector row : Iterables.skip(datapoints, 1)) {
+ center.update(row);
+ }
+
+ // Given the centroid, we can compute \Delta_1^2(X), the total squared distance for the datapoints
+ // this accelerates seed selection.
+ double deltaX = 0;
+ DistanceMeasure distanceMeasure = centroids.getDistanceMeasure();
+ for (WeightedVector row : datapoints) {
+ deltaX += distanceMeasure.distance(row, center);
+ }
+
+ // Find the first seed c_1 (and conceptually the second, c_2) as might be done in the 2-means clustering so that
+ // the probability of selecting c_1 and c_2 is proportional to || c_1 - c_2 ||^2. This is done
+ // by first selecting c_1 with probability:
+ //
+ // p(c_1) = sum_{c_1} || c_1 - c_2 ||^2 \over sum_{c_1, c_2} || c_1 - c_2 ||^2
+ //
+ // This can be simplified to:
+ //
+ // p(c_1) = \Delta_1^2(X) + n || c_1 - c ||^2 / (2 n \Delta_1^2(X))
+ //
+ // where c = \sum x / n and \Delta_1^2(X) = sum || x - c ||^2
+ //
+ // All subsequent seeds c_i (including c_2) can then be selected from the remaining points with probability
+ // proportional to Pr(c_i == x_j) = min_{m < i} || c_m - x_j ||^2.
+
+ // Multinomial distribution of vector indices for the selection seeds. These correspond to
+ // the indices of the vectors in the original datapoints list.
+ Multinomial<Integer> seedSelector = new Multinomial<>();
+ for (int i = 0; i < datapoints.size(); ++i) {
+ double selectionProbability =
+ deltaX + datapoints.size() * distanceMeasure.distance(datapoints.get(i), center);
+ seedSelector.add(i, selectionProbability);
+ }
+
+ int selected = random.nextInt(datapoints.size());
+ Centroid c_1 = new Centroid(datapoints.get(selected).clone());
+ c_1.setIndex(0);
+ // Construct a set of weighted things which can be used for random selection. Initial weights are
+ // set to the squared distance from c_1
+ for (int i = 0; i < datapoints.size(); ++i) {
+ WeightedVector row = datapoints.get(i);
+ double w = distanceMeasure.distance(c_1, row) * 2 * Math.log(1 + row.getWeight());
+ seedSelector.set(i, w);
+ }
+
+ // From here, seeds are selected with probability proportional to:
+ //
+ // r_i = min_{c_j} || x_i - c_j ||^2
+ //
+ // when we only have c_1, we have already set these distances and as we select each new
+ // seed, we update the minimum distances.
+ centroids.add(c_1);
+ int clusterIndex = 1;
+ while (centroids.size() < numClusters) {
+ // Select according to weights.
+ int seedIndex = seedSelector.sample();
+ Centroid nextSeed = new Centroid(datapoints.get(seedIndex));
+ nextSeed.setIndex(clusterIndex++);
+ centroids.add(nextSeed);
+ // Don't select this one again.
+ seedSelector.delete(seedIndex);
+ // Re-weight everything according to the minimum distance to a seed.
+ for (int currSeedIndex : seedSelector) {
+ WeightedVector curr = datapoints.get(currSeedIndex);
+ double newWeight = nextSeed.getWeight() * distanceMeasure.distance(nextSeed, curr);
+ if (newWeight < seedSelector.getWeight(currSeedIndex)) {
+ seedSelector.set(currSeedIndex, newWeight);
+ }
+ }
+ }
+ }
+
+ /**
+ * Examines the datapoints and updates cluster centers to be the centroid of the nearest datapoints points. To
+ * compute a new center for cluster c_i, we average all points that are closer than d_i * trimFraction
+ * where d_i is
+ *
+ * d_i = min_j \sqrt ||c_j - c_i||^2
+ *
+ * By ignoring distant points, the centroids converge more quickly to a good approximation of the
+ * optimal k-means solution (given good starting points).
+ *
+ * @param datapoints the points to cluster.
+ */
+ private void iterativeAssignment(List<? extends WeightedVector> datapoints) {
+ DistanceMeasure distanceMeasure = centroids.getDistanceMeasure();
+ // closestClusterDistances.get(i) is the distance from the i'th cluster to its closest
+ // neighboring cluster.
+ List<Double> closestClusterDistances = Lists.newArrayListWithExpectedSize(numClusters);
+ // clusterAssignments[i] == j means that the i'th point is assigned to the j'th cluster. When
+ // these don't change, we are done.
+ // Each point is assigned to the invalid "-1" cluster initially.
+ List<Integer> clusterAssignments = Lists.newArrayList(Collections.nCopies(datapoints.size(), -1));
+
+ boolean changed = true;
+ for (int i = 0; changed && i < maxNumIterations; i++) {
+ changed = false;
+ // We compute what the distance between each cluster and its closest neighbor is to set a
+ // proportional distance threshold for points that should be involved in calculating the
+ // centroid.
+ closestClusterDistances.clear();
+ for (Vector center : centroids) {
+ // If a centroid has no points assigned to it, the clustering failed.
+ Vector closestOtherCluster = centroids.searchFirst(center, true).getValue();
+ closestClusterDistances.add(distanceMeasure.distance(center, closestOtherCluster));
+ }
+
+ // Copies the current cluster centroids to newClusters and sets their weights to 0. This is
+ // so we calculate the new centroids as we go through the datapoints.
+ List<Centroid> newCentroids = Lists.newArrayList();
+ for (Vector centroid : centroids) {
+ // need a deep copy because we will mutate these values
+ Centroid newCentroid = (Centroid)centroid.clone();
+ newCentroid.setWeight(0);
+ newCentroids.add(newCentroid);
+ }
+
+ // Pass over the datapoints computing new centroids.
+ for (int j = 0; j < datapoints.size(); ++j) {
+ WeightedVector datapoint = datapoints.get(j);
+ // Get the closest cluster this point belongs to.
+ WeightedThing<Vector> closestPair = centroids.searchFirst(datapoint, false);
+ int closestIndex = ((WeightedVector) closestPair.getValue()).getIndex();
+ double closestDistance = closestPair.getWeight();
+ // Update its cluster assignment if necessary.
+ if (closestIndex != clusterAssignments.get(j)) {
+ changed = true;
+ clusterAssignments.set(j, closestIndex);
+ }
+ // Only update if the datapoints point is near enough. What this means is that the weight
+ // of outliers is NOT taken into account and the final weights of the centroids will
+ // reflect this (it will be less or equal to the initial sum of the weights).
+ if (closestDistance < trimFraction * closestClusterDistances.get(closestIndex)) {
+ newCentroids.get(closestIndex).update(datapoint);
+ }
+ }
+ // Add the new centers back into searcher.
+ centroids.clear();
+ centroids.addAll(newCentroids);
+ }
+
+ if (correctWeights) {
+ for (Vector v : centroids) {
+ ((Centroid)v).setWeight(0);
+ }
+ for (WeightedVector datapoint : datapoints) {
+ Centroid closestCentroid = (Centroid) centroids.searchFirst(datapoint, false).getValue();
+ closestCentroid.setWeight(closestCentroid.getWeight() + datapoint.getWeight());
+ }
+ }
+ }
+
+ @Override
+ public Iterator<Centroid> iterator() {
+ return Iterators.transform(centroids.iterator(), new Function<Vector, Centroid>() {
+ @Override
+ public Centroid apply(Vector input) {
+ Preconditions.checkArgument(input instanceof Centroid, "Non-centroid in centroids " +
+ "searcher");
+ //noinspection ConstantConditions
+ return (Centroid)input;
+ }
+ });
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java
new file mode 100644
index 0000000..0e3f068
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java
@@ -0,0 +1,368 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.cluster;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.jet.math.Constants;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+import org.apache.mahout.math.random.WeightedThing;
+
+/**
+ * Implements a streaming k-means algorithm for weighted vectors.
+ * The goal clustering points one at a time, especially useful for MapReduce mappers that get inputs one at a time.
+ *
+ * A rough description of the algorithm:
+ * Suppose there are l clusters at one point and a new point p is added.
+ * The new point can either be added to one of the existing l clusters or become a new cluster. To decide:
+ * - let c be the closest cluster to point p;
+ * - let d be the distance between c and p;
+ * - if d > distanceCutoff, create a new cluster from p (p is too far away from the clusters to be part of them;
+ * distanceCutoff represents the largest distance from a point its assigned cluster's centroid);
+ * - else (d <= distanceCutoff), create a new cluster with probability d / distanceCutoff (the probability of creating
+ * a new cluster increases as d increases).
+ * There will be either l points or l + 1 points after processing a new point.
+ *
+ * As the number of clusters increases, it will go over the numClusters limit (numClusters represents a recommendation
+ * for the number of clusters that there should be at the end). To decrease the number of clusters the existing clusters
+ * are treated as data points and are re-clustered (collapsed). This tends to make the number of clusters go down.
+ * If the number of clusters is still too high, distanceCutoff is increased.
+ *
+ * For more details, see:
+ * - "Streaming k-means approximation" by N. Ailon, R. Jaiswal, C. Monteleoni
+ * http://books.nips.cc/papers/files/nips22/NIPS2009_1085.pdf
+ * - "Fast and Accurate k-means for Large Datasets" by M. Shindler, A. Wong, A. Meyerson,
+ * http://books.nips.cc/papers/files/nips24/NIPS2011_1271.pdf
+ */
+public class StreamingKMeans implements Iterable<Centroid> {
+ /**
+ * The searcher containing the centroids that resulted from the clustering of points until now. When adding a new
+ * point we either assign it to one of the existing clusters in this searcher or create a new centroid for it.
+ */
+ private final UpdatableSearcher centroids;
+
+ /**
+ * The estimated number of clusters to cluster the data in. If the actual number of clusters increases beyond this
+ * limit, the clusters will be "collapsed" (re-clustered, by treating them as data points). This doesn't happen
+ * recursively and a collapse might not necessarily make the number of actual clusters drop to less than this limit.
+ *
+ * If the goal is clustering a large data set into k clusters, numClusters SHOULD NOT BE SET to k. StreamingKMeans is
+ * useful to reduce the size of the data set by the mappers so that it can fit into memory in one reducer that runs
+ * BallKMeans.
+ *
+ * It is NOT MEANT to cluster the data into k clusters in one pass because it can't guarantee that there will in fact
+ * be k clusters in total. This is because of the dynamic nature of numClusters over the course of the runtime.
+ * To get an exact number of clusters, another clustering algorithm needs to be applied to the results.
+ */
+ private int numClusters;
+
+ /**
+ * The number of data points seen so far. This is important for re-estimating numClusters when deciding to collapse
+ * the existing clusters.
+ */
+ private int numProcessedDatapoints = 0;
+
+ /**
+ * This is the current value of the distance cutoff. Points which are much closer than this to a centroid will stick
+ * to it almost certainly. Points further than this to any centroid will form a new cluster.
+ *
+ * This increases (is multiplied by beta) when a cluster collapse did not make the number of clusters drop to below
+ * numClusters (it effectively increases the tolerance for cluster compactness discouraging the creation of new
+ * clusters). Since a collapse only happens when centroids.size() > clusterOvershoot * numClusters, the cutoff
+ * increases when the collapse didn't at least remove the slack in the number of clusters.
+ */
+ private double distanceCutoff;
+
+ /**
+ * Parameter that controls the growth of the distanceCutoff. After n increases of the
+ * distanceCutoff starting at d_0, the final value is d_0 * beta^n (distance cutoffs increase following a geometric
+ * progression with ratio beta).
+ */
+ private final double beta;
+
+ /**
+ * Multiplying clusterLogFactor with numProcessedDatapoints gets an estimate of the suggested
+ * number of clusters. This mirrors the recommended number of clusters for n points where there should be k actual
+ * clusters, k * log n. In the case of our estimate we use clusterLogFactor * log(numProcessedDataPoints).
+ *
+ * It is important to note that numClusters is NOT k. It is an estimate of k * log n.
+ */
+ private final double clusterLogFactor;
+
+ /**
+ * Centroids are collapsed when the number of clusters becomes greater than clusterOvershoot * numClusters. This
+ * effectively means having a slack in numClusters so that the actual number of centroids, centroids.size() tracks
+ * numClusters approximately. The idea is that the actual number of clusters should be at least numClusters but not
+ * much more (so that we don't end up having 1 cluster / point).
+ */
+ private final double clusterOvershoot;
+
+ /**
+ * Random object to sample values from.
+ */
+ private final Random random = RandomUtils.getRandom();
+
+ /**
+ * Calls StreamingKMeans(searcher, numClusters, 1.3, 10, 2).
+ * @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int,
+ * double, double, double, double)
+ */
+ public StreamingKMeans(UpdatableSearcher searcher, int numClusters) {
+ this(searcher, numClusters, 1.0 / numClusters, 1.3, 20, 2);
+ }
+
+ /**
+ * Calls StreamingKMeans(searcher, numClusters, distanceCutoff, 1.3, 10, 2).
+ * @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int,
+ * double, double, double, double)
+ */
+ public StreamingKMeans(UpdatableSearcher searcher, int numClusters, double distanceCutoff) {
+ this(searcher, numClusters, distanceCutoff, 1.3, 20, 2);
+ }
+
+ /**
+ * Creates a new StreamingKMeans class given a searcher and the number of clusters to generate.
+ *
+ * @param searcher A Searcher that is used for performing nearest neighbor search. It MUST BE
+ * EMPTY initially because it will be used to keep track of the cluster
+ * centroids.
+ * @param numClusters An estimated number of clusters to generate for the data points.
+ * This can adjusted, but the actual number will depend on the data. The
+ * @param distanceCutoff The initial distance cutoff representing the value of the
+ * distance between a point and its closest centroid after which
+ * the new point will definitely be assigned to a new cluster.
+ * @param beta Ratio of geometric progression to use when increasing distanceCutoff. After n increases, distanceCutoff
+ * becomes distanceCutoff * beta^n. A smaller value increases the distanceCutoff less aggressively.
+ * @param clusterLogFactor Value multiplied with the number of points counted so far estimating the number of clusters
+ * to aim for. If the final number of clusters is known and this clustering is only for a
+ * sketch of the data, this can be the final number of clusters, k.
+ * @param clusterOvershoot Multiplicative slack factor for slowing down the collapse of the clusters.
+ */
+ public StreamingKMeans(UpdatableSearcher searcher, int numClusters,
+ double distanceCutoff, double beta, double clusterLogFactor, double clusterOvershoot) {
+ this.centroids = searcher;
+ this.numClusters = numClusters;
+ this.distanceCutoff = distanceCutoff;
+ this.beta = beta;
+ this.clusterLogFactor = clusterLogFactor;
+ this.clusterOvershoot = clusterOvershoot;
+ }
+
+ /**
+ * @return an Iterator to the Centroids contained in this clusterer.
+ */
+ @Override
+ public Iterator<Centroid> iterator() {
+ return Iterators.transform(centroids.iterator(), new Function<Vector, Centroid>() {
+ @Override
+ public Centroid apply(Vector input) {
+ return (Centroid)input;
+ }
+ });
+ }
+
+ /**
+ * Cluster the rows of a matrix, treating them as Centroids with weight 1.
+ * @param data matrix whose rows are to be clustered.
+ * @return the UpdatableSearcher containing the resulting centroids.
+ */
+ public UpdatableSearcher cluster(Matrix data) {
+ return cluster(Iterables.transform(data, new Function<MatrixSlice, Centroid>() {
+ @Override
+ public Centroid apply(MatrixSlice input) {
+ // The key in a Centroid is actually the MatrixSlice's index.
+ return Centroid.create(input.index(), input.vector());
+ }
+ }));
+ }
+
+ /**
+ * Cluster the data points in an Iterable<Centroid>.
+ * @param datapoints Iterable whose elements are to be clustered.
+ * @return the UpdatableSearcher containing the resulting centroids.
+ */
+ public UpdatableSearcher cluster(Iterable<Centroid> datapoints) {
+ return clusterInternal(datapoints, false);
+ }
+
+ /**
+ * Cluster one data point.
+ * @param datapoint to be clustered.
+ * @return the UpdatableSearcher containing the resulting centroids.
+ */
+ public UpdatableSearcher cluster(final Centroid datapoint) {
+ return cluster(new Iterable<Centroid>() {
+ @Override
+ public Iterator<Centroid> iterator() {
+ return new Iterator<Centroid>() {
+ private boolean accessed = false;
+
+ @Override
+ public boolean hasNext() {
+ return !accessed;
+ }
+
+ @Override
+ public Centroid next() {
+ accessed = true;
+ return datapoint;
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+ };
+ }
+ });
+ }
+
+ /**
+ * @return the number of clusters computed from the points until now.
+ */
+ public int getNumClusters() {
+ return centroids.size();
+ }
+
+ /**
+ * Internal clustering method that gets called from the other wrappers.
+ * @param datapoints Iterable of data points to be clustered.
+ * @param collapseClusters whether this is an "inner" clustering and the datapoints are the previously computed
+ * centroids. Some logic is different to ensure counters are consistent but it behaves
+ * nearly the same.
+ * @return the UpdatableSearcher containing the resulting centroids.
+ */
+ private UpdatableSearcher clusterInternal(Iterable<Centroid> datapoints, boolean collapseClusters) {
+ Iterator<Centroid> datapointsIterator = datapoints.iterator();
+ if (!datapointsIterator.hasNext()) {
+ return centroids;
+ }
+
+ int oldNumProcessedDataPoints = numProcessedDatapoints;
+ // We clear the centroids we have in case of cluster collapse, the old clusters are the
+ // datapoints but we need to re-cluster them.
+ if (collapseClusters) {
+ centroids.clear();
+ numProcessedDatapoints = 0;
+ }
+
+ if (centroids.size() == 0) {
+ // Assign the first datapoint to the first cluster.
+ // Adding a vector to a searcher would normally just reference the copy,
+ // but we could potentially mutate it and so we need to make a clone.
+ centroids.add(datapointsIterator.next().clone());
+ ++numProcessedDatapoints;
+ }
+
+ // To cluster, we scan the data and either add each point to the nearest group or create a new group.
+ // when we get too many groups, we need to increase the threshold and rescan our current groups
+ while (datapointsIterator.hasNext()) {
+ Centroid row = datapointsIterator.next();
+ // Get the closest vector and its weight as a WeightedThing<Vector>.
+ // The weight of the WeightedThing is the distance to the query and the value is a
+ // reference to one of the vectors we added to the searcher previously.
+ WeightedThing<Vector> closestPair = centroids.searchFirst(row, false);
+
+ // We get a uniformly distributed random number between 0 and 1 and compare it with the
+ // distance to the closest cluster divided by the distanceCutoff.
+ // This is so that if the closest cluster is further than distanceCutoff,
+ // closestPair.getWeight() / distanceCutoff > 1 which will trigger the creation of a new
+ // cluster anyway.
+ // However, if the ratio is less than 1, we want to create a new cluster with probability
+ // proportional to the distance to the closest cluster.
+ double sample = random.nextDouble();
+ if (sample < row.getWeight() * closestPair.getWeight() / distanceCutoff) {
+ // Add new centroid, note that the vector is copied because we may mutate it later.
+ centroids.add(row.clone());
+ } else {
+ // Merge the new point with the existing centroid. This will update the centroid's actual
+ // position.
+ // We know that all the points we inserted in the centroids searcher are (or extend)
+ // WeightedVector, so the cast will always succeed.
+ Centroid centroid = (Centroid) closestPair.getValue();
+
+ // We will update the centroid by removing it from the searcher and reinserting it to
+ // ensure consistency.
+ if (!centroids.remove(centroid, Constants.EPSILON)) {
+ throw new RuntimeException("Unable to remove centroid");
+ }
+ centroid.update(row);
+ centroids.add(centroid);
+
+ }
+ ++numProcessedDatapoints;
+
+ if (!collapseClusters && centroids.size() > clusterOvershoot * numClusters) {
+ numClusters = (int) Math.max(numClusters, clusterLogFactor * Math.log(numProcessedDatapoints));
+
+ List<Centroid> shuffled = Lists.newArrayList();
+ for (Vector vector : centroids) {
+ shuffled.add((Centroid) vector);
+ }
+ Collections.shuffle(shuffled);
+ // Re-cluster using the shuffled centroids as data points. The centroids member variable
+ // is modified directly.
+ clusterInternal(shuffled, true);
+
+ if (centroids.size() > numClusters) {
+ distanceCutoff *= beta;
+ }
+ }
+ }
+
+ if (collapseClusters) {
+ numProcessedDatapoints = oldNumProcessedDataPoints;
+ }
+ return centroids;
+ }
+
+ public void reindexCentroids() {
+ int numCentroids = 0;
+ for (Centroid centroid : this) {
+ centroid.setIndex(numCentroids++);
+ }
+ }
+
+ /**
+ * @return the distanceCutoff (an upper bound for the maximum distance within a cluster).
+ */
+ public double getDistanceCutoff() {
+ return distanceCutoff;
+ }
+
+ public void setDistanceCutoff(double distanceCutoff) {
+ this.distanceCutoff = distanceCutoff;
+ }
+
+ public DistanceMeasure getDistanceMeasure() {
+ return centroids.getDistanceMeasure();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java
new file mode 100644
index 0000000..a41940b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+public class CentroidWritable implements Writable {
+ private Centroid centroid = null;
+
+ public CentroidWritable() {}
+
+ public CentroidWritable(Centroid centroid) {
+ this.centroid = centroid;
+ }
+
+ public Centroid getCentroid() {
+ return centroid;
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ dataOutput.writeInt(centroid.getIndex());
+ dataOutput.writeDouble(centroid.getWeight());
+ VectorWritable.writeVector(dataOutput, centroid.getVector());
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ if (centroid == null) {
+ centroid = read(dataInput);
+ return;
+ }
+ centroid.setIndex(dataInput.readInt());
+ centroid.setWeight(dataInput.readDouble());
+ centroid.assign(VectorWritable.readVector(dataInput));
+ }
+
+ public static Centroid read(DataInput dataInput) throws IOException {
+ int index = dataInput.readInt();
+ double weight = dataInput.readDouble();
+ Vector v = VectorWritable.readVector(dataInput);
+ return new Centroid(index, v, weight);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof CentroidWritable)) {
+ return false;
+ }
+ CentroidWritable writable = (CentroidWritable) o;
+ return centroid.equals(writable.centroid);
+ }
+
+ @Override
+ public int hashCode() {
+ return centroid.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return centroid.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
new file mode 100644
index 0000000..73776b9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java
@@ -0,0 +1,493 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.ProjectionSearch;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Classifies the vectors into different clusters found by the clustering
+ * algorithm.
+ */
+public final class StreamingKMeansDriver extends AbstractJob {
+ /**
+ * Streaming KMeans options
+ */
+ /**
+ * The number of cluster that Mappers will use should be \(O(k log n)\) where k is the number of clusters
+ * to get at the end and n is the number of points to cluster. This doesn't need to be exact.
+ * It will be adjusted at runtime.
+ */
+ public static final String ESTIMATED_NUM_MAP_CLUSTERS = "estimatedNumMapClusters";
+ /**
+ * The initial estimated distance cutoff between two points for forming new clusters.
+ * @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans
+ * Defaults to 10e-6.
+ */
+ public static final String ESTIMATED_DISTANCE_CUTOFF = "estimatedDistanceCutoff";
+
+ /**
+ * Ball KMeans options
+ */
+ /**
+ * After mapping finishes, we get an intermediate set of vectors that represent approximate
+ * clusterings of the data from each Mapper. These can be clustered by the Reducer using
+ * BallKMeans in memory. This variable is the maximum number of iterations in the final
+ * BallKMeans algorithm.
+ * Defaults to 10.
+ */
+ public static final String MAX_NUM_ITERATIONS = "maxNumIterations";
+ /**
+ * The "ball" aspect of ball k-means means that only the closest points to the centroid will actually be used
+ * for updating. The fraction of the points to be used is those points whose distance to the center is within
+ * trimFraction * distance to the closest other center.
+ * Defaults to 0.9.
+ */
+ public static final String TRIM_FRACTION = "trimFraction";
+ /**
+ * Whether to use k-means++ initialization or random initialization of the seed centroids.
+ * Essentially, k-means++ provides better clusters, but takes longer, whereas random initialization takes less
+ * time, but produces worse clusters, and tends to fail more often and needs multiple runs to compare to
+ * k-means++. If set, uses randomInit.
+ * @see org.apache.mahout.clustering.streaming.cluster.BallKMeans
+ */
+ public static final String RANDOM_INIT = "randomInit";
+ /**
+ * Whether to correct the weights of the centroids after the clustering is done. The weights end up being wrong
+ * because of the trimFraction and possible train/test splits. In some cases, especially in a pipeline, having
+ * an accurate count of the weights is useful. If set, ignores the final weights.
+ */
+ public static final String IGNORE_WEIGHTS = "ignoreWeights";
+ /**
+ * The percentage of points that go into the "test" set when evaluating BallKMeans runs in the reducer.
+ */
+ public static final String TEST_PROBABILITY = "testProbability";
+ /**
+ * The percentage of points that go into the "training" set when evaluating BallKMeans runs in the reducer.
+ */
+ public static final String NUM_BALLKMEANS_RUNS = "numBallKMeansRuns";
+
+ /**
+ Searcher options
+ */
+ /**
+ * The Searcher class when performing nearest neighbor search in StreamingKMeans.
+ * Defaults to ProjectionSearch.
+ */
+ public static final String SEARCHER_CLASS_OPTION = "searcherClass";
+ /**
+ * The number of projections to use when using a projection searcher like ProjectionSearch or
+ * FastProjectionSearch. Projection searches work by projection the all the vectors on to a set of
+ * basis vectors and searching for the projected query in that totally ordered set. This
+ * however can produce false positives (vectors that are closer when projected than they would
+ * actually be.
+ * So, there must be more than one projection vectors in the basis. This variable is the number
+ * of vectors in a basis.
+ * Defaults to 3
+ */
+ public static final String NUM_PROJECTIONS_OPTION = "numProjections";
+ /**
+ * When using approximate searches (anything that's not BruteSearch),
+ * more than just the seemingly closest element must be considered. This variable has different
+ * meanings depending on the actual Searcher class used but is a measure of how many candidates
+ * will be considered.
+ * See the ProjectionSearch, FastProjectionSearch, LocalitySensitiveHashSearch classes for more
+ * details.
+ * Defaults to 2.
+ */
+ public static final String SEARCH_SIZE_OPTION = "searchSize";
+
+ /**
+ * Whether to run another pass of StreamingKMeans on the reducer's points before BallKMeans. On some data sets
+ * with a large number of mappers, the intermediate number of clusters passed to the reducer is too large to
+ * fit into memory directly, hence the option to collapse the clusters further with StreamingKMeans.
+ */
+ public static final String REDUCE_STREAMING_KMEANS = "reduceStreamingKMeans";
+
+ private static final Logger log = LoggerFactory.getLogger(StreamingKMeansDriver.class);
+
+ public static final float INVALID_DISTANCE_CUTOFF = -1;
+
+ @Override
+ public int run(String[] args) throws Exception {
+ // Standard options for any Mahout job.
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.overwriteOption().create());
+
+ // The number of clusters to create for the data.
+ addOption(DefaultOptionCreator.numClustersOption().withDescription(
+ "The k in k-Means. Approximately this many clusters will be generated.").create());
+
+ // StreamingKMeans (mapper) options
+ // There will be k final clusters, but in the Map phase to get a good approximation of the data, O(k log n)
+ // clusters are needed. Since n is the number of data points and not knowable until reading all the vectors,
+ // provide a decent estimate.
+ addOption(ESTIMATED_NUM_MAP_CLUSTERS, "km", "The estimated number of clusters to use for the "
+ + "Map phase of the job when running StreamingKMeans. This should be around k * log(n), "
+ + "where k is the final number of clusters and n is the total number of data points to "
+ + "cluster.", String.valueOf(1));
+
+ addOption(ESTIMATED_DISTANCE_CUTOFF, "e", "The initial estimated distance cutoff between two "
+ + "points for forming new clusters. If no value is given, it's estimated from the data set",
+ String.valueOf(INVALID_DISTANCE_CUTOFF));
+
+ // BallKMeans (reducer) options
+ addOption(MAX_NUM_ITERATIONS, "mi", "The maximum number of iterations to run for the "
+ + "BallKMeans algorithm used by the reducer. If no value is given, defaults to 10.", String.valueOf(10));
+
+ addOption(TRIM_FRACTION, "tf", "The 'ball' aspect of ball k-means means that only the closest points "
+ + "to the centroid will actually be used for updating. The fraction of the points to be used is those "
+ + "points whose distance to the center is within trimFraction * distance to the closest other center. "
+ + "If no value is given, defaults to 0.9.", String.valueOf(0.9));
+
+ addFlag(RANDOM_INIT, "ri", "Whether to use k-means++ initialization or random initialization "
+ + "of the seed centroids. Essentially, k-means++ provides better clusters, but takes longer, whereas random "
+ + "initialization takes less time, but produces worse clusters, and tends to fail more often and needs "
+ + "multiple runs to compare to k-means++. If set, uses the random initialization.");
+
+ addFlag(IGNORE_WEIGHTS, "iw", "Whether to correct the weights of the centroids after the clustering is done. "
+ + "The weights end up being wrong because of the trimFraction and possible train/test splits. In some cases, "
+ + "especially in a pipeline, having an accurate count of the weights is useful. If set, ignores the final "
+ + "weights");
+
+ addOption(TEST_PROBABILITY, "testp", "A double value between 0 and 1 that represents the percentage of "
+ + "points to be used for 'testing' different clustering runs in the final BallKMeans "
+ + "step. If no value is given, defaults to 0.1", String.valueOf(0.1));
+
+ addOption(NUM_BALLKMEANS_RUNS, "nbkm", "Number of BallKMeans runs to use at the end to try to cluster the "
+ + "points. If no value is given, defaults to 4", String.valueOf(4));
+
+ // Nearest neighbor search options
+ // The distance measure used for computing the distance between two points. Generally, the
+ // SquaredEuclideanDistance is used for clustering problems (it's equivalent to CosineDistance for normalized
+ // vectors).
+ // WARNING! You can use any metric but most of the literature is for the squared euclidean distance.
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+
+ // The default searcher should be something more efficient that BruteSearch (ProjectionSearch, ...). See
+ // o.a.m.math.neighborhood.*
+ addOption(SEARCHER_CLASS_OPTION, "sc", "The type of searcher to be used when performing nearest "
+ + "neighbor searches. Defaults to ProjectionSearch.", ProjectionSearch.class.getCanonicalName());
+
+ // In the original paper, the authors used 1 projection vector.
+ addOption(NUM_PROJECTIONS_OPTION, "np", "The number of projections considered in estimating the "
+ + "distances between vectors. Only used when the distance measure requested is either "
+ + "ProjectionSearch or FastProjectionSearch. If no value is given, defaults to 3.", String.valueOf(3));
+
+ addOption(SEARCH_SIZE_OPTION, "s", "In more efficient searches (non BruteSearch), "
+ + "not all distances are calculated for determining the nearest neighbors. The number of "
+ + "elements whose distances from the query vector is actually computer is proportional to "
+ + "searchSize. If no value is given, defaults to 1.", String.valueOf(2));
+
+ addFlag(REDUCE_STREAMING_KMEANS, "rskm", "There might be too many intermediate clusters from the mapper "
+ + "to fit into memory, so the reducer can run another pass of StreamingKMeans to collapse them down to a "
+ + "fewer clusters");
+
+ addOption(DefaultOptionCreator.methodOption().create());
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+ Path output = getOutputPath();
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ configureOptionsForWorkers();
+ run(getConf(), getInputPath(), output);
+ return 0;
+ }
+
+ private void configureOptionsForWorkers() throws ClassNotFoundException {
+ log.info("Starting to configure options for workers");
+
+ String method = getOption(DefaultOptionCreator.METHOD_OPTION);
+
+ int numClusters = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION));
+
+ // StreamingKMeans
+ int estimatedNumMapClusters = Integer.parseInt(getOption(ESTIMATED_NUM_MAP_CLUSTERS));
+ float estimatedDistanceCutoff = Float.parseFloat(getOption(ESTIMATED_DISTANCE_CUTOFF));
+
+ // BallKMeans
+ int maxNumIterations = Integer.parseInt(getOption(MAX_NUM_ITERATIONS));
+ float trimFraction = Float.parseFloat(getOption(TRIM_FRACTION));
+ boolean randomInit = hasOption(RANDOM_INIT);
+ boolean ignoreWeights = hasOption(IGNORE_WEIGHTS);
+ float testProbability = Float.parseFloat(getOption(TEST_PROBABILITY));
+ int numBallKMeansRuns = Integer.parseInt(getOption(NUM_BALLKMEANS_RUNS));
+
+ // Nearest neighbor search
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ String searcherClass = getOption(SEARCHER_CLASS_OPTION);
+
+ // Get more parameters depending on the kind of search class we're working with. BruteSearch
+ // doesn't need anything else.
+ // LocalitySensitiveHashSearch and ProjectionSearches need searchSize.
+ // ProjectionSearches also need the number of projections.
+ boolean getSearchSize = false;
+ boolean getNumProjections = false;
+ if (!searcherClass.equals(BruteSearch.class.getName())) {
+ getSearchSize = true;
+ getNumProjections = true;
+ }
+
+ // The search size to use. This is quite fuzzy and might end up not being configurable at all.
+ int searchSize = 0;
+ if (getSearchSize) {
+ searchSize = Integer.parseInt(getOption(SEARCH_SIZE_OPTION));
+ }
+
+ // The number of projections to use. This is only useful in projection searches which
+ // project the vectors on multiple basis vectors to get distance estimates that are faster to
+ // calculate.
+ int numProjections = 0;
+ if (getNumProjections) {
+ numProjections = Integer.parseInt(getOption(NUM_PROJECTIONS_OPTION));
+ }
+
+ boolean reduceStreamingKMeans = hasOption(REDUCE_STREAMING_KMEANS);
+
+ configureOptionsForWorkers(getConf(), numClusters,
+ /* StreamingKMeans */
+ estimatedNumMapClusters, estimatedDistanceCutoff,
+ /* BallKMeans */
+ maxNumIterations, trimFraction, randomInit, ignoreWeights, testProbability, numBallKMeansRuns,
+ /* Searcher */
+ measureClass, searcherClass, searchSize, numProjections,
+ method,
+ reduceStreamingKMeans);
+ }
+
+ /**
+ * Checks the parameters for a StreamingKMeans job and prepares a Configuration with them.
+ *
+ * @param conf the Configuration to populate
+ * @param numClusters k, the number of clusters at the end
+ * @param estimatedNumMapClusters O(k log n), the number of clusters requested from each mapper
+ * @param estimatedDistanceCutoff an estimate of the minimum distance that separates two clusters (can be smaller and
+ * will be increased dynamically)
+ * @param maxNumIterations the maximum number of iterations of BallKMeans
+ * @param trimFraction the fraction of the points to be considered in updating a ball k-means
+ * @param randomInit whether to initialize the ball k-means seeds randomly
+ * @param ignoreWeights whether to ignore the invalid final ball k-means weights
+ * @param testProbability the percentage of vectors assigned to the test set for selecting the best final centers
+ * @param numBallKMeansRuns the number of BallKMeans runs in the reducer that determine the centroids to return
+ * (clusters are computed for the training set and the error is computed on the test set)
+ * @param measureClass string, name of the distance measure class; theory works for Euclidean-like distances
+ * @param searcherClass string, name of the searcher that will be used for nearest neighbor search
+ * @param searchSize the number of closest neighbors to look at for selecting the closest one in approximate nearest
+ * neighbor searches
+ * @param numProjections the number of projected vectors to use for faster searching (only useful for ProjectionSearch
+ * or FastProjectionSearch); @see org.apache.mahout.math.neighborhood.ProjectionSearch
+ */
+ public static void configureOptionsForWorkers(Configuration conf,
+ int numClusters,
+ /* StreamingKMeans */
+ int estimatedNumMapClusters, float estimatedDistanceCutoff,
+ /* BallKMeans */
+ int maxNumIterations, float trimFraction, boolean randomInit,
+ boolean ignoreWeights, float testProbability, int numBallKMeansRuns,
+ /* Searcher */
+ String measureClass, String searcherClass,
+ int searchSize, int numProjections,
+ String method,
+ boolean reduceStreamingKMeans) throws ClassNotFoundException {
+ // Checking preconditions for the parameters.
+ Preconditions.checkArgument(numClusters > 0,
+ "Invalid number of clusters requested: " + numClusters + ". Must be: numClusters > 0!");
+
+ // StreamingKMeans
+ Preconditions.checkArgument(estimatedNumMapClusters > numClusters, "Invalid number of estimated map "
+ + "clusters; There must be more than the final number of clusters (k log n vs k)");
+ Preconditions.checkArgument(estimatedDistanceCutoff == INVALID_DISTANCE_CUTOFF || estimatedDistanceCutoff > 0,
+ "estimatedDistanceCutoff must be equal to -1 or must be greater then 0!");
+
+ // BallKMeans
+ Preconditions.checkArgument(maxNumIterations > 0, "Must have at least one BallKMeans iteration");
+ Preconditions.checkArgument(trimFraction > 0, "trimFraction must be positive");
+ Preconditions.checkArgument(testProbability >= 0 && testProbability < 1, "test probability is not in the "
+ + "interval [0, 1)");
+ Preconditions.checkArgument(numBallKMeansRuns > 0, "numBallKMeans cannot be negative");
+
+ // Searcher
+ if (!searcherClass.contains("Brute")) {
+ // These tests only make sense when a relevant searcher is being used.
+ Preconditions.checkArgument(searchSize > 0, "Invalid searchSize. Must be positive.");
+ if (searcherClass.contains("Projection")) {
+ Preconditions.checkArgument(numProjections > 0, "Invalid numProjections. Must be positive");
+ }
+ }
+
+ // Setting the parameters in the Configuration.
+ conf.setInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, numClusters);
+ /* StreamingKMeans */
+ conf.setInt(ESTIMATED_NUM_MAP_CLUSTERS, estimatedNumMapClusters);
+ if (estimatedDistanceCutoff != INVALID_DISTANCE_CUTOFF) {
+ conf.setFloat(ESTIMATED_DISTANCE_CUTOFF, estimatedDistanceCutoff);
+ }
+ /* BallKMeans */
+ conf.setInt(MAX_NUM_ITERATIONS, maxNumIterations);
+ conf.setFloat(TRIM_FRACTION, trimFraction);
+ conf.setBoolean(RANDOM_INIT, randomInit);
+ conf.setBoolean(IGNORE_WEIGHTS, ignoreWeights);
+ conf.setFloat(TEST_PROBABILITY, testProbability);
+ conf.setInt(NUM_BALLKMEANS_RUNS, numBallKMeansRuns);
+ /* Searcher */
+ // Checks if the measureClass is available, throws exception otherwise.
+ Class.forName(measureClass);
+ conf.set(DefaultOptionCreator.DISTANCE_MEASURE_OPTION, measureClass);
+ // Checks if the searcherClass is available, throws exception otherwise.
+ Class.forName(searcherClass);
+ conf.set(SEARCHER_CLASS_OPTION, searcherClass);
+ conf.setInt(SEARCH_SIZE_OPTION, searchSize);
+ conf.setInt(NUM_PROJECTIONS_OPTION, numProjections);
+ conf.set(DefaultOptionCreator.METHOD_OPTION, method);
+
+ conf.setBoolean(REDUCE_STREAMING_KMEANS, reduceStreamingKMeans);
+
+ log.info("Parameters are: [k] numClusters {}; "
+ + "[SKM] estimatedNumMapClusters {}; estimatedDistanceCutoff {} "
+ + "[BKM] maxNumIterations {}; trimFraction {}; randomInit {}; ignoreWeights {}; "
+ + "testProbability {}; numBallKMeansRuns {}; "
+ + "[S] measureClass {}; searcherClass {}; searcherSize {}; numProjections {}; "
+ + "method {}; reduceStreamingKMeans {}", numClusters, estimatedNumMapClusters, estimatedDistanceCutoff,
+ maxNumIterations, trimFraction, randomInit, ignoreWeights, testProbability, numBallKMeansRuns,
+ measureClass, searcherClass, searchSize, numProjections, method, reduceStreamingKMeans);
+ }
+
+ /**
+ * Iterate over the input vectors to produce clusters and, if requested, use the results of the final iteration to
+ * cluster the input vectors.
+ *
+ * @param input the directory pathname for input points.
+ * @param output the directory pathname for output points.
+ * @return 0 on success, -1 on failure.
+ */
+ public static int run(Configuration conf, Path input, Path output)
+ throws IOException, InterruptedException, ClassNotFoundException, ExecutionException {
+ log.info("Starting StreamingKMeans clustering for vectors in {}; results are output to {}",
+ input.toString(), output.toString());
+
+ if (conf.get(DefaultOptionCreator.METHOD_OPTION,
+ DefaultOptionCreator.MAPREDUCE_METHOD).equals(DefaultOptionCreator.SEQUENTIAL_METHOD)) {
+ return runSequentially(conf, input, output);
+ } else {
+ return runMapReduce(conf, input, output);
+ }
+ }
+
+ private static int runSequentially(Configuration conf, Path input, Path output)
+ throws IOException, ExecutionException, InterruptedException {
+ long start = System.currentTimeMillis();
+ // Run StreamingKMeans step in parallel by spawning 1 thread per input path to process.
+ ExecutorService pool = Executors.newCachedThreadPool();
+ List<Future<Iterable<Centroid>>> intermediateCentroidFutures = Lists.newArrayList();
+ for (FileStatus status : HadoopUtil.listStatus(FileSystem.get(conf), input, PathFilters.logsCRCFilter())) {
+ intermediateCentroidFutures.add(pool.submit(new StreamingKMeansThread(status.getPath(), conf)));
+ }
+ log.info("Finished running Mappers");
+ // Merge the resulting "mapper" centroids.
+ List<Centroid> intermediateCentroids = Lists.newArrayList();
+ for (Future<Iterable<Centroid>> futureIterable : intermediateCentroidFutures) {
+ for (Centroid centroid : futureIterable.get()) {
+ intermediateCentroids.add(centroid);
+ }
+ }
+ pool.shutdown();
+ pool.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS);
+ log.info("Finished StreamingKMeans");
+ SequenceFile.Writer writer = SequenceFile.createWriter(FileSystem.get(conf), conf, new Path(output, "part-r-00000"), IntWritable.class,
+ CentroidWritable.class);
+ int numCentroids = 0;
+ // Run BallKMeans on the intermediate centroids.
+ for (Vector finalVector : StreamingKMeansReducer.getBestCentroids(intermediateCentroids, conf)) {
+ Centroid finalCentroid = (Centroid)finalVector;
+ writer.append(new IntWritable(numCentroids++), new CentroidWritable(finalCentroid));
+ }
+ writer.close();
+ long end = System.currentTimeMillis();
+ log.info("Finished BallKMeans. Took {}.", (end - start) / 1000.0);
+ return 0;
+ }
+
+ public static int runMapReduce(Configuration conf, Path input, Path output)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ // Prepare Job for submission.
+ Job job = HadoopUtil.prepareJob(input, output, SequenceFileInputFormat.class,
+ StreamingKMeansMapper.class, IntWritable.class, CentroidWritable.class,
+ StreamingKMeansReducer.class, IntWritable.class, CentroidWritable.class, SequenceFileOutputFormat.class,
+ conf);
+ job.setJobName(HadoopUtil.getCustomJobName(StreamingKMeansDriver.class.getSimpleName(), job,
+ StreamingKMeansMapper.class, StreamingKMeansReducer.class));
+
+ // There is only one reducer so that the intermediate centroids get collected on one
+ // machine and are clustered in memory to get the right number of clusters.
+ job.setNumReduceTasks(1);
+
+ // Set the JAR (so that the required libraries are available) and run.
+ job.setJarByClass(StreamingKMeansDriver.class);
+
+ // Run job!
+ long start = System.currentTimeMillis();
+ if (!job.waitForCompletion(true)) {
+ return -1;
+ }
+ long end = System.currentTimeMillis();
+
+ log.info("StreamingKMeans clustering complete. Results are in {}. Took {} ms", output.toString(), end - start);
+ return 0;
+ }
+
+ /**
+ * Constructor to be used by the ToolRunner.
+ */
+ private StreamingKMeansDriver() {}
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new StreamingKMeansDriver(), args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
new file mode 100644
index 0000000..ced11ea
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+
+public class StreamingKMeansMapper extends Mapper<Writable, VectorWritable, IntWritable, CentroidWritable> {
+ private static final int NUM_ESTIMATE_POINTS = 1000;
+
+ /**
+ * The clusterer object used to cluster the points received by this mapper online.
+ */
+ private StreamingKMeans clusterer;
+
+ /**
+ * Number of points clustered so far.
+ */
+ private int numPoints = 0;
+
+ private boolean estimateDistanceCutoff = false;
+
+ private List<Centroid> estimatePoints;
+
+ @Override
+ public void setup(Context context) {
+ // At this point the configuration received from the Driver is assumed to be valid.
+ // No other checks are made.
+ Configuration conf = context.getConfiguration();
+ UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
+ int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
+ double estimatedDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF,
+ StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF);
+ if (estimatedDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) {
+ estimateDistanceCutoff = true;
+ estimatePoints = Lists.newArrayList();
+ }
+ // There is no way of estimating the distance cutoff unless we have some data.
+ clusterer = new StreamingKMeans(searcher, numClusters, estimatedDistanceCutoff);
+ }
+
+ private void clusterEstimatePoints() {
+ clusterer.setDistanceCutoff(ClusteringUtils.estimateDistanceCutoff(
+ estimatePoints, clusterer.getDistanceMeasure()));
+ clusterer.cluster(estimatePoints);
+ estimateDistanceCutoff = false;
+ }
+
+ @Override
+ public void map(Writable key, VectorWritable point, Context context) {
+ Centroid centroid = new Centroid(numPoints++, point.get(), 1);
+ if (estimateDistanceCutoff) {
+ if (numPoints < NUM_ESTIMATE_POINTS) {
+ estimatePoints.add(centroid);
+ } else if (numPoints == NUM_ESTIMATE_POINTS) {
+ clusterEstimatePoints();
+ }
+ } else {
+ clusterer.cluster(centroid);
+ }
+ }
+
+ @Override
+ public void cleanup(Context context) throws IOException, InterruptedException {
+ // We should cluster the points at the end if they haven't yet been clustered.
+ if (estimateDistanceCutoff) {
+ clusterEstimatePoints();
+ }
+ // Reindex the centroids before passing them to the reducer.
+ clusterer.reindexCentroids();
+ // All outputs have the same key to go to the same final reducer.
+ for (Centroid centroid : clusterer) {
+ context.write(new IntWritable(0), new CentroidWritable(centroid));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
new file mode 100644
index 0000000..2b78acc
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.clustering.streaming.cluster.BallKMeans;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class StreamingKMeansReducer extends Reducer<IntWritable, CentroidWritable, IntWritable, CentroidWritable> {
+
+ private static final Logger log = LoggerFactory.getLogger(StreamingKMeansReducer.class);
+
+ /**
+ * Configuration for the MapReduce job.
+ */
+ private Configuration conf;
+
+ @Override
+ public void setup(Context context) {
+ // At this point the configuration received from the Driver is assumed to be valid.
+ // No other checks are made.
+ conf = context.getConfiguration();
+ }
+
+ @Override
+ public void reduce(IntWritable key, Iterable<CentroidWritable> centroids,
+ Context context) throws IOException, InterruptedException {
+ List<Centroid> intermediateCentroids;
+ // There might be too many intermediate centroids to fit into memory, in which case, we run another pass
+ // of StreamingKMeans to collapse the clusters further.
+ if (conf.getBoolean(StreamingKMeansDriver.REDUCE_STREAMING_KMEANS, false)) {
+ intermediateCentroids = Lists.newArrayList(
+ new StreamingKMeansThread(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() {
+ @Override
+ public Centroid apply(CentroidWritable input) {
+ Preconditions.checkNotNull(input);
+ return input.getCentroid().clone();
+ }
+ }), conf).call());
+ } else {
+ intermediateCentroids = centroidWritablesToList(centroids);
+ }
+
+ int index = 0;
+ for (Vector centroid : getBestCentroids(intermediateCentroids, conf)) {
+ context.write(new IntWritable(index), new CentroidWritable((Centroid) centroid));
+ ++index;
+ }
+ }
+
+ public static List<Centroid> centroidWritablesToList(Iterable<CentroidWritable> centroids) {
+ // A new list must be created because Hadoop iterators mutate the contents of the Writable in
+ // place, without allocating new references when iterating through the centroids Iterable.
+ return Lists.newArrayList(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() {
+ @Override
+ public Centroid apply(CentroidWritable input) {
+ Preconditions.checkNotNull(input);
+ return input.getCentroid().clone();
+ }
+ }));
+ }
+
+ public static Iterable<Vector> getBestCentroids(List<Centroid> centroids, Configuration conf) {
+
+ if (log.isInfoEnabled()) {
+ log.info("Number of Centroids: {}", centroids.size());
+ }
+
+ int numClusters = conf.getInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, 1);
+ int maxNumIterations = conf.getInt(StreamingKMeansDriver.MAX_NUM_ITERATIONS, 10);
+ float trimFraction = conf.getFloat(StreamingKMeansDriver.TRIM_FRACTION, 0.9f);
+ boolean kMeansPlusPlusInit = !conf.getBoolean(StreamingKMeansDriver.RANDOM_INIT, false);
+ boolean correctWeights = !conf.getBoolean(StreamingKMeansDriver.IGNORE_WEIGHTS, false);
+ float testProbability = conf.getFloat(StreamingKMeansDriver.TEST_PROBABILITY, 0.1f);
+ int numRuns = conf.getInt(StreamingKMeansDriver.NUM_BALLKMEANS_RUNS, 3);
+
+ BallKMeans ballKMeansCluster = new BallKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(conf),
+ numClusters, maxNumIterations, trimFraction, kMeansPlusPlusInit, correctWeights, testProbability, numRuns);
+ return ballKMeansCluster.cluster(centroids);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
new file mode 100644
index 0000000..acb2b56
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java
@@ -0,0 +1,92 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.concurrent.Callable;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class StreamingKMeansThread implements Callable<Iterable<Centroid>> {
+ private static final Logger log = LoggerFactory.getLogger(StreamingKMeansThread.class);
+
+ private static final int NUM_ESTIMATE_POINTS = 1000;
+
+ private final Configuration conf;
+ private final Iterable<Centroid> dataPoints;
+
+ public StreamingKMeansThread(Path input, Configuration conf) {
+ this(StreamingKMeansUtilsMR.getCentroidsFromVectorWritable(
+ new SequenceFileValueIterable<VectorWritable>(input, false, conf)), conf);
+ }
+
+ public StreamingKMeansThread(Iterable<Centroid> dataPoints, Configuration conf) {
+ this.dataPoints = dataPoints;
+ this.conf = conf;
+ }
+
+ @Override
+ public Iterable<Centroid> call() {
+ UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf);
+ int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1);
+ double estimateDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF,
+ StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF);
+
+ Iterator<Centroid> dataPointsIterator = dataPoints.iterator();
+
+ if (estimateDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) {
+ List<Centroid> estimatePoints = Lists.newArrayListWithExpectedSize(NUM_ESTIMATE_POINTS);
+ while (dataPointsIterator.hasNext() && estimatePoints.size() < NUM_ESTIMATE_POINTS) {
+ Centroid centroid = dataPointsIterator.next();
+ estimatePoints.add(centroid);
+ }
+
+ if (log.isInfoEnabled()) {
+ log.info("Estimated Points: {}", estimatePoints.size());
+ }
+ estimateDistanceCutoff = ClusteringUtils.estimateDistanceCutoff(estimatePoints, searcher.getDistanceMeasure());
+ }
+
+ StreamingKMeans streamingKMeans = new StreamingKMeans(searcher, numClusters, estimateDistanceCutoff);
+
+ // datapointsIterator could be empty if no estimate distance was initially provided
+ // hence creating the iterator again here for the clustering
+ if (!dataPointsIterator.hasNext()) {
+ dataPointsIterator = dataPoints.iterator();
+ }
+
+ while (dataPointsIterator.hasNext()) {
+ streamingKMeans.cluster(dataPointsIterator.next());
+ }
+
+ streamingKMeans.reindexCentroids();
+ return streamingKMeans;
+ }
+
+}
[48/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
new file mode 100644
index 0000000..624a8c4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJob.java
@@ -0,0 +1,419 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.mapreduce.MergeVectorsCombiner;
+import org.apache.mahout.common.mapreduce.MergeVectorsReducer;
+import org.apache.mahout.common.mapreduce.TransposeMapper;
+import org.apache.mahout.common.mapreduce.VectorSumCombiner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.Vectors;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>MapReduce implementation of the two factorization algorithms described in
+ *
+ * <p>"Large-scale Parallel Collaborative Filtering for the Netflix Prize" available at
+ * http://www.hpl.hp.com/personal/Robert_Schreiber/papers/2008%20AAIM%20Netflix/netflix_aaim08(submitted).pdf.</p>
+ *
+ * "<p>Collaborative Filtering for Implicit Feedback Datasets" available at
+ * http://research.yahoo.com/pub/2433</p>
+ *
+ * </p>
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--input (path): Directory containing one or more text files with the dataset</li>
+ * <li>--output (path): path where output should go</li>
+ * <li>--lambda (double): regularization parameter to avoid overfitting</li>
+ * <li>--userFeatures (path): path to the user feature matrix</li>
+ * <li>--itemFeatures (path): path to the item feature matrix</li>
+ * <li>--numThreadsPerSolver (int): threads to use per solver mapper, (default: 1)</li>
+ * </ol>
+ */
+public class ParallelALSFactorizationJob extends AbstractJob {
+
+ private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJob.class);
+
+ static final String NUM_FEATURES = ParallelALSFactorizationJob.class.getName() + ".numFeatures";
+ static final String LAMBDA = ParallelALSFactorizationJob.class.getName() + ".lambda";
+ static final String ALPHA = ParallelALSFactorizationJob.class.getName() + ".alpha";
+ static final String NUM_ENTITIES = ParallelALSFactorizationJob.class.getName() + ".numEntities";
+
+ static final String USES_LONG_IDS = ParallelALSFactorizationJob.class.getName() + ".usesLongIDs";
+ static final String TOKEN_POS = ParallelALSFactorizationJob.class.getName() + ".tokenPos";
+
+ private boolean implicitFeedback;
+ private int numIterations;
+ private int numFeatures;
+ private double lambda;
+ private double alpha;
+ private int numThreadsPerSolver;
+
+ enum Stats { NUM_USERS }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new ParallelALSFactorizationJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("lambda", null, "regularization parameter", true);
+ addOption("implicitFeedback", null, "data consists of implicit feedback?", String.valueOf(false));
+ addOption("alpha", null, "confidence parameter (only used on implicit feedback)", String.valueOf(40));
+ addOption("numFeatures", null, "dimension of the feature space", true);
+ addOption("numIterations", null, "number of iterations", true);
+ addOption("numThreadsPerSolver", null, "threads per solver mapper", String.valueOf(1));
+ addOption("usesLongIDs", null, "input contains long IDs that need to be translated");
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ numFeatures = Integer.parseInt(getOption("numFeatures"));
+ numIterations = Integer.parseInt(getOption("numIterations"));
+ lambda = Double.parseDouble(getOption("lambda"));
+ alpha = Double.parseDouble(getOption("alpha"));
+ implicitFeedback = Boolean.parseBoolean(getOption("implicitFeedback"));
+
+ numThreadsPerSolver = Integer.parseInt(getOption("numThreadsPerSolver"));
+ boolean usesLongIDs = Boolean.parseBoolean(getOption("usesLongIDs", String.valueOf(false)));
+
+ /*
+ * compute the factorization A = U M'
+ *
+ * where A (users x items) is the matrix of known ratings
+ * U (users x features) is the representation of users in the feature space
+ * M (items x features) is the representation of items in the feature space
+ */
+
+ if (usesLongIDs) {
+ Job mapUsers = prepareJob(getInputPath(), getOutputPath("userIDIndex"), TextInputFormat.class,
+ MapLongIDsMapper.class, VarIntWritable.class, VarLongWritable.class, IDMapReducer.class,
+ VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class);
+ mapUsers.getConfiguration().set(TOKEN_POS, String.valueOf(TasteHadoopUtils.USER_ID_POS));
+ mapUsers.waitForCompletion(true);
+
+ Job mapItems = prepareJob(getInputPath(), getOutputPath("itemIDIndex"), TextInputFormat.class,
+ MapLongIDsMapper.class, VarIntWritable.class, VarLongWritable.class, IDMapReducer.class,
+ VarIntWritable.class, VarLongWritable.class, SequenceFileOutputFormat.class);
+ mapItems.getConfiguration().set(TOKEN_POS, String.valueOf(TasteHadoopUtils.ITEM_ID_POS));
+ mapItems.waitForCompletion(true);
+ }
+
+ /* create A' */
+ Job itemRatings = prepareJob(getInputPath(), pathToItemRatings(),
+ TextInputFormat.class, ItemRatingVectorsMapper.class, IntWritable.class,
+ VectorWritable.class, VectorSumReducer.class, IntWritable.class,
+ VectorWritable.class, SequenceFileOutputFormat.class);
+ itemRatings.setCombinerClass(VectorSumCombiner.class);
+ itemRatings.getConfiguration().set(USES_LONG_IDS, String.valueOf(usesLongIDs));
+ boolean succeeded = itemRatings.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ /* create A */
+ Job userRatings = prepareJob(pathToItemRatings(), pathToUserRatings(),
+ TransposeMapper.class, IntWritable.class, VectorWritable.class, MergeUserVectorsReducer.class,
+ IntWritable.class, VectorWritable.class);
+ userRatings.setCombinerClass(MergeVectorsCombiner.class);
+ succeeded = userRatings.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ //TODO this could be fiddled into one of the upper jobs
+ Job averageItemRatings = prepareJob(pathToItemRatings(), getTempPath("averageRatings"),
+ AverageRatingMapper.class, IntWritable.class, VectorWritable.class, MergeVectorsReducer.class,
+ IntWritable.class, VectorWritable.class);
+ averageItemRatings.setCombinerClass(MergeVectorsCombiner.class);
+ succeeded = averageItemRatings.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ Vector averageRatings = ALS.readFirstRow(getTempPath("averageRatings"), getConf());
+
+ int numItems = averageRatings.getNumNondefaultElements();
+ int numUsers = (int) userRatings.getCounters().findCounter(Stats.NUM_USERS).getValue();
+
+ log.info("Found {} users and {} items", numUsers, numItems);
+
+ /* create an initial M */
+ initializeM(averageRatings);
+
+ for (int currentIteration = 0; currentIteration < numIterations; currentIteration++) {
+ /* broadcast M, read A row-wise, recompute U row-wise */
+ log.info("Recomputing U (iteration {}/{})", currentIteration, numIterations);
+ runSolver(pathToUserRatings(), pathToU(currentIteration), pathToM(currentIteration - 1), currentIteration, "U",
+ numItems);
+ /* broadcast U, read A' row-wise, recompute M row-wise */
+ log.info("Recomputing M (iteration {}/{})", currentIteration, numIterations);
+ runSolver(pathToItemRatings(), pathToM(currentIteration), pathToU(currentIteration), currentIteration, "M",
+ numUsers);
+ }
+
+ return 0;
+ }
+
+ private void initializeM(Vector averageRatings) throws IOException {
+ Random random = RandomUtils.getRandom();
+
+ FileSystem fs = FileSystem.get(pathToM(-1).toUri(), getConf());
+ SequenceFile.Writer writer = null;
+ try {
+ writer = new SequenceFile.Writer(fs, getConf(), new Path(pathToM(-1), "part-m-00000"), IntWritable.class,
+ VectorWritable.class);
+
+ IntWritable index = new IntWritable();
+ VectorWritable featureVector = new VectorWritable();
+
+ for (Vector.Element e : averageRatings.nonZeroes()) {
+ Vector row = new DenseVector(numFeatures);
+ row.setQuick(0, e.get());
+ for (int m = 1; m < numFeatures; m++) {
+ row.setQuick(m, random.nextDouble());
+ }
+ index.set(e.index());
+ featureVector.set(row);
+ writer.append(index, featureVector);
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ static class VectorSumReducer
+ extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
+
+ private final VectorWritable result = new VectorWritable();
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context ctx)
+ throws IOException, InterruptedException {
+ Vector sum = Vectors.sum(values.iterator());
+ result.set(new SequentialAccessSparseVector(sum));
+ ctx.write(key, result);
+ }
+ }
+
+ static class MergeUserVectorsReducer extends
+ Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable> {
+
+ private final VectorWritable result = new VectorWritable();
+
+ @Override
+ public void reduce(WritableComparable<?> key, Iterable<VectorWritable> vectors, Context ctx)
+ throws IOException, InterruptedException {
+ Vector merged = VectorWritable.merge(vectors.iterator()).get();
+ result.set(new SequentialAccessSparseVector(merged));
+ ctx.write(key, result);
+ ctx.getCounter(Stats.NUM_USERS).increment(1);
+ }
+ }
+
+ static class ItemRatingVectorsMapper extends Mapper<LongWritable,Text,IntWritable,VectorWritable> {
+
+ private final IntWritable itemIDWritable = new IntWritable();
+ private final VectorWritable ratingsWritable = new VectorWritable(true);
+ private final Vector ratings = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
+
+ private boolean usesLongIDs;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ usesLongIDs = ctx.getConfiguration().getBoolean(USES_LONG_IDS, false);
+ }
+
+ @Override
+ protected void map(LongWritable offset, Text line, Context ctx) throws IOException, InterruptedException {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
+ int userID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.USER_ID_POS], usesLongIDs);
+ int itemID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.ITEM_ID_POS], usesLongIDs);
+ float rating = Float.parseFloat(tokens[2]);
+
+ ratings.setQuick(userID, rating);
+
+ itemIDWritable.set(itemID);
+ ratingsWritable.set(ratings);
+
+ ctx.write(itemIDWritable, ratingsWritable);
+
+ // prepare instance for reuse
+ ratings.setQuick(userID, 0.0d);
+ }
+ }
+
+ private void runSolver(Path ratings, Path output, Path pathToUorM, int currentIteration, String matrixName,
+ int numEntities) throws ClassNotFoundException, IOException, InterruptedException {
+
+ // necessary for local execution in the same JVM only
+ SharingMapper.reset();
+
+ Class<? extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable>> solverMapperClassInternal;
+ String name;
+
+ if (implicitFeedback) {
+ solverMapperClassInternal = SolveImplicitFeedbackMapper.class;
+ name = "Recompute " + matrixName + ", iteration (" + currentIteration + '/' + numIterations + "), "
+ + '(' + numThreadsPerSolver + " threads, " + numFeatures + " features, implicit feedback)";
+ } else {
+ solverMapperClassInternal = SolveExplicitFeedbackMapper.class;
+ name = "Recompute " + matrixName + ", iteration (" + currentIteration + '/' + numIterations + "), "
+ + '(' + numThreadsPerSolver + " threads, " + numFeatures + " features, explicit feedback)";
+ }
+
+ Job solverForUorI = prepareJob(ratings, output, SequenceFileInputFormat.class, MultithreadedSharingMapper.class,
+ IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, name);
+ Configuration solverConf = solverForUorI.getConfiguration();
+ solverConf.set(LAMBDA, String.valueOf(lambda));
+ solverConf.set(ALPHA, String.valueOf(alpha));
+ solverConf.setInt(NUM_FEATURES, numFeatures);
+ solverConf.set(NUM_ENTITIES, String.valueOf(numEntities));
+
+ FileSystem fs = FileSystem.get(pathToUorM.toUri(), solverConf);
+ FileStatus[] parts = fs.listStatus(pathToUorM, PathFilters.partFilter());
+ for (FileStatus part : parts) {
+ if (log.isDebugEnabled()) {
+ log.debug("Adding {} to distributed cache", part.getPath().toString());
+ }
+ DistributedCache.addCacheFile(part.getPath().toUri(), solverConf);
+ }
+
+ MultithreadedMapper.setMapperClass(solverForUorI, solverMapperClassInternal);
+ MultithreadedMapper.setNumberOfThreads(solverForUorI, numThreadsPerSolver);
+
+ boolean succeeded = solverForUorI.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+ static class AverageRatingMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private final IntWritable firstIndex = new IntWritable(0);
+ private final Vector featureVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
+ private final VectorWritable featureVectorWritable = new VectorWritable();
+
+ @Override
+ protected void map(IntWritable r, VectorWritable v, Context ctx) throws IOException, InterruptedException {
+ RunningAverage avg = new FullRunningAverage();
+ for (Vector.Element e : v.get().nonZeroes()) {
+ avg.addDatum(e.get());
+ }
+
+ featureVector.setQuick(r.get(), avg.getAverage());
+ featureVectorWritable.set(featureVector);
+ ctx.write(firstIndex, featureVectorWritable);
+
+ // prepare instance for reuse
+ featureVector.setQuick(r.get(), 0.0d);
+ }
+ }
+
+ static class MapLongIDsMapper extends Mapper<LongWritable,Text,VarIntWritable,VarLongWritable> {
+
+ private int tokenPos;
+ private final VarIntWritable index = new VarIntWritable();
+ private final VarLongWritable idWritable = new VarLongWritable();
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ tokenPos = ctx.getConfiguration().getInt(TOKEN_POS, -1);
+ Preconditions.checkState(tokenPos >= 0);
+ }
+
+ @Override
+ protected void map(LongWritable key, Text line, Context ctx) throws IOException, InterruptedException {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(line.toString());
+
+ long id = Long.parseLong(tokens[tokenPos]);
+
+ index.set(TasteHadoopUtils.idToIndex(id));
+ idWritable.set(id);
+ ctx.write(index, idWritable);
+ }
+ }
+
+ static class IDMapReducer extends Reducer<VarIntWritable,VarLongWritable,VarIntWritable,VarLongWritable> {
+ @Override
+ protected void reduce(VarIntWritable index, Iterable<VarLongWritable> ids, Context ctx)
+ throws IOException, InterruptedException {
+ ctx.write(index, ids.iterator().next());
+ }
+ }
+
+ private Path pathToM(int iteration) {
+ return iteration == numIterations - 1 ? getOutputPath("M") : getTempPath("M-" + iteration);
+ }
+
+ private Path pathToU(int iteration) {
+ return iteration == numIterations - 1 ? getOutputPath("U") : getTempPath("U-" + iteration);
+ }
+
+ private Path pathToItemRatings() {
+ return getTempPath("itemRatings");
+ }
+
+ private Path pathToUserRatings() {
+ return getOutputPath("userRatings");
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java
new file mode 100644
index 0000000..6e7ea81
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/PredictionMapper.java
@@ -0,0 +1,145 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.mahout.cf.taste.hadoop.MutableRecommendedItem;
+import org.apache.mahout.cf.taste.hadoop.RecommendedItemsWritable;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.hadoop.TopItemsQueue;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.IntObjectProcedure;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+import org.apache.mahout.math.set.OpenIntHashSet;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * a multithreaded mapper that loads the feature matrices U and M into memory. Afterwards it computes recommendations
+ * from these. Can be executed by a {@link MultithreadedSharingMapper}.
+ */
+public class PredictionMapper extends SharingMapper<IntWritable,VectorWritable,LongWritable,RecommendedItemsWritable,
+ Pair<OpenIntObjectHashMap<Vector>,OpenIntObjectHashMap<Vector>>> {
+
+ private int recommendationsPerUser;
+ private float maxRating;
+
+ private boolean usesLongIDs;
+ private OpenIntLongHashMap userIDIndex;
+ private OpenIntLongHashMap itemIDIndex;
+
+ private final LongWritable userIDWritable = new LongWritable();
+ private final RecommendedItemsWritable recommendations = new RecommendedItemsWritable();
+
+ @Override
+ Pair<OpenIntObjectHashMap<Vector>, OpenIntObjectHashMap<Vector>> createSharedInstance(Context ctx) {
+ Configuration conf = ctx.getConfiguration();
+ Path pathToU = new Path(conf.get(RecommenderJob.USER_FEATURES_PATH));
+ Path pathToM = new Path(conf.get(RecommenderJob.ITEM_FEATURES_PATH));
+
+ OpenIntObjectHashMap<Vector> U = ALS.readMatrixByRows(pathToU, conf);
+ OpenIntObjectHashMap<Vector> M = ALS.readMatrixByRows(pathToM, conf);
+
+ return new Pair<>(U, M);
+ }
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ Configuration conf = ctx.getConfiguration();
+ recommendationsPerUser = conf.getInt(RecommenderJob.NUM_RECOMMENDATIONS,
+ RecommenderJob.DEFAULT_NUM_RECOMMENDATIONS);
+ maxRating = Float.parseFloat(conf.get(RecommenderJob.MAX_RATING));
+
+ usesLongIDs = conf.getBoolean(ParallelALSFactorizationJob.USES_LONG_IDS, false);
+ if (usesLongIDs) {
+ userIDIndex = TasteHadoopUtils.readIDIndexMap(conf.get(RecommenderJob.USER_INDEX_PATH), conf);
+ itemIDIndex = TasteHadoopUtils.readIDIndexMap(conf.get(RecommenderJob.ITEM_INDEX_PATH), conf);
+ }
+ }
+
+ @Override
+ protected void map(IntWritable userIndexWritable, VectorWritable ratingsWritable, Context ctx)
+ throws IOException, InterruptedException {
+
+ Pair<OpenIntObjectHashMap<Vector>, OpenIntObjectHashMap<Vector>> uAndM = getSharedInstance();
+ OpenIntObjectHashMap<Vector> U = uAndM.getFirst();
+ OpenIntObjectHashMap<Vector> M = uAndM.getSecond();
+
+ Vector ratings = ratingsWritable.get();
+ int userIndex = userIndexWritable.get();
+ final OpenIntHashSet alreadyRatedItems = new OpenIntHashSet(ratings.getNumNondefaultElements());
+
+ for (Vector.Element e : ratings.nonZeroes()) {
+ alreadyRatedItems.add(e.index());
+ }
+
+ final TopItemsQueue topItemsQueue = new TopItemsQueue(recommendationsPerUser);
+ final Vector userFeatures = U.get(userIndex);
+
+ M.forEachPair(new IntObjectProcedure<Vector>() {
+ @Override
+ public boolean apply(int itemID, Vector itemFeatures) {
+ if (!alreadyRatedItems.contains(itemID)) {
+ double predictedRating = userFeatures.dot(itemFeatures);
+
+ MutableRecommendedItem top = topItemsQueue.top();
+ if (predictedRating > top.getValue()) {
+ top.set(itemID, (float) predictedRating);
+ topItemsQueue.updateTop();
+ }
+ }
+ return true;
+ }
+ });
+
+ List<RecommendedItem> recommendedItems = topItemsQueue.getTopItems();
+
+ if (!recommendedItems.isEmpty()) {
+
+ // cap predictions to maxRating
+ for (RecommendedItem topItem : recommendedItems) {
+ ((MutableRecommendedItem) topItem).capToMaxValue(maxRating);
+ }
+
+ if (usesLongIDs) {
+ long userID = userIDIndex.get(userIndex);
+ userIDWritable.set(userID);
+
+ for (RecommendedItem topItem : recommendedItems) {
+ // remap item IDs
+ long itemID = itemIDIndex.get((int) topItem.getItemID());
+ ((MutableRecommendedItem) topItem).setItemID(itemID);
+ }
+
+ } else {
+ userIDWritable.set(userIndex);
+ }
+
+ recommendations.set(recommendedItems);
+ ctx.write(userIDWritable, recommendations);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java
new file mode 100644
index 0000000..679d227
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/RecommenderJob.java
@@ -0,0 +1,110 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.RecommendedItemsWritable;
+import org.apache.mahout.common.AbstractJob;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+ * <p>Computes the top-N recommendations per user from a decomposition of the rating matrix</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--input (path): Directory containing the vectorized user ratings</li>
+ * <li>--output (path): path where output should go</li>
+ * <li>--numRecommendations (int): maximum number of recommendations per user (default: 10)</li>
+ * <li>--maxRating (double): maximum rating of an item</li>
+ * <li>--numThreads (int): threads to use per mapper, (default: 1)</li>
+ * </ol>
+ */
+public class RecommenderJob extends AbstractJob {
+
+ static final String NUM_RECOMMENDATIONS = RecommenderJob.class.getName() + ".numRecommendations";
+ static final String USER_FEATURES_PATH = RecommenderJob.class.getName() + ".userFeatures";
+ static final String ITEM_FEATURES_PATH = RecommenderJob.class.getName() + ".itemFeatures";
+ static final String MAX_RATING = RecommenderJob.class.getName() + ".maxRating";
+ static final String USER_INDEX_PATH = RecommenderJob.class.getName() + ".userIndex";
+ static final String ITEM_INDEX_PATH = RecommenderJob.class.getName() + ".itemIndex";
+
+ static final int DEFAULT_NUM_RECOMMENDATIONS = 10;
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new RecommenderJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOption("userFeatures", null, "path to the user feature matrix", true);
+ addOption("itemFeatures", null, "path to the item feature matrix", true);
+ addOption("numRecommendations", null, "number of recommendations per user",
+ String.valueOf(DEFAULT_NUM_RECOMMENDATIONS));
+ addOption("maxRating", null, "maximum rating available", true);
+ addOption("numThreads", null, "threads per mapper", String.valueOf(1));
+ addOption("usesLongIDs", null, "input contains long IDs that need to be translated");
+ addOption("userIDIndex", null, "index for user long IDs (necessary if usesLongIDs is true)");
+ addOption("itemIDIndex", null, "index for user long IDs (necessary if usesLongIDs is true)");
+ addOutputOption();
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Job prediction = prepareJob(getInputPath(), getOutputPath(), SequenceFileInputFormat.class,
+ MultithreadedSharingMapper.class, IntWritable.class, RecommendedItemsWritable.class, TextOutputFormat.class);
+ Configuration conf = prediction.getConfiguration();
+
+ int numThreads = Integer.parseInt(getOption("numThreads"));
+
+ conf.setInt(NUM_RECOMMENDATIONS, Integer.parseInt(getOption("numRecommendations")));
+ conf.set(USER_FEATURES_PATH, getOption("userFeatures"));
+ conf.set(ITEM_FEATURES_PATH, getOption("itemFeatures"));
+ conf.set(MAX_RATING, getOption("maxRating"));
+
+ boolean usesLongIDs = Boolean.parseBoolean(getOption("usesLongIDs"));
+ if (usesLongIDs) {
+ conf.set(ParallelALSFactorizationJob.USES_LONG_IDS, String.valueOf(true));
+ conf.set(USER_INDEX_PATH, getOption("userIDIndex"));
+ conf.set(ITEM_INDEX_PATH, getOption("itemIDIndex"));
+ }
+
+ MultithreadedMapper.setMapperClass(prediction, PredictionMapper.class);
+ MultithreadedMapper.setNumberOfThreads(prediction, numThreads);
+
+ boolean succeeded = prediction.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ return 0;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
new file mode 100644
index 0000000..9925807
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SharingMapper.java
@@ -0,0 +1,59 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.mapreduce.Mapper;
+
+import java.io.IOException;
+
+/**
+ * Mapper class to be used by {@link MultithreadedSharingMapper}. Offers "global" before() and after() methods
+ * that will typically be used to set up static variables.
+ *
+ * Suitable for mappers that need large, read-only in-memory data to operate.
+ *
+ * @param <K1>
+ * @param <V1>
+ * @param <K2>
+ * @param <V2>
+ */
+public abstract class SharingMapper<K1,V1,K2,V2,S> extends Mapper<K1,V1,K2,V2> {
+
+ private static Object SHARED_INSTANCE;
+
+ /**
+ * Called before the multithreaded execution
+ *
+ * @param context mapper's context
+ */
+ abstract S createSharedInstance(Context context) throws IOException;
+
+ final void setupSharedInstance(Context context) throws IOException {
+ if (SHARED_INSTANCE == null) {
+ SHARED_INSTANCE = createSharedInstance(context);
+ }
+ }
+
+ final S getSharedInstance() {
+ return (S) SHARED_INSTANCE;
+ }
+
+ static void reset() {
+ SHARED_INSTANCE = null;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
new file mode 100644
index 0000000..2569918
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveExplicitFeedbackMapper.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+
+import java.io.IOException;
+
+/** Solving mapper that can be safely executed using multiple threads */
+public class SolveExplicitFeedbackMapper
+ extends SharingMapper<IntWritable,VectorWritable,IntWritable,VectorWritable,OpenIntObjectHashMap<Vector>> {
+
+ private double lambda;
+ private int numFeatures;
+ private final VectorWritable uiOrmj = new VectorWritable();
+
+ @Override
+ OpenIntObjectHashMap<Vector> createSharedInstance(Context ctx) throws IOException {
+ Configuration conf = ctx.getConfiguration();
+ int numEntities = Integer.parseInt(conf.get(ParallelALSFactorizationJob.NUM_ENTITIES));
+ return ALS.readMatrixByRowsFromDistributedCache(numEntities, conf);
+ }
+
+ @Override
+ protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
+ lambda = Double.parseDouble(ctx.getConfiguration().get(ParallelALSFactorizationJob.LAMBDA));
+ numFeatures = ctx.getConfiguration().getInt(ParallelALSFactorizationJob.NUM_FEATURES, -1);
+ Preconditions.checkArgument(numFeatures > 0, "numFeatures must be greater then 0!");
+ }
+
+ @Override
+ protected void map(IntWritable userOrItemID, VectorWritable ratingsWritable, Context ctx)
+ throws IOException, InterruptedException {
+ OpenIntObjectHashMap<Vector> uOrM = getSharedInstance();
+ uiOrmj.set(ALS.solveExplicit(ratingsWritable, uOrM, lambda, numFeatures));
+ ctx.write(userOrItemID, uiOrmj);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
new file mode 100644
index 0000000..fd6657f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/SolveImplicitFeedbackMapper.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.als.ImplicitFeedbackAlternatingLeastSquaresSolver;
+
+import java.io.IOException;
+
+/** Solving mapper that can be safely executed using multiple threads */
+public class SolveImplicitFeedbackMapper
+ extends SharingMapper<IntWritable,VectorWritable,IntWritable,VectorWritable,
+ ImplicitFeedbackAlternatingLeastSquaresSolver> {
+
+ private final VectorWritable uiOrmj = new VectorWritable();
+
+ @Override
+ ImplicitFeedbackAlternatingLeastSquaresSolver createSharedInstance(Context ctx) throws IOException {
+ Configuration conf = ctx.getConfiguration();
+
+ double lambda = Double.parseDouble(conf.get(ParallelALSFactorizationJob.LAMBDA));
+ double alpha = Double.parseDouble(conf.get(ParallelALSFactorizationJob.ALPHA));
+ int numFeatures = conf.getInt(ParallelALSFactorizationJob.NUM_FEATURES, -1);
+ int numEntities = Integer.parseInt(conf.get(ParallelALSFactorizationJob.NUM_ENTITIES));
+
+ Preconditions.checkArgument(numFeatures > 0, "numFeatures must be greater then 0!");
+
+ return new ImplicitFeedbackAlternatingLeastSquaresSolver(numFeatures, lambda, alpha,
+ ALS.readMatrixByRowsFromDistributedCache(numEntities, conf), 1);
+ }
+
+ @Override
+ protected void map(IntWritable userOrItemID, VectorWritable ratingsWritable, Context ctx)
+ throws IOException, InterruptedException {
+ ImplicitFeedbackAlternatingLeastSquaresSolver solver = getSharedInstance();
+ uiOrmj.set(solver.solve(ratingsWritable.get()));
+ ctx.write(userOrItemID, uiOrmj);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
new file mode 100644
index 0000000..b44fd5b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/AggregateAndRecommendReducer.java
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.cf.taste.hadoop.MutableRecommendedItem;
+import org.apache.mahout.cf.taste.hadoop.RecommendedItemsWritable;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.hadoop.TopItemsQueue;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>computes prediction values for each user</p>
+ *
+ * <pre>
+ * u = a user
+ * i = an item not yet rated by u
+ * N = all items similar to i (where similarity is usually computed by pairwisely comparing the item-vectors
+ * of the user-item matrix)
+ *
+ * Prediction(u,i) = sum(all n from N: similarity(i,n) * rating(u,n)) / sum(all n from N: abs(similarity(i,n)))
+ * </pre>
+ */
+public final class AggregateAndRecommendReducer extends
+ Reducer<VarLongWritable,PrefAndSimilarityColumnWritable,VarLongWritable,RecommendedItemsWritable> {
+
+ private static final Logger log = LoggerFactory.getLogger(AggregateAndRecommendReducer.class);
+
+ static final String ITEMID_INDEX_PATH = "itemIDIndexPath";
+ static final String NUM_RECOMMENDATIONS = "numRecommendations";
+ static final int DEFAULT_NUM_RECOMMENDATIONS = 10;
+ static final String ITEMS_FILE = "itemsFile";
+
+ private boolean booleanData;
+ private int recommendationsPerUser;
+ private IDReader idReader;
+ private FastIDSet itemsToRecommendFor;
+ private OpenIntLongHashMap indexItemIDMap;
+
+ private final RecommendedItemsWritable recommendedItems = new RecommendedItemsWritable();
+
+ private static final float BOOLEAN_PREF_VALUE = 1.0f;
+
+ @Override
+ protected void setup(Context context) throws IOException {
+ Configuration conf = context.getConfiguration();
+ recommendationsPerUser = conf.getInt(NUM_RECOMMENDATIONS, DEFAULT_NUM_RECOMMENDATIONS);
+ booleanData = conf.getBoolean(RecommenderJob.BOOLEAN_DATA, false);
+ indexItemIDMap = TasteHadoopUtils.readIDIndexMap(conf.get(ITEMID_INDEX_PATH), conf);
+
+ idReader = new IDReader(conf);
+ idReader.readIDs();
+ itemsToRecommendFor = idReader.getItemIds();
+ }
+
+ @Override
+ protected void reduce(VarLongWritable userID,
+ Iterable<PrefAndSimilarityColumnWritable> values,
+ Context context) throws IOException, InterruptedException {
+ if (booleanData) {
+ reduceBooleanData(userID, values, context);
+ } else {
+ reduceNonBooleanData(userID, values, context);
+ }
+ }
+
+ private void reduceBooleanData(VarLongWritable userID,
+ Iterable<PrefAndSimilarityColumnWritable> values,
+ Context context) throws IOException, InterruptedException {
+ /* having boolean data, each estimated preference can only be 1,
+ * however we can't use this to rank the recommended items,
+ * so we use the sum of similarities for that. */
+ Iterator<PrefAndSimilarityColumnWritable> columns = values.iterator();
+ Vector predictions = columns.next().getSimilarityColumn();
+ while (columns.hasNext()) {
+ predictions.assign(columns.next().getSimilarityColumn(), Functions.PLUS);
+ }
+ writeRecommendedItems(userID, predictions, context);
+ }
+
+ private void reduceNonBooleanData(VarLongWritable userID,
+ Iterable<PrefAndSimilarityColumnWritable> values,
+ Context context) throws IOException, InterruptedException {
+ /* each entry here is the sum in the numerator of the prediction formula */
+ Vector numerators = null;
+ /* each entry here is the sum in the denominator of the prediction formula */
+ Vector denominators = null;
+ /* each entry here is the number of similar items used in the prediction formula */
+ Vector numberOfSimilarItemsUsed = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+
+ for (PrefAndSimilarityColumnWritable prefAndSimilarityColumn : values) {
+ Vector simColumn = prefAndSimilarityColumn.getSimilarityColumn();
+ float prefValue = prefAndSimilarityColumn.getPrefValue();
+ /* count the number of items used for each prediction */
+ for (Element e : simColumn.nonZeroes()) {
+ int itemIDIndex = e.index();
+ numberOfSimilarItemsUsed.setQuick(itemIDIndex, numberOfSimilarItemsUsed.getQuick(itemIDIndex) + 1);
+ }
+
+ if (denominators == null) {
+ denominators = simColumn.clone();
+ } else {
+ denominators.assign(simColumn, Functions.PLUS_ABS);
+ }
+
+ if (numerators == null) {
+ numerators = simColumn.clone();
+ if (prefValue != BOOLEAN_PREF_VALUE) {
+ numerators.assign(Functions.MULT, prefValue);
+ }
+ } else {
+ if (prefValue != BOOLEAN_PREF_VALUE) {
+ simColumn.assign(Functions.MULT, prefValue);
+ }
+ numerators.assign(simColumn, Functions.PLUS);
+ }
+
+ }
+
+ if (numerators == null) {
+ return;
+ }
+
+ Vector recommendationVector = new RandomAccessSparseVector(Integer.MAX_VALUE, 100);
+ for (Element element : numerators.nonZeroes()) {
+ int itemIDIndex = element.index();
+ /* preference estimations must be based on at least 2 datapoints */
+ if (numberOfSimilarItemsUsed.getQuick(itemIDIndex) > 1) {
+ /* compute normalized prediction */
+ double prediction = element.get() / denominators.getQuick(itemIDIndex);
+ recommendationVector.setQuick(itemIDIndex, prediction);
+ }
+ }
+ writeRecommendedItems(userID, recommendationVector, context);
+ }
+
+ /**
+ * find the top entries in recommendationVector, map them to the real itemIDs and write back the result
+ */
+ private void writeRecommendedItems(VarLongWritable userID, Vector recommendationVector, Context context)
+ throws IOException, InterruptedException {
+ TopItemsQueue topKItems = new TopItemsQueue(recommendationsPerUser);
+ FastIDSet itemsForUser = null;
+
+ if (idReader != null && idReader.isUserItemFilterSpecified()) {
+ itemsForUser = idReader.getItemsToRecommendForUser(userID.get());
+ }
+
+ for (Element element : recommendationVector.nonZeroes()) {
+ int index = element.index();
+ long itemID;
+ if (indexItemIDMap != null && !indexItemIDMap.isEmpty()) {
+ itemID = indexItemIDMap.get(index);
+ } else { // we don't have any mappings, so just use the original
+ itemID = index;
+ }
+
+ if (shouldIncludeItemIntoRecommendations(itemID, itemsToRecommendFor, itemsForUser)) {
+
+ float value = (float) element.get();
+ if (!Float.isNaN(value)) {
+
+ MutableRecommendedItem topItem = topKItems.top();
+ if (value > topItem.getValue()) {
+ topItem.set(itemID, value);
+ topKItems.updateTop();
+ }
+ }
+ }
+ }
+
+ List<RecommendedItem> topItems = topKItems.getTopItems();
+ if (!topItems.isEmpty()) {
+ recommendedItems.set(topItems);
+ context.write(userID, recommendedItems);
+ }
+ }
+
+ private boolean shouldIncludeItemIntoRecommendations(long itemID, FastIDSet allItemsToRecommendFor,
+ FastIDSet itemsForUser) {
+ if (allItemsToRecommendFor == null && itemsForUser == null) {
+ return true;
+ } else if (itemsForUser != null) {
+ return itemsForUser.contains(itemID);
+ } else {
+ return allItemsToRecommendFor.contains(itemID);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/IDReader.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/IDReader.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/IDReader.java
new file mode 100644
index 0000000..b8cf6bb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/IDReader.java
@@ -0,0 +1,250 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Map;
+import java.util.regex.Pattern;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.iterator.FileLineIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Reads user ids and item ids from files specified in usersFile, itemsFile or userItemFile options in item-based
+ * recommender. Composes a list of users and a list of items which can be used by
+ * {@link org.apache.mahout.cf.taste.hadoop.item.UserVectorSplitterMapper} and
+ * {@link org.apache.mahout.cf.taste.hadoop.item.AggregateAndRecommendReducer}.
+ */
+public class IDReader {
+
+ static final String USER_ITEM_FILE = "userItemFile";
+
+ private static final Logger log = LoggerFactory.getLogger(IDReader.class);
+ private static final Pattern SEPARATOR = Pattern.compile("[\t,]");
+
+ private Configuration conf;
+
+ private String usersFile;
+ private String itemsFile;
+ private String userItemFile;
+
+ private FastIDSet userIds;
+ private FastIDSet itemIds;
+
+ private FastIDSet emptySet;
+
+ /* Key - user id, value - a set of item ids to include into recommendations for this user */
+ private Map<Long, FastIDSet> userItemFilter;
+
+ /**
+ * Creates a new IDReader
+ *
+ * @param conf Job configuration
+ */
+ public IDReader(Configuration conf) {
+ this.conf = conf;
+ emptySet = new FastIDSet();
+
+ usersFile = conf.get(UserVectorSplitterMapper.USERS_FILE);
+ itemsFile = conf.get(AggregateAndRecommendReducer.ITEMS_FILE);
+ userItemFile = conf.get(USER_ITEM_FILE);
+ }
+
+ /**
+ * Reads user ids and item ids from files specified in a job configuration
+ *
+ * @throws IOException if an error occurs during file read operation
+ *
+ * @throws IllegalStateException if userItemFile option is specified together with usersFile or itemsFile
+ */
+ public void readIDs() throws IOException, IllegalStateException {
+ if (isUserItemFileSpecified()) {
+ readUserItemFilterIfNeeded();
+ }
+
+ if (isUsersFileSpecified() || isUserItemFilterSpecified()) {
+ readUserIds();
+ }
+
+ if (isItemsFileSpecified() || isUserItemFilterSpecified()) {
+ readItemIds();
+ }
+ }
+
+ /**
+ * Gets a collection of items which should be recommended for a user
+ *
+ * @param userId ID of a user we are interested in
+ * @return if a userItemFile option is specified, and that file contains at least one item ID for the user,
+ * then this method returns a {@link FastIDSet} object populated with item IDs. Otherwise, this
+ * method returns an empty set.
+ */
+ public FastIDSet getItemsToRecommendForUser(Long userId) {
+ if (isUserItemFilterSpecified() && userItemFilter.containsKey(userId)) {
+ return userItemFilter.get(userId);
+ } else {
+ return emptySet;
+ }
+ }
+
+ private void readUserIds() throws IOException, IllegalStateException {
+ if (isUsersFileSpecified() && !isUserItemFileSpecified()) {
+ userIds = readIDList(usersFile);
+ } else if (isUserItemFileSpecified() && !isUsersFileSpecified()) {
+ readUserItemFilterIfNeeded();
+ userIds = extractAllUserIdsFromUserItemFilter(userItemFilter);
+ } else if (!isUsersFileSpecified()) {
+ throw new IllegalStateException("Neither usersFile nor userItemFile options are specified");
+ } else {
+ throw new IllegalStateException("usersFile and userItemFile options cannot be used simultaneously");
+ }
+ }
+
+ private void readItemIds() throws IOException, IllegalStateException {
+ if (isItemsFileSpecified() && !isUserItemFileSpecified()) {
+ itemIds = readIDList(itemsFile);
+ } else if (isUserItemFileSpecified() && !isItemsFileSpecified()) {
+ readUserItemFilterIfNeeded();
+ itemIds = extractAllItemIdsFromUserItemFilter(userItemFilter);
+ } else if (!isItemsFileSpecified()) {
+ throw new IllegalStateException("Neither itemsFile nor userItemFile options are specified");
+ } else {
+ throw new IllegalStateException("itemsFile and userItemFile options cannot be specified simultaneously");
+ }
+ }
+
+ private void readUserItemFilterIfNeeded() throws IOException {
+ if (!isUserItemFilterSpecified() && isUserItemFileSpecified()) {
+ userItemFilter = readUserItemFilter(userItemFile);
+ }
+ }
+
+ private Map<Long, FastIDSet> readUserItemFilter(String pathString) throws IOException {
+ Map<Long, FastIDSet> result = Maps.newHashMap();
+ InputStream in = openFile(pathString);
+
+ try {
+ for (String line : new FileLineIterable(in)) {
+ try {
+ String[] tokens = SEPARATOR.split(line);
+ Long userId = Long.parseLong(tokens[0]);
+ Long itemId = Long.parseLong(tokens[1]);
+
+ addUserAndItemIdToUserItemFilter(result, userId, itemId);
+ } catch (NumberFormatException nfe) {
+ log.warn("userItemFile line ignored: {}", line);
+ }
+ }
+ } finally {
+ Closeables.close(in, true);
+ }
+
+ return result;
+ }
+
+ void addUserAndItemIdToUserItemFilter(Map<Long, FastIDSet> filter, Long userId, Long itemId) {
+ FastIDSet itemIds;
+
+ if (filter.containsKey(userId)) {
+ itemIds = filter.get(userId);
+ } else {
+ itemIds = new FastIDSet();
+ filter.put(userId, itemIds);
+ }
+
+ itemIds.add(itemId);
+ }
+
+ static FastIDSet extractAllUserIdsFromUserItemFilter(Map<Long, FastIDSet> filter) {
+ FastIDSet result = new FastIDSet();
+
+ for (Long userId : filter.keySet()) {
+ result.add(userId);
+ }
+
+ return result;
+ }
+
+ private FastIDSet extractAllItemIdsFromUserItemFilter(Map<Long, FastIDSet> filter) {
+ FastIDSet result = new FastIDSet();
+
+ for (FastIDSet itemIds : filter.values()) {
+ result.addAll(itemIds);
+ }
+
+ return result;
+ }
+
+ private FastIDSet readIDList(String pathString) throws IOException {
+ FastIDSet result = null;
+
+ if (pathString != null) {
+ result = new FastIDSet();
+ InputStream in = openFile(pathString);
+
+ try {
+ for (String line : new FileLineIterable(in)) {
+ try {
+ result.add(Long.parseLong(line));
+ } catch (NumberFormatException nfe) {
+ log.warn("line ignored: {}", line);
+ }
+ }
+ } finally {
+ Closeables.close(in, true);
+ }
+ }
+
+ return result;
+ }
+
+ private InputStream openFile(String pathString) throws IOException {
+ return HadoopUtil.openStream(new Path(pathString), conf);
+ }
+
+ public boolean isUsersFileSpecified () {
+ return usersFile != null;
+ }
+
+ public boolean isItemsFileSpecified () {
+ return itemsFile != null;
+ }
+
+ public boolean isUserItemFileSpecified () {
+ return userItemFile != null;
+ }
+
+ public boolean isUserItemFilterSpecified() {
+ return userItemFilter != null;
+ }
+
+ public FastIDSet getUserIds() {
+ return userIds;
+ }
+
+ public FastIDSet getItemIds() {
+ return itemIds;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterAsVectorAndPrefsReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterAsVectorAndPrefsReducer.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterAsVectorAndPrefsReducer.java
new file mode 100644
index 0000000..d9a7d25
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterAsVectorAndPrefsReducer.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.Vector;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * we use a neat little trick to explicitly filter items for some users: we inject a NaN summand into the preference
+ * estimation for those items, which makes {@link org.apache.mahout.cf.taste.hadoop.item.AggregateAndRecommendReducer}
+ * automatically exclude them
+ */
+public class ItemFilterAsVectorAndPrefsReducer
+ extends Reducer<VarLongWritable,VarLongWritable,VarIntWritable,VectorAndPrefsWritable> {
+
+ private final VarIntWritable itemIDIndexWritable = new VarIntWritable();
+ private final VectorAndPrefsWritable vectorAndPrefs = new VectorAndPrefsWritable();
+
+ @Override
+ protected void reduce(VarLongWritable itemID, Iterable<VarLongWritable> values, Context ctx)
+ throws IOException, InterruptedException {
+
+ int itemIDIndex = TasteHadoopUtils.idToIndex(itemID.get());
+ Vector vector = new RandomAccessSparseVector(Integer.MAX_VALUE, 1);
+ /* artificial NaN summand to exclude this item from the recommendations for all users specified in userIDs */
+ vector.set(itemIDIndex, Double.NaN);
+
+ List<Long> userIDs = Lists.newArrayList();
+ List<Float> prefValues = Lists.newArrayList();
+ for (VarLongWritable userID : values) {
+ userIDs.add(userID.get());
+ prefValues.add(1.0f);
+ }
+
+ itemIDIndexWritable.set(itemIDIndex);
+ vectorAndPrefs.set(vector, userIDs, prefValues);
+ ctx.write(itemIDIndexWritable, vectorAndPrefs);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterMapper.java
new file mode 100644
index 0000000..cdc1ddf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemFilterMapper.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VarLongWritable;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+/**
+ * map out all user/item pairs to filter, keyed by the itemID
+ */
+public class ItemFilterMapper extends Mapper<LongWritable,Text,VarLongWritable,VarLongWritable> {
+
+ private static final Pattern SEPARATOR = Pattern.compile("[\t,]");
+
+ private final VarLongWritable itemIDWritable = new VarLongWritable();
+ private final VarLongWritable userIDWritable = new VarLongWritable();
+
+ @Override
+ protected void map(LongWritable key, Text line, Context ctx) throws IOException, InterruptedException {
+ String[] tokens = SEPARATOR.split(line.toString());
+ long userID = Long.parseLong(tokens[0]);
+ long itemID = Long.parseLong(tokens[1]);
+ itemIDWritable.set(itemID);
+ userIDWritable.set(userID);
+ ctx.write(itemIDWritable, userIDWritable);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexMapper.java
new file mode 100644
index 0000000..ac8597e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexMapper.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.hadoop.ToEntityPrefsMapper;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+
+public final class ItemIDIndexMapper extends
+ Mapper<LongWritable,Text, VarIntWritable, VarLongWritable> {
+
+ private boolean transpose;
+
+ private final VarIntWritable indexWritable = new VarIntWritable();
+ private final VarLongWritable itemIDWritable = new VarLongWritable();
+
+ @Override
+ protected void setup(Context context) {
+ Configuration jobConf = context.getConfiguration();
+ transpose = jobConf.getBoolean(ToEntityPrefsMapper.TRANSPOSE_USER_ITEM, false);
+ }
+
+ @Override
+ protected void map(LongWritable key,
+ Text value,
+ Context context) throws IOException, InterruptedException {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString());
+ long itemID = Long.parseLong(tokens[transpose ? 0 : 1]);
+ int index = TasteHadoopUtils.idToIndex(itemID);
+ indexWritable.set(index);
+ itemIDWritable.set(itemID);
+ context.write(indexWritable, itemIDWritable);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexReducer.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexReducer.java
new file mode 100644
index 0000000..d9ecf5e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/ItemIDIndexReducer.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.io.IOException;
+
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+
+public final class ItemIDIndexReducer extends
+ Reducer<VarIntWritable, VarLongWritable, VarIntWritable,VarLongWritable> {
+
+ private final VarLongWritable minimumItemIDWritable = new VarLongWritable();
+
+ @Override
+ protected void reduce(VarIntWritable index,
+ Iterable<VarLongWritable> possibleItemIDs,
+ Context context) throws IOException, InterruptedException {
+ long minimumItemID = Long.MAX_VALUE;
+ for (VarLongWritable varLongWritable : possibleItemIDs) {
+ long itemID = varLongWritable.get();
+ if (itemID < minimumItemID) {
+ minimumItemID = itemID;
+ }
+ }
+ if (minimumItemID != Long.MAX_VALUE) {
+ minimumItemIDWritable.set(minimumItemID);
+ context.write(index, minimumItemIDWritable);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyMapper.java
new file mode 100644
index 0000000..0e818f3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PartialMultiplyMapper.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.Vector;
+
+/**
+ * maps similar items and their preference values per user
+ */
+public final class PartialMultiplyMapper extends
+ Mapper<VarIntWritable,VectorAndPrefsWritable,VarLongWritable,PrefAndSimilarityColumnWritable> {
+
+ private final VarLongWritable userIDWritable = new VarLongWritable();
+ private final PrefAndSimilarityColumnWritable prefAndSimilarityColumn = new PrefAndSimilarityColumnWritable();
+
+ @Override
+ protected void map(VarIntWritable key,
+ VectorAndPrefsWritable vectorAndPrefsWritable,
+ Context context) throws IOException, InterruptedException {
+
+ Vector similarityMatrixColumn = vectorAndPrefsWritable.getVector();
+ List<Long> userIDs = vectorAndPrefsWritable.getUserIDs();
+ List<Float> prefValues = vectorAndPrefsWritable.getValues();
+
+ for (int i = 0; i < userIDs.size(); i++) {
+ long userID = userIDs.get(i);
+ float prefValue = prefValues.get(i);
+ if (!Float.isNaN(prefValue)) {
+ prefAndSimilarityColumn.set(prefValue, similarityMatrixColumn);
+ userIDWritable.set(userID);
+ context.write(userIDWritable, prefAndSimilarityColumn);
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PrefAndSimilarityColumnWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PrefAndSimilarityColumnWritable.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PrefAndSimilarityColumnWritable.java
new file mode 100644
index 0000000..704c74a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/item/PrefAndSimilarityColumnWritable.java
@@ -0,0 +1,85 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+public final class PrefAndSimilarityColumnWritable implements Writable {
+
+ private float prefValue;
+ private Vector similarityColumn;
+
+ public PrefAndSimilarityColumnWritable() {
+ }
+
+ public PrefAndSimilarityColumnWritable(float prefValue, Vector similarityColumn) {
+ set(prefValue, similarityColumn);
+ }
+
+ public void set(float prefValue, Vector similarityColumn) {
+ this.prefValue = prefValue;
+ this.similarityColumn = similarityColumn;
+ }
+
+ public float getPrefValue() {
+ return prefValue;
+ }
+
+ public Vector getSimilarityColumn() {
+ return similarityColumn;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ prefValue = in.readFloat();
+ VectorWritable vw = new VectorWritable();
+ vw.readFields(in);
+ similarityColumn = vw.get();
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeFloat(prefValue);
+ VectorWritable vw = new VectorWritable(similarityColumn);
+ vw.setWritesLaxPrecision(true);
+ vw.write(out);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj instanceof PrefAndSimilarityColumnWritable) {
+ PrefAndSimilarityColumnWritable other = (PrefAndSimilarityColumnWritable) obj;
+ return prefValue == other.prefValue && similarityColumn.equals(other.similarityColumn);
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return RandomUtils.hashFloat(prefValue) + 31 * similarityColumn.hashCode();
+ }
+
+
+}
[49/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/common/NoSuchItemException.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/common/NoSuchItemException.java b/mr/src/main/java/org/apache/mahout/cf/taste/common/NoSuchItemException.java
new file mode 100644
index 0000000..1ac5b72
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/common/NoSuchItemException.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.common;
+
+public final class NoSuchItemException extends TasteException {
+
+ public NoSuchItemException() { }
+
+ public NoSuchItemException(long itemID) {
+ this(String.valueOf(itemID));
+ }
+
+ public NoSuchItemException(String message) {
+ super(message);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/common/NoSuchUserException.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/common/NoSuchUserException.java b/mr/src/main/java/org/apache/mahout/cf/taste/common/NoSuchUserException.java
new file mode 100644
index 0000000..cbb60fa
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/common/NoSuchUserException.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.common;
+
+public final class NoSuchUserException extends TasteException {
+
+ public NoSuchUserException() { }
+
+ public NoSuchUserException(long userID) {
+ this(String.valueOf(userID));
+ }
+
+ public NoSuchUserException(String message) {
+ super(message);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/common/Refreshable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/common/Refreshable.java b/mr/src/main/java/org/apache/mahout/cf/taste/common/Refreshable.java
new file mode 100644
index 0000000..9b26bee
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/common/Refreshable.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.common;
+
+import java.util.Collection;
+
+/**
+ * <p>
+ * Implementations of this interface have state that can be periodically refreshed. For example, an
+ * implementation instance might contain some pre-computed information that should be periodically refreshed.
+ * The {@link #refresh(Collection)} method triggers such a refresh.
+ * </p>
+ *
+ * <p>
+ * All Taste components implement this. In particular,
+ * {@link org.apache.mahout.cf.taste.recommender.Recommender}s do. Callers may want to call
+ * {@link #refresh(Collection)} periodically to re-compute information throughout the system and bring it up
+ * to date, though this operation may be expensive.
+ * </p>
+ */
+public interface Refreshable {
+
+ /**
+ * <p>
+ * Triggers "refresh" -- whatever that means -- of the implementation. The general contract is that any
+ * {@link Refreshable} should always leave itself in a consistent, operational state, and that the refresh
+ * atomically updates internal state from old to new.
+ * </p>
+ *
+ * @param alreadyRefreshed
+ * {@link org.apache.mahout.cf.taste.common.Refreshable}s that are known to have already been
+ * refreshed as a result of an initial call to a {#refresh(Collection)} method on some
+ * object. This ensure that objects in a refresh dependency graph aren't refreshed twice
+ * needlessly.
+ */
+ void refresh(Collection<Refreshable> alreadyRefreshed);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/common/TasteException.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/common/TasteException.java b/mr/src/main/java/org/apache/mahout/cf/taste/common/TasteException.java
new file mode 100644
index 0000000..1792eff
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/common/TasteException.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.common;
+
+/**
+ * <p>
+ * An exception thrown when an error occurs inside the Taste engine.
+ * </p>
+ */
+public class TasteException extends Exception {
+
+ public TasteException() { }
+
+ public TasteException(String message) {
+ super(message);
+ }
+
+ public TasteException(Throwable cause) {
+ super(cause);
+ }
+
+ public TasteException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/common/Weighting.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/common/Weighting.java b/mr/src/main/java/org/apache/mahout/cf/taste/common/Weighting.java
new file mode 100644
index 0000000..4e39617
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/common/Weighting.java
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.common;
+
+/**
+ * <p>
+ * A simple enum which gives symbolic names to the ideas of "weighted" and "unweighted", to make various API
+ * calls which take a weighting parameter more readable.
+ * </p>
+ */
+public enum Weighting {
+
+ WEIGHTED,
+ UNWEIGHTED
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/eval/DataModelBuilder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/eval/DataModelBuilder.java b/mr/src/main/java/org/apache/mahout/cf/taste/eval/DataModelBuilder.java
new file mode 100644
index 0000000..875c65e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/eval/DataModelBuilder.java
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.eval;
+
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+
+/**
+ * <p>
+ * Implementations of this inner interface are simple helper classes which create a {@link DataModel} to be
+ * used while evaluating a {@link org.apache.mahout.cf.taste.recommender.Recommender}.
+ *
+ * @see RecommenderBuilder
+ * @see RecommenderEvaluator
+ */
+public interface DataModelBuilder {
+
+ /**
+ * <p>
+ * Builds a {@link DataModel} implementation to be used in an evaluation, given training data.
+ * </p>
+ *
+ * @param trainingData
+ * data to be used in the {@link DataModel}
+ * @return {@link DataModel} based upon the given data
+ */
+ DataModel buildDataModel(FastByIDMap<PreferenceArray> trainingData);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/eval/IRStatistics.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/eval/IRStatistics.java b/mr/src/main/java/org/apache/mahout/cf/taste/eval/IRStatistics.java
new file mode 100644
index 0000000..9c442ff
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/eval/IRStatistics.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.eval;
+
+/**
+ * <p>
+ * Implementations encapsulate information retrieval-related statistics about a
+ * {@link org.apache.mahout.cf.taste.recommender.Recommender}'s recommendations.
+ * </p>
+ *
+ * <p>
+ * See <a href="http://en.wikipedia.org/wiki/Information_retrieval">Information retrieval</a>.
+ * </p>
+ */
+public interface IRStatistics {
+
+ /**
+ * <p>
+ * See <a href="http://en.wikipedia.org/wiki/Information_retrieval#Precision">Precision</a>.
+ * </p>
+ */
+ double getPrecision();
+
+ /**
+ * <p>
+ * See <a href="http://en.wikipedia.org/wiki/Information_retrieval#Recall">Recall</a>.
+ * </p>
+ */
+ double getRecall();
+
+ /**
+ * <p>
+ * See <a href="http://en.wikipedia.org/wiki/Information_retrieval#Fall-Out">Fall-Out</a>.
+ * </p>
+ */
+ double getFallOut();
+
+ /**
+ * <p>
+ * See <a href="http://en.wikipedia.org/wiki/Information_retrieval#F-measure">F-measure</a>.
+ * </p>
+ */
+ double getF1Measure();
+
+ /**
+ * <p>
+ * See <a href="http://en.wikipedia.org/wiki/Information_retrieval#F-measure">F-measure</a>.
+ * </p>
+ */
+ double getFNMeasure(double n);
+
+ /**
+ * <p>
+ * See <a href="http://en.wikipedia.org/wiki/Discounted_cumulative_gain#Normalized_DCG">
+ * Normalized Discounted Cumulative Gain</a>.
+ * </p>
+ */
+ double getNormalizedDiscountedCumulativeGain();
+
+ /**
+ * @return the fraction of all users for whom recommendations could be produced
+ */
+ double getReach();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderBuilder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderBuilder.java b/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderBuilder.java
new file mode 100644
index 0000000..1805092
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderBuilder.java
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.eval;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+
+/**
+ * <p>
+ * Implementations of this inner interface are simple helper classes which create a {@link Recommender} to be
+ * evaluated based on the given {@link DataModel}.
+ * </p>
+ */
+public interface RecommenderBuilder {
+
+ /**
+ * <p>
+ * Builds a {@link Recommender} implementation to be evaluated, using the given {@link DataModel}.
+ * </p>
+ *
+ * @param dataModel
+ * {@link DataModel} to build the {@link Recommender} on
+ * @return {@link Recommender} based upon the given {@link DataModel}
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel}
+ */
+ Recommender buildRecommender(DataModel dataModel) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderEvaluator.java b/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderEvaluator.java
new file mode 100644
index 0000000..dcbbcf8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderEvaluator.java
@@ -0,0 +1,105 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.eval;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+/**
+ * <p>
+ * Implementations of this interface evaluate the quality of a
+ * {@link org.apache.mahout.cf.taste.recommender.Recommender}'s recommendations.
+ * </p>
+ */
+public interface RecommenderEvaluator {
+
+ /**
+ * <p>
+ * Evaluates the quality of a {@link org.apache.mahout.cf.taste.recommender.Recommender}'s recommendations.
+ * The range of values that may be returned depends on the implementation, but <em>lower</em> values must
+ * mean better recommendations, with 0 being the lowest / best possible evaluation, meaning a perfect match.
+ * This method does not accept a {@link org.apache.mahout.cf.taste.recommender.Recommender} directly, but
+ * rather a {@link RecommenderBuilder} which can build the
+ * {@link org.apache.mahout.cf.taste.recommender.Recommender} to test on top of a given {@link DataModel}.
+ * </p>
+ *
+ * <p>
+ * Implementations will take a certain percentage of the preferences supplied by the given {@link DataModel}
+ * as "training data". This is typically most of the data, like 90%. This data is used to produce
+ * recommendations, and the rest of the data is compared against estimated preference values to see how much
+ * the {@link org.apache.mahout.cf.taste.recommender.Recommender}'s predicted preferences match the user's
+ * real preferences. Specifically, for each user, this percentage of the user's ratings are used to produce
+ * recommendations, and for each user, the remaining preferences are compared against the user's real
+ * preferences.
+ * </p>
+ *
+ * <p>
+ * For large datasets, it may be desirable to only evaluate based on a small percentage of the data.
+ * {@code evaluationPercentage} controls how many of the {@link DataModel}'s users are used in
+ * evaluation.
+ * </p>
+ *
+ * <p>
+ * To be clear, {@code trainingPercentage} and {@code evaluationPercentage} are not related. They
+ * do not need to add up to 1.0, for example.
+ * </p>
+ *
+ * @param recommenderBuilder
+ * object that can build a {@link org.apache.mahout.cf.taste.recommender.Recommender} to test
+ * @param dataModelBuilder
+ * {@link DataModelBuilder} to use, or if null, a default {@link DataModel}
+ * implementation will be used
+ * @param dataModel
+ * dataset to test on
+ * @param trainingPercentage
+ * percentage of each user's preferences to use to produce recommendations; the rest are compared
+ * to estimated preference values to evaluate
+ * {@link org.apache.mahout.cf.taste.recommender.Recommender} performance
+ * @param evaluationPercentage
+ * percentage of users to use in evaluation
+ * @return a "score" representing how well the {@link org.apache.mahout.cf.taste.recommender.Recommender}'s
+ * estimated preferences match real values; <em>lower</em> scores mean a better match and 0 is a
+ * perfect match
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel}
+ */
+ double evaluate(RecommenderBuilder recommenderBuilder,
+ DataModelBuilder dataModelBuilder,
+ DataModel dataModel,
+ double trainingPercentage,
+ double evaluationPercentage) throws TasteException;
+
+ /**
+ * @deprecated see {@link DataModel#getMaxPreference()}
+ */
+ @Deprecated
+ float getMaxPreference();
+
+ @Deprecated
+ void setMaxPreference(float maxPreference);
+
+ /**
+ * @deprecated see {@link DataModel#getMinPreference()}
+ */
+ @Deprecated
+ float getMinPreference();
+
+ @Deprecated
+ void setMinPreference(float minPreference);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderIRStatsEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderIRStatsEvaluator.java b/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderIRStatsEvaluator.java
new file mode 100644
index 0000000..6e4e9c7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/eval/RecommenderIRStatsEvaluator.java
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.eval;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+
+/**
+ * <p>
+ * Implementations collect information retrieval-related statistics on a
+ * {@link org.apache.mahout.cf.taste.recommender.Recommender}'s performance, including precision, recall and
+ * f-measure.
+ * </p>
+ *
+ * <p>
+ * See <a href="http://en.wikipedia.org/wiki/Information_retrieval">Information retrieval</a>.
+ */
+public interface RecommenderIRStatsEvaluator {
+
+ /**
+ * @param recommenderBuilder
+ * object that can build a {@link org.apache.mahout.cf.taste.recommender.Recommender} to test
+ * @param dataModelBuilder
+ * {@link DataModelBuilder} to use, or if null, a default {@link DataModel} implementation will be
+ * used
+ * @param dataModel
+ * dataset to test on
+ * @param rescorer
+ * if any, to use when computing recommendations
+ * @param at
+ * as in, "precision at 5". The number of recommendations to consider when evaluating precision,
+ * etc.
+ * @param relevanceThreshold
+ * items whose preference value is at least this value are considered "relevant" for the purposes
+ * of computations
+ * @return {@link IRStatistics} with resulting precision, recall, etc.
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel}
+ */
+ IRStatistics evaluate(RecommenderBuilder recommenderBuilder,
+ DataModelBuilder dataModelBuilder,
+ DataModel dataModel,
+ IDRescorer rescorer,
+ int at,
+ double relevanceThreshold,
+ double evaluationPercentage) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/eval/RelevantItemsDataSplitter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/eval/RelevantItemsDataSplitter.java b/mr/src/main/java/org/apache/mahout/cf/taste/eval/RelevantItemsDataSplitter.java
new file mode 100644
index 0000000..da318d5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/eval/RelevantItemsDataSplitter.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.eval;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+
+/**
+ * Implementations of this interface determine the items that are considered relevant,
+ * and splits data into a training and test subset, for purposes of precision/recall
+ * tests as implemented by implementations of {@link RecommenderIRStatsEvaluator}.
+ */
+public interface RelevantItemsDataSplitter {
+
+ /**
+ * During testing, relevant items are removed from a particular users' preferences,
+ * and a model is build using this user's other preferences and all other users.
+ *
+ * @param at Maximum number of items to be removed
+ * @param relevanceThreshold Minimum strength of preference for an item to be considered
+ * relevant
+ * @return IDs of relevant items
+ */
+ FastIDSet getRelevantItemsIDs(long userID,
+ int at,
+ double relevanceThreshold,
+ DataModel dataModel) throws TasteException;
+
+ /**
+ * Adds a single user and all their preferences to the training model.
+ *
+ * @param userID ID of user whose preferences we are trying to predict
+ * @param relevantItemIDs IDs of items considered relevant to that user
+ * @param trainingUsers the database of training preferences to which we will
+ * append the ones for otherUserID.
+ * @param otherUserID for whom we are adding preferences to the training model
+ */
+ void processOtherUser(long userID,
+ FastIDSet relevantItemIDs,
+ FastByIDMap<PreferenceArray> trainingUsers,
+ long otherUserID,
+ DataModel dataModel) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityEntityWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityEntityWritable.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityEntityWritable.java
new file mode 100644
index 0000000..e70a675
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityEntityWritable.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import com.google.common.primitives.Longs;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.mahout.math.Varint;
+
+/** A {@link WritableComparable} encapsulating two items. */
+public final class EntityEntityWritable implements WritableComparable<EntityEntityWritable>, Cloneable {
+
+ private long aID;
+ private long bID;
+
+ public EntityEntityWritable() {
+ // do nothing
+ }
+
+ public EntityEntityWritable(long aID, long bID) {
+ this.aID = aID;
+ this.bID = bID;
+ }
+
+ long getAID() {
+ return aID;
+ }
+
+ long getBID() {
+ return bID;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ Varint.writeSignedVarLong(aID, out);
+ Varint.writeSignedVarLong(bID, out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ aID = Varint.readSignedVarLong(in);
+ bID = Varint.readSignedVarLong(in);
+ }
+
+ @Override
+ public int compareTo(EntityEntityWritable that) {
+ int aCompare = compare(aID, that.getAID());
+ return aCompare == 0 ? compare(bID, that.getBID()) : aCompare;
+ }
+
+ private static int compare(long a, long b) {
+ return a < b ? -1 : a > b ? 1 : 0;
+ }
+
+ @Override
+ public int hashCode() {
+ return Longs.hashCode(aID) + 31 * Longs.hashCode(bID);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (o instanceof EntityEntityWritable) {
+ EntityEntityWritable that = (EntityEntityWritable) o;
+ return aID == that.getAID() && bID == that.getBID();
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return aID + "\t" + bID;
+ }
+
+ @Override
+ public EntityEntityWritable clone() {
+ return new EntityEntityWritable(aID, bID);
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityPrefWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityPrefWritable.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityPrefWritable.java
new file mode 100644
index 0000000..2aab63c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/EntityPrefWritable.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.VarLongWritable;
+
+/** A {@link org.apache.hadoop.io.Writable} encapsulating an item ID and a preference value. */
+public final class EntityPrefWritable extends VarLongWritable implements Cloneable {
+
+ private float prefValue;
+
+ public EntityPrefWritable() {
+ // do nothing
+ }
+
+ public EntityPrefWritable(long itemID, float prefValue) {
+ super(itemID);
+ this.prefValue = prefValue;
+ }
+
+ public EntityPrefWritable(EntityPrefWritable other) {
+ this(other.get(), other.getPrefValue());
+ }
+
+ public long getID() {
+ return get();
+ }
+
+ public float getPrefValue() {
+ return prefValue;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ super.write(out);
+ out.writeFloat(prefValue);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ super.readFields(in);
+ prefValue = in.readFloat();
+ }
+
+ @Override
+ public int hashCode() {
+ return super.hashCode() ^ RandomUtils.hashFloat(prefValue);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof EntityPrefWritable)) {
+ return false;
+ }
+ EntityPrefWritable other = (EntityPrefWritable) o;
+ return get() == other.get() && prefValue == other.getPrefValue();
+ }
+
+ @Override
+ public String toString() {
+ return get() + "\t" + prefValue;
+ }
+
+ @Override
+ public EntityPrefWritable clone() {
+ return new EntityPrefWritable(get(), prefValue);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/MutableRecommendedItem.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/MutableRecommendedItem.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/MutableRecommendedItem.java
new file mode 100644
index 0000000..3de272d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/MutableRecommendedItem.java
@@ -0,0 +1,81 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.common.RandomUtils;
+
+/**
+ * Mutable variant of {@link RecommendedItem}
+ */
+public class MutableRecommendedItem implements RecommendedItem {
+
+ private long itemID;
+ private float value;
+
+ public MutableRecommendedItem() {}
+
+ public MutableRecommendedItem(long itemID, float value) {
+ this.itemID = itemID;
+ this.value = value;
+ }
+
+ @Override
+ public long getItemID() {
+ return itemID;
+ }
+
+ @Override
+ public float getValue() {
+ return value;
+ }
+
+ public void setItemID(long itemID) {
+ this.itemID = itemID;
+ }
+
+ public void set(long itemID, float value) {
+ this.itemID = itemID;
+ this.value = value;
+ }
+
+ public void capToMaxValue(float maxValue) {
+ if (value > maxValue) {
+ value = maxValue;
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "MutableRecommendedItem[item:" + itemID + ", value:" + value + ']';
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) itemID ^ RandomUtils.hashFloat(value);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof MutableRecommendedItem)) {
+ return false;
+ }
+ RecommendedItem other = (RecommendedItem) o;
+ return itemID == other.getItemID() && value == other.getValue();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/RecommendedItemsWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/RecommendedItemsWritable.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/RecommendedItemsWritable.java
new file mode 100644
index 0000000..947204d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/RecommendedItemsWritable.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.cf.taste.impl.recommender.GenericRecommendedItem;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.math.Varint;
+
+/**
+ * A {@link Writable} which encapsulates a list of {@link RecommendedItem}s. This is the mapper (and reducer)
+ * output, and represents items recommended to a user. The first item is the one whose estimated preference is
+ * highest.
+ */
+public final class RecommendedItemsWritable implements Writable {
+
+ private List<RecommendedItem> recommended;
+
+ public RecommendedItemsWritable() {
+ // do nothing
+ }
+
+ public RecommendedItemsWritable(List<RecommendedItem> recommended) {
+ this.recommended = recommended;
+ }
+
+ public List<RecommendedItem> getRecommendedItems() {
+ return recommended;
+ }
+
+ public void set(List<RecommendedItem> recommended) {
+ this.recommended = recommended;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(recommended.size());
+ for (RecommendedItem item : recommended) {
+ Varint.writeSignedVarLong(item.getItemID(), out);
+ out.writeFloat(item.getValue());
+ }
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int size = in.readInt();
+ recommended = Lists.newArrayListWithCapacity(size);
+ for (int i = 0; i < size; i++) {
+ long itemID = Varint.readSignedVarLong(in);
+ float value = in.readFloat();
+ RecommendedItem recommendedItem = new GenericRecommendedItem(itemID, value);
+ recommended.add(recommendedItem);
+ }
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder result = new StringBuilder(200);
+ result.append('[');
+ boolean first = true;
+ for (RecommendedItem item : recommended) {
+ if (first) {
+ first = false;
+ } else {
+ result.append(',');
+ }
+ result.append(String.valueOf(item.getItemID()));
+ result.append(':');
+ result.append(String.valueOf(item.getValue()));
+ }
+ result.append(']');
+ return result.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtils.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtils.java
new file mode 100644
index 0000000..e3fab29
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtils.java
@@ -0,0 +1,84 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import com.google.common.primitives.Longs;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.VarIntWritable;
+import org.apache.mahout.math.VarLongWritable;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+
+import java.util.regex.Pattern;
+
+/**
+ * Some helper methods for the hadoop-related stuff in org.apache.mahout.cf.taste
+ */
+public final class TasteHadoopUtils {
+
+ public static final int USER_ID_POS = 0;
+ public static final int ITEM_ID_POS = 1;
+
+ /** Standard delimiter of textual preference data */
+ private static final Pattern PREFERENCE_TOKEN_DELIMITER = Pattern.compile("[\t,]");
+
+ private TasteHadoopUtils() {}
+
+ /**
+ * Splits a preference data line into string tokens
+ */
+ public static String[] splitPrefTokens(CharSequence line) {
+ return PREFERENCE_TOKEN_DELIMITER.split(line);
+ }
+
+ /**
+ * Maps a long to an int with range of 0 to Integer.MAX_VALUE-1
+ */
+ public static int idToIndex(long id) {
+ return 0x7FFFFFFF & Longs.hashCode(id) % 0x7FFFFFFE;
+ }
+
+ public static int readID(String token, boolean usesLongIDs) {
+ return usesLongIDs ? idToIndex(Long.parseLong(token)) : Integer.parseInt(token);
+ }
+
+ /**
+ * Reads a binary mapping file
+ */
+ public static OpenIntLongHashMap readIDIndexMap(String idIndexPathStr, Configuration conf) {
+ OpenIntLongHashMap indexIDMap = new OpenIntLongHashMap();
+ Path itemIDIndexPath = new Path(idIndexPathStr);
+ for (Pair<VarIntWritable,VarLongWritable> record
+ : new SequenceFileDirIterable<VarIntWritable,VarLongWritable>(itemIDIndexPath,
+ PathType.LIST,
+ PathFilters.partFilter(),
+ null,
+ true,
+ conf)) {
+ indexIDMap.put(record.getFirst().get(), record.getSecond().get());
+ }
+ return indexIDMap;
+ }
+
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToEntityPrefsMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToEntityPrefsMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToEntityPrefsMapper.java
new file mode 100644
index 0000000..fdb552e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToEntityPrefsMapper.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.cf.taste.hadoop.item.RecommenderJob;
+import org.apache.mahout.math.VarLongWritable;
+
+import java.io.IOException;
+import java.util.regex.Pattern;
+
+public abstract class ToEntityPrefsMapper extends
+ Mapper<LongWritable,Text, VarLongWritable,VarLongWritable> {
+
+ public static final String TRANSPOSE_USER_ITEM = ToEntityPrefsMapper.class + "transposeUserItem";
+ public static final String RATING_SHIFT = ToEntityPrefsMapper.class + "shiftRatings";
+
+ private static final Pattern DELIMITER = Pattern.compile("[\t,]");
+
+ private boolean booleanData;
+ private boolean transpose;
+ private final boolean itemKey;
+ private float ratingShift;
+
+ ToEntityPrefsMapper(boolean itemKey) {
+ this.itemKey = itemKey;
+ }
+
+ @Override
+ protected void setup(Context context) {
+ Configuration jobConf = context.getConfiguration();
+ booleanData = jobConf.getBoolean(RecommenderJob.BOOLEAN_DATA, false);
+ transpose = jobConf.getBoolean(TRANSPOSE_USER_ITEM, false);
+ ratingShift = Float.parseFloat(jobConf.get(RATING_SHIFT, "0.0"));
+ }
+
+ @Override
+ public void map(LongWritable key,
+ Text value,
+ Context context) throws IOException, InterruptedException {
+ String[] tokens = DELIMITER.split(value.toString());
+ long userID = Long.parseLong(tokens[0]);
+ long itemID = Long.parseLong(tokens[1]);
+ if (itemKey ^ transpose) {
+ // If using items as keys, and not transposing items and users, then users are items!
+ // Or if not using items as keys (users are, as usual), but transposing items and users,
+ // then users are items! Confused?
+ long temp = userID;
+ userID = itemID;
+ itemID = temp;
+ }
+ if (booleanData) {
+ context.write(new VarLongWritable(userID), new VarLongWritable(itemID));
+ } else {
+ float prefValue = tokens.length > 2 ? Float.parseFloat(tokens[2]) + ratingShift : 1.0f;
+ context.write(new VarLongWritable(userID), new EntityPrefWritable(itemID, prefValue));
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToItemPrefsMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToItemPrefsMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToItemPrefsMapper.java
new file mode 100644
index 0000000..f5f9574
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/ToItemPrefsMapper.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+/**
+ * <h1>Input</h1>
+ *
+ * <p>
+ * Intended for use with {@link org.apache.hadoop.mapreduce.lib.input.TextInputFormat};
+ * accepts line number / line pairs as
+ * {@link org.apache.hadoop.io.LongWritable}/{@link org.apache.hadoop.io.Text} pairs.
+ * </p>
+ *
+ * <p>
+ * Each line is assumed to be of the form {@code userID,itemID,preference}, or {@code userID,itemID}.
+ * </p>
+ *
+ * <h1>Output</h1>
+ *
+ * <p>
+ * Outputs the user ID as a {@link org.apache.mahout.math.VarLongWritable} mapped to the item ID and preference as a
+ * {@link EntityPrefWritable}.
+ * </p>
+ */
+public final class ToItemPrefsMapper extends ToEntityPrefsMapper {
+
+ public ToItemPrefsMapper() {
+ super(false);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueue.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueue.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueue.java
new file mode 100644
index 0000000..0f9ea75
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueue.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import com.google.common.collect.Lists;
+import org.apache.lucene.util.PriorityQueue;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+
+import java.util.Collections;
+import java.util.List;
+
+public class TopItemsQueue extends PriorityQueue<MutableRecommendedItem> {
+
+ private static final long SENTINEL_ID = Long.MIN_VALUE;
+
+ private final int maxSize;
+
+ public TopItemsQueue(int maxSize) {
+ super(maxSize);
+ this.maxSize = maxSize;
+ }
+
+ public List<RecommendedItem> getTopItems() {
+ List<RecommendedItem> recommendedItems = Lists.newArrayListWithCapacity(maxSize);
+ while (size() > 0) {
+ MutableRecommendedItem topItem = pop();
+ // filter out "sentinel" objects necessary for maintaining an efficient priority queue
+ if (topItem.getItemID() != SENTINEL_ID) {
+ recommendedItems.add(topItem);
+ }
+ }
+ Collections.reverse(recommendedItems);
+ return recommendedItems;
+ }
+
+ @Override
+ protected boolean lessThan(MutableRecommendedItem one, MutableRecommendedItem two) {
+ return one.getValue() < two.getValue();
+ }
+
+ @Override
+ protected MutableRecommendedItem getSentinelObject() {
+ return new MutableRecommendedItem(SENTINEL_ID, Float.MIN_VALUE);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
new file mode 100644
index 0000000..c5ccf38
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/ALS.java
@@ -0,0 +1,107 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalFileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.als.AlternatingLeastSquaresSolver;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.List;
+
+final class ALS {
+
+ private ALS() {}
+
+ static Vector readFirstRow(Path dir, Configuration conf) throws IOException {
+ Iterator<VectorWritable> iterator = new SequenceFileDirValueIterator<>(dir, PathType.LIST,
+ PathFilters.partFilter(), null, true, conf);
+ return iterator.hasNext() ? iterator.next().get() : null;
+ }
+
+ public static OpenIntObjectHashMap<Vector> readMatrixByRowsFromDistributedCache(int numEntities,
+ Configuration conf) throws IOException {
+
+ IntWritable rowIndex = new IntWritable();
+ VectorWritable row = new VectorWritable();
+
+
+ OpenIntObjectHashMap<Vector> featureMatrix = numEntities > 0
+ ? new OpenIntObjectHashMap<Vector>(numEntities) : new OpenIntObjectHashMap<Vector>();
+
+ Path[] cachedFiles = HadoopUtil.getCachedFiles(conf);
+ LocalFileSystem localFs = FileSystem.getLocal(conf);
+
+ for (Path cachedFile : cachedFiles) {
+
+ SequenceFile.Reader reader = null;
+ try {
+ reader = new SequenceFile.Reader(localFs, cachedFile, conf);
+ while (reader.next(rowIndex, row)) {
+ featureMatrix.put(rowIndex.get(), row.get());
+ }
+ } finally {
+ Closeables.close(reader, true);
+ }
+ }
+
+ Preconditions.checkState(!featureMatrix.isEmpty(), "Feature matrix is empty");
+ return featureMatrix;
+ }
+
+ public static OpenIntObjectHashMap<Vector> readMatrixByRows(Path dir, Configuration conf) {
+ OpenIntObjectHashMap<Vector> matrix = new OpenIntObjectHashMap<>();
+ for (Pair<IntWritable,VectorWritable> pair
+ : new SequenceFileDirIterable<IntWritable,VectorWritable>(dir, PathType.LIST, PathFilters.partFilter(), conf)) {
+ int rowIndex = pair.getFirst().get();
+ Vector row = pair.getSecond().get();
+ matrix.put(rowIndex, row);
+ }
+ return matrix;
+ }
+
+ public static Vector solveExplicit(VectorWritable ratingsWritable, OpenIntObjectHashMap<Vector> uOrM,
+ double lambda, int numFeatures) {
+ Vector ratings = ratingsWritable.get();
+
+ List<Vector> featureVectors = Lists.newArrayListWithCapacity(ratings.getNumNondefaultElements());
+ for (Vector.Element e : ratings.nonZeroes()) {
+ int index = e.index();
+ featureVectors.add(uOrM.get(index));
+ }
+
+ return AlternatingLeastSquaresSolver.solve(featureVectors, ratings, lambda, numFeatures);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java
new file mode 100644
index 0000000..b061a63
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/DatasetSplitter.java
@@ -0,0 +1,158 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.RandomUtils;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * <p>Split a recommendation dataset into a training and a test set</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--input (path): Directory containing one or more text files with the dataset</li>
+ * <li>--output (path): path where output should go</li>
+ * <li>--trainingPercentage (double): percentage of the data to use as training set (optional, default 0.9)</li>
+ * <li>--probePercentage (double): percentage of the data to use as probe set (optional, default 0.1)</li>
+ * </ol>
+ */
+public class DatasetSplitter extends AbstractJob {
+
+ private static final String TRAINING_PERCENTAGE = DatasetSplitter.class.getName() + ".trainingPercentage";
+ private static final String PROBE_PERCENTAGE = DatasetSplitter.class.getName() + ".probePercentage";
+ private static final String PART_TO_USE = DatasetSplitter.class.getName() + ".partToUse";
+
+ private static final Text INTO_TRAINING_SET = new Text("T");
+ private static final Text INTO_PROBE_SET = new Text("P");
+
+ private static final double DEFAULT_TRAINING_PERCENTAGE = 0.9;
+ private static final double DEFAULT_PROBE_PERCENTAGE = 0.1;
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new DatasetSplitter(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("trainingPercentage", "t", "percentage of the data to use as training set (default: "
+ + DEFAULT_TRAINING_PERCENTAGE + ')', String.valueOf(DEFAULT_TRAINING_PERCENTAGE));
+ addOption("probePercentage", "p", "percentage of the data to use as probe set (default: "
+ + DEFAULT_PROBE_PERCENTAGE + ')', String.valueOf(DEFAULT_PROBE_PERCENTAGE));
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ double trainingPercentage = Double.parseDouble(getOption("trainingPercentage"));
+ double probePercentage = Double.parseDouble(getOption("probePercentage"));
+ String tempDir = getOption("tempDir");
+
+ Path markedPrefs = new Path(tempDir, "markedPreferences");
+ Path trainingSetPath = new Path(getOutputPath(), "trainingSet");
+ Path probeSetPath = new Path(getOutputPath(), "probeSet");
+
+ Job markPreferences = prepareJob(getInputPath(), markedPrefs, TextInputFormat.class, MarkPreferencesMapper.class,
+ Text.class, Text.class, SequenceFileOutputFormat.class);
+ markPreferences.getConfiguration().set(TRAINING_PERCENTAGE, String.valueOf(trainingPercentage));
+ markPreferences.getConfiguration().set(PROBE_PERCENTAGE, String.valueOf(probePercentage));
+ boolean succeeded = markPreferences.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ Job createTrainingSet = prepareJob(markedPrefs, trainingSetPath, SequenceFileInputFormat.class,
+ WritePrefsMapper.class, NullWritable.class, Text.class, TextOutputFormat.class);
+ createTrainingSet.getConfiguration().set(PART_TO_USE, INTO_TRAINING_SET.toString());
+ succeeded = createTrainingSet.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ Job createProbeSet = prepareJob(markedPrefs, probeSetPath, SequenceFileInputFormat.class,
+ WritePrefsMapper.class, NullWritable.class, Text.class, TextOutputFormat.class);
+ createProbeSet.getConfiguration().set(PART_TO_USE, INTO_PROBE_SET.toString());
+ succeeded = createProbeSet.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ return 0;
+ }
+
+ static class MarkPreferencesMapper extends Mapper<LongWritable,Text,Text,Text> {
+
+ private Random random;
+ private double trainingBound;
+ private double probeBound;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ random = RandomUtils.getRandom();
+ trainingBound = Double.parseDouble(ctx.getConfiguration().get(TRAINING_PERCENTAGE));
+ probeBound = trainingBound + Double.parseDouble(ctx.getConfiguration().get(PROBE_PERCENTAGE));
+ }
+
+ @Override
+ protected void map(LongWritable key, Text text, Context ctx) throws IOException, InterruptedException {
+ double randomValue = random.nextDouble();
+ if (randomValue <= trainingBound) {
+ ctx.write(INTO_TRAINING_SET, text);
+ } else if (randomValue <= probeBound) {
+ ctx.write(INTO_PROBE_SET, text);
+ }
+ }
+ }
+
+ static class WritePrefsMapper extends Mapper<Text,Text,NullWritable,Text> {
+
+ private String partToUse;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ partToUse = ctx.getConfiguration().get(PART_TO_USE);
+ }
+
+ @Override
+ protected void map(Text key, Text text, Context ctx) throws IOException, InterruptedException {
+ if (partToUse.equals(key.toString())) {
+ ctx.write(NullWritable.get(), text);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java
new file mode 100644
index 0000000..3048b77
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/FactorizationEvaluator.java
@@ -0,0 +1,172 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import java.io.BufferedWriter;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.base.Charsets;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.TextInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+
+/**
+ * <p>Measures the root-mean-squared error of a rating matrix factorization against a test set.</p>
+ *
+ * <p>Command line arguments specific to this class are:</p>
+ *
+ * <ol>
+ * <li>--output (path): path where output should go</li>
+ * <li>--pairs (path): path containing the test ratings, each line must be userID,itemID,rating</li>
+ * <li>--userFeatures (path): path to the user feature matrix</li>
+ * <li>--itemFeatures (path): path to the item feature matrix</li>
+ * </ol>
+ */
+public class FactorizationEvaluator extends AbstractJob {
+
+ private static final String USER_FEATURES_PATH = RecommenderJob.class.getName() + ".userFeatures";
+ private static final String ITEM_FEATURES_PATH = RecommenderJob.class.getName() + ".itemFeatures";
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new FactorizationEvaluator(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOption("userFeatures", null, "path to the user feature matrix", true);
+ addOption("itemFeatures", null, "path to the item feature matrix", true);
+ addOption("usesLongIDs", null, "input contains long IDs that need to be translated");
+ addOutputOption();
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ Path errors = getTempPath("errors");
+
+ Job predictRatings = prepareJob(getInputPath(), errors, TextInputFormat.class, PredictRatingsMapper.class,
+ DoubleWritable.class, NullWritable.class, SequenceFileOutputFormat.class);
+
+ Configuration conf = predictRatings.getConfiguration();
+ conf.set(USER_FEATURES_PATH, getOption("userFeatures"));
+ conf.set(ITEM_FEATURES_PATH, getOption("itemFeatures"));
+
+ boolean usesLongIDs = Boolean.parseBoolean(getOption("usesLongIDs"));
+ if (usesLongIDs) {
+ conf.set(ParallelALSFactorizationJob.USES_LONG_IDS, String.valueOf(true));
+ }
+
+
+ boolean succeeded = predictRatings.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+
+ BufferedWriter writer = null;
+ try {
+ FileSystem fs = FileSystem.get(getOutputPath().toUri(), getConf());
+ FSDataOutputStream outputStream = fs.create(getOutputPath("rmse.txt"));
+ double rmse = computeRmse(errors);
+ writer = new BufferedWriter(new OutputStreamWriter(outputStream, Charsets.UTF_8));
+ writer.write(String.valueOf(rmse));
+ } finally {
+ Closeables.close(writer, false);
+ }
+
+ return 0;
+ }
+
+ double computeRmse(Path errors) {
+ RunningAverage average = new FullRunningAverage();
+ for (Pair<DoubleWritable,NullWritable> entry
+ : new SequenceFileDirIterable<DoubleWritable, NullWritable>(errors, PathType.LIST, PathFilters.logsCRCFilter(),
+ getConf())) {
+ DoubleWritable error = entry.getFirst();
+ average.addDatum(error.get() * error.get());
+ }
+
+ return Math.sqrt(average.getAverage());
+ }
+
+ public static class PredictRatingsMapper extends Mapper<LongWritable,Text,DoubleWritable,NullWritable> {
+
+ private OpenIntObjectHashMap<Vector> U;
+ private OpenIntObjectHashMap<Vector> M;
+
+ private boolean usesLongIDs;
+
+ private final DoubleWritable error = new DoubleWritable();
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ Configuration conf = ctx.getConfiguration();
+
+ Path pathToU = new Path(conf.get(USER_FEATURES_PATH));
+ Path pathToM = new Path(conf.get(ITEM_FEATURES_PATH));
+
+ U = ALS.readMatrixByRows(pathToU, conf);
+ M = ALS.readMatrixByRows(pathToM, conf);
+
+ usesLongIDs = conf.getBoolean(ParallelALSFactorizationJob.USES_LONG_IDS, false);
+ }
+
+ @Override
+ protected void map(LongWritable key, Text value, Context ctx) throws IOException, InterruptedException {
+
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(value.toString());
+
+ int userID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.USER_ID_POS], usesLongIDs);
+ int itemID = TasteHadoopUtils.readID(tokens[TasteHadoopUtils.ITEM_ID_POS], usesLongIDs);
+ double rating = Double.parseDouble(tokens[2]);
+
+ if (U.containsKey(userID) && M.containsKey(itemID)) {
+ double estimate = U.get(userID).dot(M.get(itemID));
+ error.set(rating - estimate);
+ ctx.write(error, NullWritable.get());
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java
new file mode 100644
index 0000000..d93e3a4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/hadoop/als/MultithreadedSharingMapper.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.JobContext;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.map.MultithreadedMapper;
+import org.apache.hadoop.util.ReflectionUtils;
+
+import java.io.IOException;
+
+/**
+ * Multithreaded Mapper for {@link SharingMapper}s. Will call setupSharedInstance() once in the controlling thread
+ * before executing the mappers using a thread pool.
+ *
+ * @param <K1>
+ * @param <V1>
+ * @param <K2>
+ * @param <V2>
+ */
+public class MultithreadedSharingMapper<K1, V1, K2, V2> extends MultithreadedMapper<K1, V1, K2, V2> {
+
+ @Override
+ public void run(Context ctx) throws IOException, InterruptedException {
+ Class<Mapper<K1, V1, K2, V2>> mapperClass =
+ MultithreadedSharingMapper.getMapperClass((JobContext) ctx);
+ Preconditions.checkNotNull(mapperClass, "Could not find Multithreaded Mapper class.");
+
+ Configuration conf = ctx.getConfiguration();
+ // instantiate the mapper
+ Mapper<K1, V1, K2, V2> mapper1 = ReflectionUtils.newInstance(mapperClass, conf);
+ SharingMapper<K1, V1, K2, V2, ?> mapper = null;
+ if (mapper1 instanceof SharingMapper) {
+ mapper = (SharingMapper<K1, V1, K2, V2, ?>) mapper1;
+ }
+ Preconditions.checkNotNull(mapper, "Could not instantiate SharingMapper. Class was: %s",
+ mapper1.getClass().getName());
+
+ // single threaded call to setup the sharing mapper
+ mapper.setupSharedInstance(ctx);
+
+ // multithreaded execution
+ super.run(ctx);
+ }
+}
[50/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/CHANGELOG
----------------------------------------------------------------------
diff --git a/CHANGELOG b/CHANGELOG
index 2a82c5a..054e96b 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -2,6 +2,8 @@ Mahout Change Log
Release 0.10.0 - unreleased
+ MAHOUT-1655: Refactors mr-legacy into mahout-hdfs and mahout-mr, Spark now depends on much reduced mahout-hdfs
+
MAHOUT-1522: Handle logging levels via log4j.xml (akm)
MAHOUT-1602: Euclidean Distance Similarity Math (Leonardo Fernandez Sanchez, smarthi)
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/bin/mahout
----------------------------------------------------------------------
diff --git a/bin/mahout b/bin/mahout
index c51c239..772c184 100755
--- a/bin/mahout
+++ b/bin/mahout
@@ -182,7 +182,7 @@ then
done
if [ "$H2O" == "1" ]; then
- for f in $MAHOUT_HOME/mrlegacy/target/mahout-mrlegacy-*.jar; do
+ for f in $MAHOUT_HOME/hdfs/target/mahout-hdfs-*.jar; do
CLASSPATH=${CLASSPATH}:$f;
done
@@ -194,7 +194,7 @@ then
# add jars for running from the command line if we requested shell or spark CLI driver
if [ "$SPARK" == "1" ]; then
- for f in $MAHOUT_HOME/mrlegacy/target/mahout-mrlegacy-*.jar ; do
+ for f in $MAHOUT_HOME/hdfs/target/mahout-hdfs-*.jar ; do
CLASSPATH=${CLASSPATH}:$f;
done
@@ -227,11 +227,11 @@ then
done
else
CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/math/target/classes
- CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/mrlegacy/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/hdfs/target/classes
+ CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/mr/target/classes
CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/integration/target/classes
CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/examples/target/classes
CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/math-scala/target/classes
- #CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/mrlegacy/src/main/resources
CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/spark/target/classes
CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/spark-shell/target/classes
CLASSPATH=${CLASSPATH}:$MAHOUT_HOME/h2o/target/classes
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/distribution/pom.xml
----------------------------------------------------------------------
diff --git a/distribution/pom.xml b/distribution/pom.xml
index b2e5071..f3e6336 100644
--- a/distribution/pom.xml
+++ b/distribution/pom.xml
@@ -91,7 +91,11 @@
</dependency>
<dependency>
<groupId>org.apache.mahout</groupId>
- <artifactId>mahout-mrlegacy</artifactId>
+ <artifactId>mahout-hdfs</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.mahout</groupId>
+ <artifactId>mahout-mr</artifactId>
</dependency>
<dependency>
<groupId>org.apache.mahout</groupId>
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/distribution/src/main/assembly/bin.xml
----------------------------------------------------------------------
diff --git a/distribution/src/main/assembly/bin.xml b/distribution/src/main/assembly/bin.xml
index a9b933e..d6c9076 100644
--- a/distribution/src/main/assembly/bin.xml
+++ b/distribution/src/main/assembly/bin.xml
@@ -65,7 +65,20 @@
<outputDirectory/>
</fileSet>
<fileSet>
- <directory>${project.basedir}/../mrlegacy/target</directory>
+ <directory>${project.basedir}/../hdfs/target</directory>
+ <includes>
+ <include>mahout-*.job</include>
+ <include>mahout-*.jar</include>
+ </includes>
+ <excludes>
+ <exclude>*sources.jar</exclude>
+ <exclude>*javadoc.jar</exclude>
+ <exclude>*tests.jar</exclude>
+ </excludes>
+ <outputDirectory/>
+ </fileSet>
+ <fileSet>
+ <directory>${project.basedir}/../mr/target</directory>
<includes>
<include>mahout-*.job</include>
<include>mahout-*.jar</include>
@@ -112,8 +125,12 @@
<outputDirectory>docs/mahout-math</outputDirectory>
</fileSet>
<fileSet>
- <directory>${project.basedir}/../mrlegacy/target/apidocs</directory>
- <outputDirectory>docs/mahout-mrlegacy</outputDirectory>
+ <directory>${project.basedir}/../hdfs/target/apidocs</directory>
+ <outputDirectory>docs/mahout-hdfs</outputDirectory>
+ </fileSet>
+ <fileSet>
+ <directory>${project.basedir}/../mr/target/apidocs</directory>
+ <outputDirectory>docs/mahout-mr</outputDirectory>
</fileSet>
<fileSet>
<directory>${project.basedir}/../integration/target/apidocs</directory>
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/examples/pom.xml
----------------------------------------------------------------------
diff --git a/examples/pom.xml b/examples/pom.xml
index dbf4bd5..b710388 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -108,11 +108,21 @@
<!-- our modules -->
<dependency>
<groupId>${project.groupId}</groupId>
- <artifactId>mahout-mrlegacy</artifactId>
+ <artifactId>mahout-hdfs</artifactId>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
- <artifactId>mahout-mrlegacy</artifactId>
+ <artifactId>mahout-mr</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-hdfs</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-mr</artifactId>
<type>test-jar</type>
<scope>test</scope>
</dependency>
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java b/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java
index 9ce2104..411c68f 100644
--- a/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java
+++ b/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java
@@ -227,7 +227,7 @@ public class TestForest extends Configured implements Tool {
Random rng = RandomUtils.getRandom();
List<double[]> resList = Lists.newArrayList();
- if (dataFS.getFileStatus(dataPath).isDirectory()) {
+ if (dataFS.getFileStatus(dataPath).isDir()) {
//the input is a directory of files
testDirectory(outputPath, converter, forest, dataset, resList, rng);
} else {
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/h2o/pom.xml
----------------------------------------------------------------------
diff --git a/h2o/pom.xml b/h2o/pom.xml
index 9dc4e62..92beeca 100644
--- a/h2o/pom.xml
+++ b/h2o/pom.xml
@@ -132,10 +132,10 @@
<version>${project.version}</version>
</dependency>
+ <!-- for MatrixWritable and VectorWritable -->
<dependency>
- <!-- for MatrixWritable and VectorWritable -->
<groupId>org.apache.mahout</groupId>
- <artifactId>mahout-mrlegacy</artifactId>
+ <artifactId>mahout-hdfs</artifactId>
<version>${project.version}</version>
</dependency>
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/hdfs/pom.xml
----------------------------------------------------------------------
diff --git a/hdfs/pom.xml b/hdfs/pom.xml
new file mode 100644
index 0000000..7e77162
--- /dev/null
+++ b/hdfs/pom.xml
@@ -0,0 +1,216 @@
+<?xml version="1.0" encoding="UTF-8"?>
+
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+
+ <parent>
+ <groupId>org.apache.mahout</groupId>
+ <artifactId>mahout</artifactId>
+ <version>1.0-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <!-- modules inherit parent's group id and version. -->
+ <artifactId>mahout-hdfs</artifactId>
+ <name>Mahout HDFS</name>
+ <description>Scalable machine learning libraries</description>
+
+ <packaging>jar</packaging>
+
+ <build>
+ <resources>
+ <resource>
+ <directory>src/main/resources</directory>
+ </resource>
+ <resource>
+ <directory>../src/conf</directory>
+ <includes>
+ <include>driver.classes.default.props</include>
+ </includes>
+ </resource>
+ </resources>
+ <plugins>
+ <!-- create test jar so other modules can reuse the core test utility classes. -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <executions>
+ <execution>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+
+ <plugin>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ </plugin>
+
+ <plugin>
+ <artifactId>maven-source-plugin</artifactId>
+ </plugin>
+
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-remote-resources-plugin</artifactId>
+ <configuration>
+ <appendedResourcesDirectory>../src/main/appended-resources</appendedResourcesDirectory>
+ <resourceBundles>
+ <resourceBundle>org.apache:apache-jar-resource-bundle:1.4</resourceBundle>
+ </resourceBundles>
+ <supplementalModels>
+ <supplementalModel>supplemental-models.xml</supplementalModel>
+ </supplementalModels>
+ </configuration>
+ </plugin>
+
+ </plugins>
+ </build>
+
+ <dependencies>
+
+ <!-- our modules -->
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-math</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-math</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+
+ <!-- Third Party -->
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.codehaus.jackson</groupId>
+ <artifactId>jackson-core-asl</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.codehaus.jackson</groupId>
+ <artifactId>jackson-mapper-asl</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-jcl</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-lang3</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>commons-cli</groupId>
+ <artifactId>commons-cli</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>com.thoughtworks.xstream</groupId>
+ <artifactId>xstream</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.lucene</groupId>
+ <artifactId>lucene-core</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.lucene</groupId>
+ <artifactId>lucene-analyzers-common</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.mahout.commons</groupId>
+ <artifactId>commons-cli</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-math3</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>com.carrotsearch.randomizedtesting</groupId>
+ <artifactId>randomizedtesting-runner</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymock</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.mrunit</groupId>
+ <artifactId>mrunit</artifactId>
+ <version>1.0.0</version>
+ <classifier>${hadoop.classifier}</classifier>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>commons-httpclient</groupId>
+ <artifactId>commons-httpclient</artifactId>
+ <version>3.0.1</version>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.solr</groupId>
+ <artifactId>solr-commons-csv</artifactId>
+ <version>3.5.0</version>
+ </dependency>
+
+ </dependencies>
+
+</project>
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/hdfs/src/main/java/org/apache/mahout/common/IOUtils.java
----------------------------------------------------------------------
diff --git a/hdfs/src/main/java/org/apache/mahout/common/IOUtils.java b/hdfs/src/main/java/org/apache/mahout/common/IOUtils.java
new file mode 100644
index 0000000..0372ed4
--- /dev/null
+++ b/hdfs/src/main/java/org/apache/mahout/common/IOUtils.java
@@ -0,0 +1,194 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common;
+
+import java.io.Closeable;
+import java.io.File;
+import java.io.IOException;
+import java.sql.Connection;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Statement;
+import java.util.Collection;
+
+import org.apache.hadoop.mapred.lib.MultipleOutputs;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>
+ * I/O-related utility methods that don't have a better home.
+ * </p>
+ */
+public final class IOUtils {
+
+ private static final Logger log = LoggerFactory.getLogger(IOUtils.class);
+
+ private IOUtils() { }
+
+ // Sheez, why can't ResultSet, Statement and Connection implement Closeable?
+
+ public static void quietClose(ResultSet closeable) {
+ if (closeable != null) {
+ try {
+ closeable.close();
+ } catch (SQLException sqle) {
+ log.warn("Unexpected exception while closing; continuing", sqle);
+ }
+ }
+ }
+
+ public static void quietClose(Statement closeable) {
+ if (closeable != null) {
+ try {
+ closeable.close();
+ } catch (SQLException sqle) {
+ log.warn("Unexpected exception while closing; continuing", sqle);
+ }
+ }
+ }
+
+ public static void quietClose(Connection closeable) {
+ if (closeable != null) {
+ try {
+ closeable.close();
+ } catch (SQLException sqle) {
+ log.warn("Unexpected exception while closing; continuing", sqle);
+ }
+ }
+ }
+
+ /**
+ * Closes a {@link ResultSet}, {@link Statement} and {@link Connection} (if not null) and logs (but does not
+ * rethrow) any resulting {@link SQLException}. This is useful for cleaning up after a database query.
+ *
+ * @param resultSet
+ * {@link ResultSet} to close
+ * @param statement
+ * {@link Statement} to close
+ * @param connection
+ * {@link Connection} to close
+ */
+ public static void quietClose(ResultSet resultSet, Statement statement, Connection connection) {
+ quietClose(resultSet);
+ quietClose(statement);
+ quietClose(connection);
+ }
+
+ /**
+ * make sure to close all sources, log all of the problems occurred, clear
+ * {@code closeables} (to prevent repeating close attempts), re-throw the
+ * last one at the end. Helps resource scope management (e.g. compositions of
+ * {@link Closeable}s objects)
+ * <P>
+ * <p/>
+ * Typical pattern:
+ * <p/>
+ *
+ * <pre>
+ * LinkedList<Closeable> closeables = new LinkedList<Closeable>();
+ * try {
+ * InputStream stream1 = new FileInputStream(...);
+ * closeables.addFirst(stream1);
+ * ...
+ * InputStream streamN = new FileInputStream(...);
+ * closeables.addFirst(streamN);
+ * ...
+ * } finally {
+ * IOUtils.close(closeables);
+ * }
+ * </pre>
+ *
+ * @param closeables
+ * must be a modifiable collection of {@link Closeable}s
+ * @throws IOException
+ * the last exception (if any) of all closed resources
+ */
+ public static void close(Collection<? extends Closeable> closeables)
+ throws IOException {
+ Throwable lastThr = null;
+
+ for (Closeable closeable : closeables) {
+ try {
+ closeable.close();
+ } catch (Throwable thr) {
+ log.error(thr.getMessage(), thr);
+ lastThr = thr;
+ }
+ }
+
+ // make sure we don't double-close
+ // but that has to be modifiable collection
+ closeables.clear();
+
+ if (lastThr != null) {
+ if (lastThr instanceof IOException) {
+ throw (IOException) lastThr;
+ } else if (lastThr instanceof RuntimeException) {
+ throw (RuntimeException) lastThr;
+ } else {
+ throw (Error) lastThr;
+ }
+ }
+
+ }
+
+
+ /**
+ * for temporary files, a file may be considered as a {@link Closeable} too,
+ * where file is wiped on close and thus the disk resource is released
+ * ('closed').
+ *
+ *
+ */
+ public static class DeleteFileOnClose implements Closeable {
+
+ private final File file;
+
+ public DeleteFileOnClose(File file) {
+ this.file = file;
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (file.isFile()) {
+ file.delete();
+ }
+ }
+ }
+
+ /**
+ * MultipleOutputs to closeable adapter.
+ *
+ */
+ public static class MultipleOutputsCloseableAdapter implements Closeable {
+ private final MultipleOutputs mo;
+
+ public MultipleOutputsCloseableAdapter(MultipleOutputs mo) {
+ this.mo = mo;
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (mo != null) {
+ mo.close();
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/hdfs/src/main/java/org/apache/mahout/math/MatrixWritable.java
----------------------------------------------------------------------
diff --git a/hdfs/src/main/java/org/apache/mahout/math/MatrixWritable.java b/hdfs/src/main/java/org/apache/mahout/math/MatrixWritable.java
new file mode 100644
index 0000000..c521f3e
--- /dev/null
+++ b/hdfs/src/main/java/org/apache/mahout/math/MatrixWritable.java
@@ -0,0 +1,202 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Maps;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.list.IntArrayList;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Map;
+
+public class MatrixWritable implements Writable {
+
+ private static final int FLAG_DENSE = 0x01;
+ private static final int FLAG_SEQUENTIAL = 0x02;
+ private static final int FLAG_LABELS = 0x04;
+ private static final int FLAG_SPARSE_ROW = 0x08;
+ private static final int NUM_FLAGS = 4;
+
+ private Matrix matrix;
+
+ public MatrixWritable() {}
+
+ public MatrixWritable(Matrix m) {
+ this.matrix = m;
+ }
+
+ public Matrix get() {
+ return matrix;
+ }
+
+ public void set(Matrix matrix) {
+ this.matrix = matrix;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ writeMatrix(out, matrix);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ matrix = readMatrix(in);
+ }
+
+ public static void readLabels(DataInput in,
+ Map<String, Integer> columnLabelBindings,
+ Map<String, Integer> rowLabelBindings) throws IOException {
+ int colSize = in.readInt();
+ if (colSize > 0) {
+ for (int i = 0; i < colSize; i++) {
+ columnLabelBindings.put(in.readUTF(), in.readInt());
+ }
+ }
+ int rowSize = in.readInt();
+ if (rowSize > 0) {
+ for (int i = 0; i < rowSize; i++) {
+ rowLabelBindings.put(in.readUTF(), in.readInt());
+ }
+ }
+ }
+
+ public static void writeLabelBindings(DataOutput out,
+ Map<String, Integer> columnLabelBindings,
+ Map<String, Integer> rowLabelBindings) throws IOException {
+ if (columnLabelBindings == null) {
+ out.writeInt(0);
+ } else {
+ out.writeInt(columnLabelBindings.size());
+ for (Map.Entry<String, Integer> stringIntegerEntry : columnLabelBindings.entrySet()) {
+ out.writeUTF(stringIntegerEntry.getKey());
+ out.writeInt(stringIntegerEntry.getValue());
+ }
+ }
+ if (rowLabelBindings == null) {
+ out.writeInt(0);
+ } else {
+ out.writeInt(rowLabelBindings.size());
+ for (Map.Entry<String, Integer> stringIntegerEntry : rowLabelBindings.entrySet()) {
+ out.writeUTF(stringIntegerEntry.getKey());
+ out.writeInt(stringIntegerEntry.getValue());
+ }
+ }
+ }
+
+ /** Reads a typed Matrix instance from the input stream */
+ public static Matrix readMatrix(DataInput in) throws IOException {
+ int flags = in.readInt();
+ Preconditions.checkArgument(flags >> NUM_FLAGS == 0, "Unknown flags set: %d", Integer.toString(flags, 2));
+ boolean dense = (flags & FLAG_DENSE) != 0;
+ boolean sequential = (flags & FLAG_SEQUENTIAL) != 0;
+ boolean hasLabels = (flags & FLAG_LABELS) != 0;
+ boolean isSparseRowMatrix = (flags & FLAG_SPARSE_ROW) != 0;
+
+ int rows = in.readInt();
+ int columns = in.readInt();
+
+ byte vectorFlags = in.readByte();
+
+ Matrix matrix;
+
+ if (dense) {
+ matrix = new DenseMatrix(rows, columns);
+ for (int row = 0; row < rows; row++) {
+ matrix.assignRow(row, VectorWritable.readVector(in, vectorFlags, columns));
+ }
+ } else if (isSparseRowMatrix) {
+ Vector[] rowVectors = new Vector[rows];
+ for (int row = 0; row < rows; row++) {
+ rowVectors[row] = VectorWritable.readVector(in, vectorFlags, columns);
+ }
+ matrix = new SparseRowMatrix(rows, columns, rowVectors, true, !sequential);
+ } else {
+ matrix = new SparseMatrix(rows, columns);
+ int numNonZeroRows = in.readInt();
+ int rowsRead = 0;
+ while (rowsRead++ < numNonZeroRows) {
+ int rowIndex = in.readInt();
+ matrix.assignRow(rowIndex, VectorWritable.readVector(in, vectorFlags, columns));
+ }
+ }
+
+ if (hasLabels) {
+ Map<String,Integer> columnLabelBindings = Maps.newHashMap();
+ Map<String,Integer> rowLabelBindings = Maps.newHashMap();
+ readLabels(in, columnLabelBindings, rowLabelBindings);
+ if (!columnLabelBindings.isEmpty()) {
+ matrix.setColumnLabelBindings(columnLabelBindings);
+ }
+ if (!rowLabelBindings.isEmpty()) {
+ matrix.setRowLabelBindings(rowLabelBindings);
+ }
+ }
+
+ return matrix;
+ }
+
+ /** Writes a typed Matrix instance to the output stream */
+ public static void writeMatrix(final DataOutput out, Matrix matrix) throws IOException {
+ int flags = 0;
+ Vector row = matrix.viewRow(0);
+ boolean isDense = row.isDense();
+ if (isDense) {
+ flags |= FLAG_DENSE;
+ }
+ if (row.isSequentialAccess()) {
+ flags |= FLAG_SEQUENTIAL;
+ }
+ if (matrix.getRowLabelBindings() != null || matrix.getColumnLabelBindings() != null) {
+ flags |= FLAG_LABELS;
+ }
+ boolean isSparseRowMatrix = matrix instanceof SparseRowMatrix;
+ if (isSparseRowMatrix) {
+ flags |= FLAG_SPARSE_ROW;
+ }
+
+ out.writeInt(flags);
+ out.writeInt(matrix.rowSize());
+ out.writeInt(matrix.columnSize());
+
+ // We only use vectors of the same type, so we write out the type information only once!
+ byte vectorFlags = VectorWritable.flags(matrix.viewRow(0), false);
+ out.writeByte(vectorFlags);
+
+ if (isDense || isSparseRowMatrix) {
+ for (int i = 0; i < matrix.rowSize(); i++) {
+ VectorWritable.writeVectorContents(out, matrix.viewRow(i), vectorFlags);
+ }
+ } else {
+ IntArrayList rowIndices = ((SparseMatrix) matrix).nonZeroRowIndices();
+ int numNonZeroRows = rowIndices.size();
+ out.writeInt(numNonZeroRows);
+ for (int i = 0; i < numNonZeroRows; i++) {
+ int rowIndex = rowIndices.getQuick(i);
+ out.writeInt(rowIndex);
+ VectorWritable.writeVectorContents(out, matrix.viewRow(rowIndex), vectorFlags);
+ }
+ }
+
+ if ((flags & FLAG_LABELS) != 0) {
+ writeLabelBindings(out, matrix.getColumnLabelBindings(), matrix.getRowLabelBindings());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/hdfs/src/main/java/org/apache/mahout/math/VarIntWritable.java
----------------------------------------------------------------------
diff --git a/hdfs/src/main/java/org/apache/mahout/math/VarIntWritable.java b/hdfs/src/main/java/org/apache/mahout/math/VarIntWritable.java
new file mode 100644
index 0000000..e5cb173
--- /dev/null
+++ b/hdfs/src/main/java/org/apache/mahout/math/VarIntWritable.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.WritableComparable;
+
+public class VarIntWritable implements WritableComparable<VarIntWritable>, Cloneable {
+
+ private int value;
+
+ public VarIntWritable() {
+ }
+
+ public VarIntWritable(int value) {
+ this.value = value;
+ }
+
+ public int get() {
+ return value;
+ }
+
+ public void set(int value) {
+ this.value = value;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ return other instanceof VarIntWritable && ((VarIntWritable) other).value == value;
+ }
+
+ @Override
+ public int hashCode() {
+ return value;
+ }
+
+ @Override
+ public String toString() {
+ return String.valueOf(value);
+ }
+
+ @Override
+ public VarIntWritable clone() {
+ return new VarIntWritable(value);
+ }
+
+ @Override
+ public int compareTo(VarIntWritable other) {
+ if (value < other.value) {
+ return -1;
+ }
+ if (value > other.value) {
+ return 1;
+ }
+ return 0;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ Varint.writeSignedVarInt(value, out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ value = Varint.readSignedVarInt(in);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/hdfs/src/main/java/org/apache/mahout/math/VarLongWritable.java
----------------------------------------------------------------------
diff --git a/hdfs/src/main/java/org/apache/mahout/math/VarLongWritable.java b/hdfs/src/main/java/org/apache/mahout/math/VarLongWritable.java
new file mode 100644
index 0000000..7b0d9c4
--- /dev/null
+++ b/hdfs/src/main/java/org/apache/mahout/math/VarLongWritable.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import com.google.common.primitives.Longs;
+import org.apache.hadoop.io.WritableComparable;
+
+public class VarLongWritable implements WritableComparable<VarLongWritable> {
+
+ private long value;
+
+ public VarLongWritable() {
+ }
+
+ public VarLongWritable(long value) {
+ this.value = value;
+ }
+
+ public long get() {
+ return value;
+ }
+
+ public void set(long value) {
+ this.value = value;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ return other != null && getClass().equals(other.getClass()) && ((VarLongWritable) other).value == value;
+ }
+
+ @Override
+ public int hashCode() {
+ return Longs.hashCode(value);
+ }
+
+ @Override
+ public String toString() {
+ return String.valueOf(value);
+ }
+
+ @Override
+ public int compareTo(VarLongWritable other) {
+ if (value >= other.value) {
+ if (value > other.value) {
+ return 1;
+ }
+ } else {
+ return -1;
+ }
+ return 0;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ Varint.writeSignedVarLong(value, out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ value = Varint.readSignedVarLong(in);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/hdfs/src/main/java/org/apache/mahout/math/Varint.java
----------------------------------------------------------------------
diff --git a/hdfs/src/main/java/org/apache/mahout/math/Varint.java b/hdfs/src/main/java/org/apache/mahout/math/Varint.java
new file mode 100644
index 0000000..f380c6c
--- /dev/null
+++ b/hdfs/src/main/java/org/apache/mahout/math/Varint.java
@@ -0,0 +1,167 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>Encodes signed and unsigned values using a common variable-length
+ * scheme, found for example in
+ * <a href="http://code.google.com/apis/protocolbuffers/docs/encoding.html">
+ * Google's Protocol Buffers</a>. It uses fewer bytes to encode smaller values,
+ * but will use slightly more bytes to encode large values.</p>
+ *
+ * <p>Signed values are further encoded using so-called zig-zag encoding
+ * in order to make them "compatible" with variable-length encoding.</p>
+ */
+public final class Varint {
+
+ private Varint() {
+ }
+
+ /**
+ * Encodes a value using the variable-length encoding from
+ * <a href="http://code.google.com/apis/protocolbuffers/docs/encoding.html">
+ * Google Protocol Buffers</a>. It uses zig-zag encoding to efficiently
+ * encode signed values. If values are known to be nonnegative,
+ * {@link #writeUnsignedVarLong(long, java.io.DataOutput)} should be used.
+ *
+ * @param value value to encode
+ * @param out to write bytes to
+ * @throws java.io.IOException if {@link java.io.DataOutput} throws {@link java.io.IOException}
+ */
+ public static void writeSignedVarLong(long value, DataOutput out) throws IOException {
+ // Great trick from http://code.google.com/apis/protocolbuffers/docs/encoding.html#types
+ writeUnsignedVarLong((value << 1) ^ (value >> 63), out);
+ }
+
+ /**
+ * Encodes a value using the variable-length encoding from
+ * <a href="http://code.google.com/apis/protocolbuffers/docs/encoding.html">
+ * Google Protocol Buffers</a>. Zig-zag is not used, so input must not be negative.
+ * If values can be negative, use {@link #writeSignedVarLong(long, java.io.DataOutput)}
+ * instead. This method treats negative input as like a large unsigned value.
+ *
+ * @param value value to encode
+ * @param out to write bytes to
+ * @throws java.io.IOException if {@link java.io.DataOutput} throws {@link java.io.IOException}
+ */
+ public static void writeUnsignedVarLong(long value, DataOutput out) throws IOException {
+ while ((value & 0xFFFFFFFFFFFFFF80L) != 0L) {
+ out.writeByte(((int) value & 0x7F) | 0x80);
+ value >>>= 7;
+ }
+ out.writeByte((int) value & 0x7F);
+ }
+
+ /**
+ * @see #writeSignedVarLong(long, java.io.DataOutput)
+ */
+ public static void writeSignedVarInt(int value, DataOutput out) throws IOException {
+ // Great trick from http://code.google.com/apis/protocolbuffers/docs/encoding.html#types
+ writeUnsignedVarInt((value << 1) ^ (value >> 31), out);
+ }
+
+ /**
+ * @see #writeUnsignedVarLong(long, java.io.DataOutput)
+ */
+ public static void writeUnsignedVarInt(int value, DataOutput out) throws IOException {
+ while ((value & 0xFFFFFF80) != 0L) {
+ out.writeByte((value & 0x7F) | 0x80);
+ value >>>= 7;
+ }
+ out.writeByte(value & 0x7F);
+ }
+
+ /**
+ * @param in to read bytes from
+ * @return decode value
+ * @throws java.io.IOException if {@link java.io.DataInput} throws {@link java.io.IOException}
+ * @throws IllegalArgumentException if variable-length value does not terminate
+ * after 9 bytes have been read
+ * @see #writeSignedVarLong(long, java.io.DataOutput)
+ */
+ public static long readSignedVarLong(DataInput in) throws IOException {
+ long raw = readUnsignedVarLong(in);
+ // This undoes the trick in writeSignedVarLong()
+ long temp = (((raw << 63) >> 63) ^ raw) >> 1;
+ // This extra step lets us deal with the largest signed values by treating
+ // negative results from read unsigned methods as like unsigned values
+ // Must re-flip the top bit if the original read value had it set.
+ return temp ^ (raw & (1L << 63));
+ }
+
+ /**
+ * @param in to read bytes from
+ * @return decode value
+ * @throws java.io.IOException if {@link java.io.DataInput} throws {@link java.io.IOException}
+ * @throws IllegalArgumentException if variable-length value does not terminate
+ * after 9 bytes have been read
+ * @see #writeUnsignedVarLong(long, java.io.DataOutput)
+ */
+ public static long readUnsignedVarLong(DataInput in) throws IOException {
+ long value = 0L;
+ int i = 0;
+ long b;
+ while (((b = in.readByte()) & 0x80L) != 0) {
+ value |= (b & 0x7F) << i;
+ i += 7;
+ Preconditions.checkArgument(i <= 63, "Variable length quantity is too long (must be <= 63)");
+ }
+ return value | (b << i);
+ }
+
+ /**
+ * @throws IllegalArgumentException if variable-length value does not terminate
+ * after 5 bytes have been read
+ * @throws java.io.IOException if {@link java.io.DataInput} throws {@link java.io.IOException}
+ * @see #readSignedVarLong(java.io.DataInput)
+ */
+ public static int readSignedVarInt(DataInput in) throws IOException {
+ int raw = readUnsignedVarInt(in);
+ // This undoes the trick in writeSignedVarInt()
+ int temp = (((raw << 31) >> 31) ^ raw) >> 1;
+ // This extra step lets us deal with the largest signed values by treating
+ // negative results from read unsigned methods as like unsigned values.
+ // Must re-flip the top bit if the original read value had it set.
+ return temp ^ (raw & (1 << 31));
+ }
+
+ /**
+ * @throws IllegalArgumentException if variable-length value does not terminate
+ * after 5 bytes have been read
+ * @throws java.io.IOException if {@link java.io.DataInput} throws {@link java.io.IOException}
+ * @see #readUnsignedVarLong(java.io.DataInput)
+ */
+ public static int readUnsignedVarInt(DataInput in) throws IOException {
+ int value = 0;
+ int i = 0;
+ int b;
+ while (((b = in.readByte()) & 0x80) != 0) {
+ value |= (b & 0x7F) << i;
+ i += 7;
+ Preconditions.checkArgument(i <= 35, "Variable length quantity is too long (must be <= 35)");
+ }
+ return value | (b << i);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/hdfs/src/main/java/org/apache/mahout/math/VectorWritable.java
----------------------------------------------------------------------
diff --git a/hdfs/src/main/java/org/apache/mahout/math/VectorWritable.java b/hdfs/src/main/java/org/apache/mahout/math/VectorWritable.java
new file mode 100644
index 0000000..491ae3b
--- /dev/null
+++ b/hdfs/src/main/java/org/apache/mahout/math/VectorWritable.java
@@ -0,0 +1,267 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license
+ * agreements. See the NOTICE file distributed with this work for additional information regarding
+ * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License. You may obtain a
+ * copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License
+ * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+ * or implied. See the License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configured;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Vector.Element;
+
+import com.google.common.base.Preconditions;
+
+public final class VectorWritable extends Configured implements Writable {
+
+ public static final int FLAG_DENSE = 0x01;
+ public static final int FLAG_SEQUENTIAL = 0x02;
+ public static final int FLAG_NAMED = 0x04;
+ public static final int FLAG_LAX_PRECISION = 0x08;
+ public static final int NUM_FLAGS = 4;
+
+ private Vector vector;
+ private boolean writesLaxPrecision;
+
+ public VectorWritable() {}
+
+ public VectorWritable(boolean writesLaxPrecision) {
+ setWritesLaxPrecision(writesLaxPrecision);
+ }
+
+ public VectorWritable(Vector vector) {
+ this.vector = vector;
+ }
+
+ public VectorWritable(Vector vector, boolean writesLaxPrecision) {
+ this(vector);
+ setWritesLaxPrecision(writesLaxPrecision);
+ }
+
+ /**
+ * @return {@link org.apache.mahout.math.Vector} that this is to write, or has
+ * just read
+ */
+ public Vector get() {
+ return vector;
+ }
+
+ public void set(Vector vector) {
+ this.vector = vector;
+ }
+
+ /**
+ * @return true if this is allowed to encode {@link org.apache.mahout.math.Vector}
+ * values using fewer bytes, possibly losing precision. In particular this means
+ * that floating point values will be encoded as floats, not doubles.
+ */
+ public boolean isWritesLaxPrecision() {
+ return writesLaxPrecision;
+ }
+
+ public void setWritesLaxPrecision(boolean writesLaxPrecision) {
+ this.writesLaxPrecision = writesLaxPrecision;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ writeVector(out, this.vector, this.writesLaxPrecision);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int flags = in.readByte();
+ int size = Varint.readUnsignedVarInt(in);
+ readFields(in, (byte) flags, size);
+ }
+
+ private void readFields(DataInput in, byte flags, int size) throws IOException {
+
+ Preconditions.checkArgument(flags >> NUM_FLAGS == 0, "Unknown flags set: %d", Integer.toString(flags, 2));
+ boolean dense = (flags & FLAG_DENSE) != 0;
+ boolean sequential = (flags & FLAG_SEQUENTIAL) != 0;
+ boolean named = (flags & FLAG_NAMED) != 0;
+ boolean laxPrecision = (flags & FLAG_LAX_PRECISION) != 0;
+
+ Vector v;
+ if (dense) {
+ double[] values = new double[size];
+ for (int i = 0; i < size; i++) {
+ values[i] = laxPrecision ? in.readFloat() : in.readDouble();
+ }
+ v = new DenseVector(values);
+ } else {
+ int numNonDefaultElements = Varint.readUnsignedVarInt(in);
+ v = sequential
+ ? new SequentialAccessSparseVector(size, numNonDefaultElements)
+ : new RandomAccessSparseVector(size, numNonDefaultElements);
+ if (sequential) {
+ int lastIndex = 0;
+ for (int i = 0; i < numNonDefaultElements; i++) {
+ int delta = Varint.readUnsignedVarInt(in);
+ int index = lastIndex + delta;
+ lastIndex = index;
+ double value = laxPrecision ? in.readFloat() : in.readDouble();
+ v.setQuick(index, value);
+ }
+ } else {
+ for (int i = 0; i < numNonDefaultElements; i++) {
+ int index = Varint.readUnsignedVarInt(in);
+ double value = laxPrecision ? in.readFloat() : in.readDouble();
+ v.setQuick(index, value);
+ }
+ }
+ }
+ if (named) {
+ String name = in.readUTF();
+ v = new NamedVector(v, name);
+ }
+ vector = v;
+ }
+
+ /** Write the vector to the output */
+ public static void writeVector(DataOutput out, Vector vector) throws IOException {
+ writeVector(out, vector, false);
+ }
+
+ public static byte flags(Vector vector, boolean laxPrecision) {
+ boolean dense = vector.isDense();
+ boolean sequential = vector.isSequentialAccess();
+ boolean named = vector instanceof NamedVector;
+
+ return (byte) ((dense ? FLAG_DENSE : 0)
+ | (sequential ? FLAG_SEQUENTIAL : 0)
+ | (named ? FLAG_NAMED : 0)
+ | (laxPrecision ? FLAG_LAX_PRECISION : 0));
+ }
+
+ /** Write out type information and size of the vector */
+ public static void writeVectorFlagsAndSize(DataOutput out, byte flags, int size) throws IOException {
+ out.writeByte(flags);
+ Varint.writeUnsignedVarInt(size, out);
+ }
+
+ public static void writeVector(DataOutput out, Vector vector, boolean laxPrecision) throws IOException {
+
+ byte flags = flags(vector, laxPrecision);
+
+ writeVectorFlagsAndSize(out, flags, vector.size());
+ writeVectorContents(out, vector, flags);
+ }
+
+ /** Write out contents of the vector */
+ public static void writeVectorContents(DataOutput out, Vector vector, byte flags) throws IOException {
+
+ boolean dense = (flags & FLAG_DENSE) != 0;
+ boolean sequential = (flags & FLAG_SEQUENTIAL) != 0;
+ boolean named = (flags & FLAG_NAMED) != 0;
+ boolean laxPrecision = (flags & FLAG_LAX_PRECISION) != 0;
+
+ if (dense) {
+ for (Element element : vector.all()) {
+ if (laxPrecision) {
+ out.writeFloat((float) element.get());
+ } else {
+ out.writeDouble(element.get());
+ }
+ }
+ } else {
+ Varint.writeUnsignedVarInt(vector.getNumNonZeroElements(), out);
+ Iterator<Element> iter = vector.nonZeroes().iterator();
+ if (sequential) {
+ int lastIndex = 0;
+ while (iter.hasNext()) {
+ Element element = iter.next();
+ if (element.get() == 0) {
+ continue;
+ }
+ int thisIndex = element.index();
+ // Delta-code indices:
+ Varint.writeUnsignedVarInt(thisIndex - lastIndex, out);
+ lastIndex = thisIndex;
+ if (laxPrecision) {
+ out.writeFloat((float) element.get());
+ } else {
+ out.writeDouble(element.get());
+ }
+ }
+ } else {
+ while (iter.hasNext()) {
+ Element element = iter.next();
+ if (element.get() == 0) {
+ // TODO(robinanil): Fix the damn iterator for the zero element.
+ continue;
+ }
+ Varint.writeUnsignedVarInt(element.index(), out);
+ if (laxPrecision) {
+ out.writeFloat((float) element.get());
+ } else {
+ out.writeDouble(element.get());
+ }
+ }
+ }
+ }
+ if (named) {
+ String name = ((NamedVector) vector).getName();
+ out.writeUTF(name == null ? "" : name);
+ }
+ }
+
+ public static Vector readVector(DataInput in) throws IOException {
+ VectorWritable v = new VectorWritable();
+ v.readFields(in);
+ return v.get();
+ }
+
+ public static Vector readVector(DataInput in, byte vectorFlags, int size) throws IOException {
+ VectorWritable v = new VectorWritable();
+ v.readFields(in, vectorFlags, size);
+ return v.get();
+ }
+
+ public static VectorWritable merge(Iterator<VectorWritable> vectors) {
+ return new VectorWritable(mergeToVector(vectors));
+ }
+
+ public static Vector mergeToVector(Iterator<VectorWritable> vectors) {
+ Vector accumulator = vectors.next().get();
+ while (vectors.hasNext()) {
+ VectorWritable v = vectors.next();
+ if (v != null) {
+ for (Element nonZeroElement : v.get().nonZeroes()) {
+ accumulator.setQuick(nonZeroElement.index(), nonZeroElement.get());
+ }
+ }
+ }
+ return accumulator;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return o instanceof VectorWritable && vector.equals(((VectorWritable) o).get());
+ }
+
+ @Override
+ public int hashCode() {
+ return vector.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return vector.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/hdfs/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
----------------------------------------------------------------------
diff --git a/hdfs/src/test/java/org/apache/mahout/math/MatrixWritableTest.java b/hdfs/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
new file mode 100644
index 0000000..226d4b1
--- /dev/null
+++ b/hdfs/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
@@ -0,0 +1,148 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+import org.junit.Test;
+
+public final class MatrixWritableTest extends MahoutTestCase {
+
+ @Test
+ public void testSparseMatrixWritable() throws Exception {
+ Matrix m = new SparseMatrix(5, 5);
+ m.set(1, 2, 3.0);
+ m.set(3, 4, 5.0);
+ Map<String, Integer> bindings = Maps.newHashMap();
+ bindings.put("A", 0);
+ bindings.put("B", 1);
+ bindings.put("C", 2);
+ bindings.put("D", 3);
+ bindings.put("default", 4);
+ m.setRowLabelBindings(bindings);
+ m.setColumnLabelBindings(bindings);
+ doTestMatrixWritableEquals(m);
+ }
+
+ @Test
+ public void testSparseRowMatrixWritable() throws Exception {
+ Matrix m = new SparseRowMatrix(5, 5);
+ m.set(1, 2, 3.0);
+ m.set(3, 4, 5.0);
+ Map<String, Integer> bindings = Maps.newHashMap();
+ bindings.put("A", 0);
+ bindings.put("B", 1);
+ bindings.put("C", 2);
+ bindings.put("D", 3);
+ bindings.put("default", 4);
+ m.setRowLabelBindings(bindings);
+ m.setColumnLabelBindings(bindings);
+ doTestMatrixWritableEquals(m);
+ }
+
+ @Test
+ public void testDenseMatrixWritable() throws Exception {
+ Matrix m = new DenseMatrix(5,5);
+ m.set(1, 2, 3.0);
+ m.set(3, 4, 5.0);
+ Map<String, Integer> bindings = Maps.newHashMap();
+ bindings.put("A", 0);
+ bindings.put("B", 1);
+ bindings.put("C", 2);
+ bindings.put("D", 3);
+ bindings.put("default", 4);
+ m.setRowLabelBindings(bindings);
+ m.setColumnLabelBindings(bindings);
+ doTestMatrixWritableEquals(m);
+ }
+
+ private static void doTestMatrixWritableEquals(Matrix m) throws IOException {
+ Writable matrixWritable = new MatrixWritable(m);
+ MatrixWritable matrixWritable2 = new MatrixWritable();
+ writeAndRead(matrixWritable, matrixWritable2);
+ Matrix m2 = matrixWritable2.get();
+ compareMatrices(m, m2);
+ doCheckBindings(m2.getRowLabelBindings());
+ doCheckBindings(m2.getColumnLabelBindings());
+ }
+
+ private static void compareMatrices(Matrix m, Matrix m2) {
+ assertEquals(m.numRows(), m2.numRows());
+ assertEquals(m.numCols(), m2.numCols());
+ for (int r = 0; r < m.numRows(); r++) {
+ for (int c = 0; c < m.numCols(); c++) {
+ assertEquals(m.get(r, c), m2.get(r, c), EPSILON);
+ }
+ }
+ Map<String,Integer> bindings = m.getRowLabelBindings();
+ Map<String, Integer> bindings2 = m2.getRowLabelBindings();
+ assertEquals(bindings == null, bindings2 == null);
+ if (bindings != null) {
+ assertEquals(bindings.size(), m.numRows());
+ assertEquals(bindings.size(), bindings2.size());
+ for (Map.Entry<String,Integer> entry : bindings.entrySet()) {
+ assertEquals(entry.getValue(), bindings2.get(entry.getKey()));
+ }
+ }
+ bindings = m.getColumnLabelBindings();
+ bindings2 = m2.getColumnLabelBindings();
+ assertEquals(bindings == null, bindings2 == null);
+ if (bindings != null) {
+ assertEquals(bindings.size(), bindings2.size());
+ for (Map.Entry<String,Integer> entry : bindings.entrySet()) {
+ assertEquals(entry.getValue(), bindings2.get(entry.getKey()));
+ }
+ }
+ }
+
+ private static void doCheckBindings(Map<String,Integer> labels) {
+ assertTrue("Missing label", labels.keySet().contains("A"));
+ assertTrue("Missing label", labels.keySet().contains("B"));
+ assertTrue("Missing label", labels.keySet().contains("C"));
+ assertTrue("Missing label", labels.keySet().contains("D"));
+ assertTrue("Missing label", labels.keySet().contains("default"));
+ }
+
+ private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(baos);
+ try {
+ toWrite.write(dos);
+ } finally {
+ Closeables.close(dos, false);
+ }
+
+ ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
+ DataInputStream dis = new DataInputStream(bais);
+ try {
+ toRead.readFields(dis);
+ } finally {
+ Closeables.close(dis, true);
+ }
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/hdfs/src/test/java/org/apache/mahout/math/VarintTest.java
----------------------------------------------------------------------
diff --git a/hdfs/src/test/java/org/apache/mahout/math/VarintTest.java b/hdfs/src/test/java/org/apache/mahout/math/VarintTest.java
new file mode 100644
index 0000000..0b1a664
--- /dev/null
+++ b/hdfs/src/test/java/org/apache/mahout/math/VarintTest.java
@@ -0,0 +1,189 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+
+/**
+ * Tests {@link Varint}.
+ */
+public final class VarintTest extends MahoutTestCase {
+
+ @Test
+ public void testUnsignedLong() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ Varint.writeUnsignedVarLong(0L, out);
+ for (long i = 1L; i > 0L && i <= (1L << 62); i <<= 1) {
+ Varint.writeUnsignedVarLong(i-1, out);
+ Varint.writeUnsignedVarLong(i, out);
+ }
+ Varint.writeUnsignedVarLong(Long.MAX_VALUE, out);
+
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ assertEquals(0L, Varint.readUnsignedVarLong(in));
+ for (long i = 1L; i > 0L && i <= (1L << 62); i <<= 1) {
+ assertEquals(i-1, Varint.readUnsignedVarLong(in));
+ assertEquals(i, Varint.readUnsignedVarLong(in));
+ }
+ assertEquals(Long.MAX_VALUE, Varint.readUnsignedVarLong(in));
+ }
+
+ @Test
+ public void testSignedPositiveLong() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ Varint.writeSignedVarLong(0L, out);
+ for (long i = 1L; i <= (1L << 61); i <<= 1) {
+ Varint.writeSignedVarLong(i-1, out);
+ Varint.writeSignedVarLong(i, out);
+ }
+ Varint.writeSignedVarLong((1L << 62) - 1, out);
+ Varint.writeSignedVarLong((1L << 62), out);
+ Varint.writeSignedVarLong(Long.MAX_VALUE, out);
+
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ assertEquals(0L, Varint.readSignedVarLong(in));
+ for (long i = 1L; i <= (1L << 61); i <<= 1) {
+ assertEquals(i-1, Varint.readSignedVarLong(in));
+ assertEquals(i, Varint.readSignedVarLong(in));
+ }
+ assertEquals((1L << 62) - 1, Varint.readSignedVarLong(in));
+ assertEquals((1L << 62), Varint.readSignedVarLong(in));
+ assertEquals(Long.MAX_VALUE, Varint.readSignedVarLong(in));
+ }
+
+ @Test
+ public void testSignedNegativeLong() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ for (long i = -1L; i >= -(1L << 62); i <<= 1) {
+ Varint.writeSignedVarLong(i, out);
+ Varint.writeSignedVarLong(i+1, out);
+ }
+ Varint.writeSignedVarLong(Long.MIN_VALUE, out);
+ Varint.writeSignedVarLong(Long.MIN_VALUE+1, out);
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ for (long i = -1L; i >= -(1L << 62); i <<= 1) {
+ assertEquals(i, Varint.readSignedVarLong(in));
+ assertEquals(i+1, Varint.readSignedVarLong(in));
+ }
+ assertEquals(Long.MIN_VALUE, Varint.readSignedVarLong(in));
+ assertEquals(Long.MIN_VALUE+1, Varint.readSignedVarLong(in));
+ }
+
+ @Test
+ public void testUnsignedInt() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ Varint.writeUnsignedVarInt(0, out);
+ for (int i = 1; i > 0 && i <= (1 << 30); i <<= 1) {
+ Varint.writeUnsignedVarLong(i-1, out);
+ Varint.writeUnsignedVarLong(i, out);
+ }
+ Varint.writeUnsignedVarLong(Integer.MAX_VALUE, out);
+
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ assertEquals(0, Varint.readUnsignedVarInt(in));
+ for (int i = 1; i > 0 && i <= (1 << 30); i <<= 1) {
+ assertEquals(i-1, Varint.readUnsignedVarInt(in));
+ assertEquals(i, Varint.readUnsignedVarInt(in));
+ }
+ assertEquals(Integer.MAX_VALUE, Varint.readUnsignedVarInt(in));
+ }
+
+ @Test
+ public void testSignedPositiveInt() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ Varint.writeSignedVarInt(0, out);
+ for (int i = 1; i <= (1 << 29); i <<= 1) {
+ Varint.writeSignedVarLong(i-1, out);
+ Varint.writeSignedVarLong(i, out);
+ }
+ Varint.writeSignedVarInt((1 << 30) - 1, out);
+ Varint.writeSignedVarInt((1 << 30), out);
+ Varint.writeSignedVarInt(Integer.MAX_VALUE, out);
+
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ assertEquals(0, Varint.readSignedVarInt(in));
+ for (int i = 1; i <= (1 << 29); i <<= 1) {
+ assertEquals(i-1, Varint.readSignedVarInt(in));
+ assertEquals(i, Varint.readSignedVarInt(in));
+ }
+ assertEquals((1L << 30) - 1, Varint.readSignedVarInt(in));
+ assertEquals((1L << 30), Varint.readSignedVarInt(in));
+ assertEquals(Integer.MAX_VALUE, Varint.readSignedVarInt(in));
+ }
+
+ @Test
+ public void testSignedNegativeInt() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ for (int i = -1; i >= -(1 << 30); i <<= 1) {
+ Varint.writeSignedVarInt(i, out);
+ Varint.writeSignedVarInt(i+1, out);
+ }
+ Varint.writeSignedVarInt(Integer.MIN_VALUE, out);
+ Varint.writeSignedVarInt(Integer.MIN_VALUE+1, out);
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ for (int i = -1; i >= -(1 << 30); i <<= 1) {
+ assertEquals(i, Varint.readSignedVarInt(in));
+ assertEquals(i+1, Varint.readSignedVarInt(in));
+ }
+ assertEquals(Integer.MIN_VALUE, Varint.readSignedVarInt(in));
+ assertEquals(Integer.MIN_VALUE+1, Varint.readSignedVarInt(in));
+ }
+
+ @Test
+ public void testUnsignedSize() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ int expectedSize = 0;
+ for (int exponent = 0; exponent <= 62; exponent++) {
+ Varint.writeUnsignedVarLong(1L << exponent, out);
+ expectedSize += 1 + exponent / 7;
+ assertEquals(expectedSize, baos.size());
+ }
+ }
+
+ @Test
+ public void testSignedSize() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ int expectedSize = 0;
+ for (int exponent = 0; exponent <= 61; exponent++) {
+ Varint.writeSignedVarLong(1L << exponent, out);
+ expectedSize += 1 + ((exponent + 1) / 7);
+ assertEquals(expectedSize, baos.size());
+ }
+ for (int exponent = 0; exponent <= 61; exponent++) {
+ Varint.writeSignedVarLong(-(1L << exponent)-1, out);
+ expectedSize += 1 + ((exponent + 1) / 7);
+ assertEquals(expectedSize, baos.size());
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/hdfs/src/test/java/org/apache/mahout/math/VectorWritableTest.java
----------------------------------------------------------------------
diff --git a/hdfs/src/test/java/org/apache/mahout/math/VectorWritableTest.java b/hdfs/src/test/java/org/apache/mahout/math/VectorWritableTest.java
new file mode 100644
index 0000000..60fb8b4
--- /dev/null
+++ b/hdfs/src/test/java/org/apache/mahout/math/VectorWritableTest.java
@@ -0,0 +1,123 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license
+ * agreements. See the NOTICE file distributed with this work for additional information regarding
+ * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License. You may obtain a
+ * copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License
+ * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+ * or implied. See the License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Vector.Element;
+import org.junit.Test;
+
+import com.carrotsearch.randomizedtesting.RandomizedTest;
+import com.carrotsearch.randomizedtesting.annotations.Repeat;
+import com.google.common.io.Closeables;
+
+public final class VectorWritableTest extends RandomizedTest {
+ private static final int MAX_VECTOR_SIZE = 100;
+
+ public void createRandom(Vector v) {
+ int size = randomInt(v.size() - 1);
+ for (int i = 0; i < size; ++i) {
+ v.set(randomInt(v.size() - 1), randomDouble());
+ }
+
+ int zeros = Math.max(2, size / 4);
+ for (Element e : v.nonZeroes()) {
+ if (e.index() % zeros == 0) {
+ e.set(0.0);
+ }
+ }
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testViewSequentialAccessSparseVectorWritable() throws Exception {
+ Vector v = new SequentialAccessSparseVector(MAX_VECTOR_SIZE);
+ createRandom(v);
+ Vector view = new VectorView(v, 0, v.size());
+ doTestVectorWritableEquals(view);
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testSequentialAccessSparseVectorWritable() throws Exception {
+ Vector v = new SequentialAccessSparseVector(MAX_VECTOR_SIZE);
+ createRandom(v);
+ doTestVectorWritableEquals(v);
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testRandomAccessSparseVectorWritable() throws Exception {
+ Vector v = new RandomAccessSparseVector(MAX_VECTOR_SIZE);
+ createRandom(v);
+ doTestVectorWritableEquals(v);
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testDenseVectorWritable() throws Exception {
+ Vector v = new DenseVector(MAX_VECTOR_SIZE);
+ createRandom(v);
+ doTestVectorWritableEquals(v);
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testNamedVectorWritable() throws Exception {
+ Vector v = new DenseVector(MAX_VECTOR_SIZE);
+ v = new NamedVector(v, "Victor");
+ createRandom(v);
+ doTestVectorWritableEquals(v);
+ }
+
+ private static void doTestVectorWritableEquals(Vector v) throws IOException {
+ Writable vectorWritable = new VectorWritable(v);
+ VectorWritable vectorWritable2 = new VectorWritable();
+ writeAndRead(vectorWritable, vectorWritable2);
+ Vector v2 = vectorWritable2.get();
+ if (v instanceof NamedVector) {
+ assertTrue(v2 instanceof NamedVector);
+ NamedVector nv = (NamedVector) v;
+ NamedVector nv2 = (NamedVector) v2;
+ assertEquals(nv.getName(), nv2.getName());
+ assertEquals("Victor", nv.getName());
+ }
+ assertEquals(v, v2);
+ }
+
+ private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(baos);
+ try {
+ toWrite.write(dos);
+ } finally {
+ Closeables.close(dos, false);
+ }
+
+ ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
+ DataInputStream dis = new DataInputStream(bais);
+ try {
+ toRead.readFields(dis);
+ } finally {
+ Closeables.close(dos, true);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/integration/pom.xml
----------------------------------------------------------------------
diff --git a/integration/pom.xml b/integration/pom.xml
index bb7077c..f9c1690 100644
--- a/integration/pom.xml
+++ b/integration/pom.xml
@@ -68,11 +68,21 @@
<!-- own modules -->
<dependency>
<groupId>${project.groupId}</groupId>
- <artifactId>mahout-mrlegacy</artifactId>
+ <artifactId>mahout-hdfs</artifactId>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
- <artifactId>mahout-mrlegacy</artifactId>
+ <artifactId>mahout-mr</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-hdfs</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-mr</artifactId>
<type>test-jar</type>
<scope>test</scope>
</dependency>
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/integration/src/main/java/org/apache/mahout/text/PrefixAdditionFilter.java
----------------------------------------------------------------------
diff --git a/integration/src/main/java/org/apache/mahout/text/PrefixAdditionFilter.java b/integration/src/main/java/org/apache/mahout/text/PrefixAdditionFilter.java
index 9c0bc11..a13341b 100644
--- a/integration/src/main/java/org/apache/mahout/text/PrefixAdditionFilter.java
+++ b/integration/src/main/java/org/apache/mahout/text/PrefixAdditionFilter.java
@@ -48,7 +48,7 @@ public final class PrefixAdditionFilter extends SequenceFilesFromDirectoryFilter
protected void process(FileStatus fst, Path current) throws IOException {
FileSystem fs = getFs();
ChunkedWriter writer = getWriter();
- if (fst.isDirectory()) {
+ if (fst.isDir()) {
String dirPath = getPrefix() + Path.SEPARATOR + current.getName() + Path.SEPARATOR + fst.getPath().getName();
fs.listStatus(fst.getPath(),
new PrefixAdditionFilter(getConf(), dirPath, getOptions(), writer, getCharset(), fs));
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/integration/src/main/java/org/apache/mahout/text/ReadOnlyFileSystemDirectory.java
----------------------------------------------------------------------
diff --git a/integration/src/main/java/org/apache/mahout/text/ReadOnlyFileSystemDirectory.java b/integration/src/main/java/org/apache/mahout/text/ReadOnlyFileSystemDirectory.java
index 18c1252..e97e35b 100644
--- a/integration/src/main/java/org/apache/mahout/text/ReadOnlyFileSystemDirectory.java
+++ b/integration/src/main/java/org/apache/mahout/text/ReadOnlyFileSystemDirectory.java
@@ -79,7 +79,7 @@ public class ReadOnlyFileSystemDirectory extends BaseDirectory {
try {
FileStatus status = fs.getFileStatus(directory);
if (status != null) {
- isDir = status.isDirectory();
+ isDir = status.isDir();
}
} catch (IOException e) {
log.error(e.getMessage(), e);
@@ -99,7 +99,7 @@ public class ReadOnlyFileSystemDirectory extends BaseDirectory {
try {
FileStatus status = fs.getFileStatus(directory);
if (status != null) {
- isDir = status.isDirectory();
+ isDir = status.isDir();
}
} catch (IOException e) {
log.error(e.getMessage(), e);
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/integration/src/main/java/org/apache/mahout/utils/SequenceFileDumper.java
----------------------------------------------------------------------
diff --git a/integration/src/main/java/org/apache/mahout/utils/SequenceFileDumper.java b/integration/src/main/java/org/apache/mahout/utils/SequenceFileDumper.java
index bf6691f..2dcc8b0 100644
--- a/integration/src/main/java/org/apache/mahout/utils/SequenceFileDumper.java
+++ b/integration/src/main/java/org/apache/mahout/utils/SequenceFileDumper.java
@@ -62,7 +62,7 @@ public final class SequenceFileDumper extends AbstractJob {
Configuration conf = new Configuration();
Path input = getInputPath();
FileSystem fs = input.getFileSystem(conf);
- if (fs.getFileStatus(input).isDirectory()) {
+ if (fs.getFileStatus(input).isDir()) {
pathArr = FileUtil.stat2Paths(fs.listStatus(input, PathFilters.logsCRCFilter()));
} else {
pathArr = new Path[1];
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/integration/src/main/java/org/apache/mahout/utils/SplitInput.java
----------------------------------------------------------------------
diff --git a/integration/src/main/java/org/apache/mahout/utils/SplitInput.java b/integration/src/main/java/org/apache/mahout/utils/SplitInput.java
index 834d5cd..af22422 100644
--- a/integration/src/main/java/org/apache/mahout/utils/SplitInput.java
+++ b/integration/src/main/java/org/apache/mahout/utils/SplitInput.java
@@ -289,7 +289,7 @@ public class SplitInput extends AbstractJob {
if (fs.getFileStatus(inputDir) == null) {
throw new IOException(inputDir + " does not exist");
}
- if (!fs.getFileStatus(inputDir).isDirectory()) {
+ if (!fs.getFileStatus(inputDir).isDir()) {
throw new IOException(inputDir + " is not a directory");
}
@@ -317,7 +317,7 @@ public class SplitInput extends AbstractJob {
if (fs.getFileStatus(inputFile) == null) {
throw new IOException(inputFile + " does not exist");
}
- if (fs.getFileStatus(inputFile).isDirectory()) {
+ if (fs.getFileStatus(inputFile).isDir()) {
throw new IOException(inputFile + " is a directory");
}
@@ -650,10 +650,10 @@ public class SplitInput extends AbstractJob {
Configuration conf = getConf();
FileSystem fs = trainingOutputDirectory.getFileSystem(conf);
FileStatus trainingOutputDirStatus = fs.getFileStatus(trainingOutputDirectory);
- Preconditions.checkArgument(trainingOutputDirStatus != null && trainingOutputDirStatus.isDirectory(),
+ Preconditions.checkArgument(trainingOutputDirStatus != null && trainingOutputDirStatus.isDir(),
"%s is not a directory", trainingOutputDirectory);
FileStatus testOutputDirStatus = fs.getFileStatus(testOutputDirectory);
- Preconditions.checkArgument(testOutputDirStatus != null && testOutputDirStatus.isDirectory(),
+ Preconditions.checkArgument(testOutputDirStatus != null && testOutputDirStatus.isDir(),
"%s is not a directory", testOutputDirectory);
}
}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/integration/src/main/java/org/apache/mahout/utils/clustering/JsonClusterWriter.java
----------------------------------------------------------------------
diff --git a/integration/src/main/java/org/apache/mahout/utils/clustering/JsonClusterWriter.java b/integration/src/main/java/org/apache/mahout/utils/clustering/JsonClusterWriter.java
index 63399b5..d564a73 100644
--- a/integration/src/main/java/org/apache/mahout/utils/clustering/JsonClusterWriter.java
+++ b/integration/src/main/java/org/apache/mahout/utils/clustering/JsonClusterWriter.java
@@ -83,6 +83,7 @@ public class JsonClusterWriter extends AbstractClusterWriter {
if (dictionary != null) {
Map<String,Object> fmtStr = cluster.asJson(dictionary);
res.put("cluster", fmtStr);
+
// get points
List<Object> points = getPoints(cluster, dictionary);
res.put("points", points);
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/integration/src/main/java/org/apache/mahout/utils/vectors/VectorDumper.java
----------------------------------------------------------------------
diff --git a/integration/src/main/java/org/apache/mahout/utils/vectors/VectorDumper.java b/integration/src/main/java/org/apache/mahout/utils/vectors/VectorDumper.java
index 2a8a42b..9214434 100644
--- a/integration/src/main/java/org/apache/mahout/utils/vectors/VectorDumper.java
+++ b/integration/src/main/java/org/apache/mahout/utils/vectors/VectorDumper.java
@@ -97,7 +97,7 @@ public final class VectorDumper extends AbstractJob {
FileSystem fs = FileSystem.get(conf);
Path input = getInputPath();
FileStatus fileStatus = fs.getFileStatus(input);
- if (fileStatus.isDirectory()) {
+ if (fileStatus.isDir()) {
pathArr = FileUtil.stat2Paths(fs.listStatus(input, PathFilters.logsCRCFilter()));
} else {
FileStatus[] inputPaths = fs.globStatus(input);
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/integration/src/test/java/org/apache/mahout/utils/vectors/lucene/LuceneIterableTest.java
----------------------------------------------------------------------
diff --git a/integration/src/test/java/org/apache/mahout/utils/vectors/lucene/LuceneIterableTest.java b/integration/src/test/java/org/apache/mahout/utils/vectors/lucene/LuceneIterableTest.java
index 1152936..ba49a2d 100644
--- a/integration/src/test/java/org/apache/mahout/utils/vectors/lucene/LuceneIterableTest.java
+++ b/integration/src/test/java/org/apache/mahout/utils/vectors/lucene/LuceneIterableTest.java
@@ -117,7 +117,7 @@ public final class LuceneIterableTest extends MahoutTestCase {
LuceneIterable iterable = new LuceneIterable(reader, "id", "content", termInfo,weight);
Iterator<Vector> iterator = iterable.iterator();
- Iterators.skip(iterator, 1);
+ Iterators.advance(iterator, 1);
}
@Test
@@ -157,10 +157,10 @@ public final class LuceneIterableTest extends MahoutTestCase {
//50 percent tolerance
iterable = new LuceneIterable(reader, "id", "content", termInfo,weight, -1, 0.5);
Iterator<Vector> iterator = iterable.iterator();
- Iterators.skip(iterator, 5);
+ Iterators.advance(iterator, 5);
try {
- Iterators.skip(iterator, Iterators.size(iterator));
+ Iterators.advance(iterator, Iterators.size(iterator));
exceptionThrown = false;
}
catch(IllegalStateException ise) {
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/pom.xml
----------------------------------------------------------------------
diff --git a/mr/pom.xml b/mr/pom.xml
new file mode 100644
index 0000000..0a48150
--- /dev/null
+++ b/mr/pom.xml
@@ -0,0 +1,249 @@
+<?xml version="1.0" encoding="UTF-8"?>
+
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+
+ <parent>
+ <groupId>org.apache.mahout</groupId>
+ <artifactId>mahout</artifactId>
+ <version>1.0-SNAPSHOT</version>
+ <relativePath>../pom.xml</relativePath>
+ </parent>
+
+ <!-- modules inherit parent's group id and version. -->
+ <artifactId>mahout-mr</artifactId>
+ <name>Mahout Map-Reduce</name>
+ <description>Scalable machine learning libraries</description>
+
+ <packaging>jar</packaging>
+
+ <build>
+ <resources>
+ <resource>
+ <directory>src/main/resources</directory>
+ </resource>
+ <resource>
+ <directory>../src/conf</directory>
+ <includes>
+ <include>driver.classes.default.props</include>
+ </includes>
+ </resource>
+ </resources>
+ <plugins>
+ <!-- create test jar so other modules can reuse the core test utility classes. -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-jar-plugin</artifactId>
+ <executions>
+ <execution>
+ <goals>
+ <goal>test-jar</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+
+ <!-- create core hadoop job jar -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-assembly-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>job</id>
+ <phase>package</phase>
+ <goals>
+ <goal>single</goal>
+ </goals>
+ <configuration>
+ <descriptors>
+ <descriptor>src/main/assembly/job.xml</descriptor>
+ </descriptors>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+
+ <plugin>
+ <artifactId>maven-javadoc-plugin</artifactId>
+ </plugin>
+
+ <plugin>
+ <artifactId>maven-source-plugin</artifactId>
+ </plugin>
+
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-remote-resources-plugin</artifactId>
+ <configuration>
+ <appendedResourcesDirectory>../src/main/appended-resources</appendedResourcesDirectory>
+ <resourceBundles>
+ <resourceBundle>org.apache:apache-jar-resource-bundle:1.4</resourceBundle>
+ </resourceBundles>
+ <supplementalModels>
+ <supplementalModel>supplemental-models.xml</supplementalModel>
+ </supplementalModels>
+ </configuration>
+ </plugin>
+
+ </plugins>
+ </build>
+
+ <dependencies>
+
+ <!-- our modules -->
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-math</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-math</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-hdfs</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>${project.groupId}</groupId>
+ <artifactId>mahout-hdfs</artifactId>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+
+ <!-- Third Party -->
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <version>11.0.2</version>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-client</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.codehaus.jackson</groupId>
+ <artifactId>jackson-core-asl</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.codehaus.jackson</groupId>
+ <artifactId>jackson-mapper-asl</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-jcl</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-lang3</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>commons-cli</groupId>
+ <artifactId>commons-cli</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>com.thoughtworks.xstream</groupId>
+ <artifactId>xstream</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.lucene</groupId>
+ <artifactId>lucene-core</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.lucene</groupId>
+ <artifactId>lucene-analyzers-common</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.mahout.commons</groupId>
+ <artifactId>commons-cli</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-math3</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-all</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>com.carrotsearch.randomizedtesting</groupId>
+ <artifactId>randomizedtesting-runner</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.easymock</groupId>
+ <artifactId>easymock</artifactId>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.mrunit</groupId>
+ <artifactId>mrunit</artifactId>
+ <version>1.0.0</version>
+ <classifier>${hadoop.classifier}</classifier>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>commons-httpclient</groupId>
+ <artifactId>commons-httpclient</artifactId>
+ <version>3.0.1</version>
+ <scope>test</scope>
+ </dependency>
+
+ <dependency>
+ <groupId>org.apache.solr</groupId>
+ <artifactId>solr-commons-csv</artifactId>
+ <version>3.5.0</version>
+ </dependency>
+
+ </dependencies>
+
+</project>
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/assembly/job.xml
----------------------------------------------------------------------
diff --git a/mr/src/main/assembly/job.xml b/mr/src/main/assembly/job.xml
new file mode 100644
index 0000000..2bdb3ce
--- /dev/null
+++ b/mr/src/main/assembly/job.xml
@@ -0,0 +1,61 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<assembly
+ xmlns="http://maven.apache.org/plugins/maven-assembly-plugin/assembly/1.1.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/plugins/maven-assembly-plugin/assembly/1.1.0
+ http://maven.apache.org/xsd/assembly-1.1.0.xsd">
+ <id>job</id>
+ <formats>
+ <format>jar</format>
+ </formats>
+ <includeBaseDirectory>false</includeBaseDirectory>
+ <dependencySets>
+ <dependencySet>
+ <unpack>true</unpack>
+ <unpackOptions>
+ <!-- MAHOUT-1126 -->
+ <excludes>
+ <exclude>META-INF/LICENSE</exclude>
+ </excludes>
+ </unpackOptions>
+ <scope>runtime</scope>
+ <outputDirectory>/</outputDirectory>
+ <useTransitiveFiltering>true</useTransitiveFiltering>
+ <excludes>
+ <exclude>org.apache.hadoop:hadoop-core</exclude>
+ </excludes>
+ </dependencySet>
+ </dependencySets>
+ <fileSets>
+ <fileSet>
+ <directory>${basedir}/target/classes</directory>
+ <outputDirectory>/</outputDirectory>
+ <excludes>
+ <exclude>*.jar</exclude>
+ </excludes>
+ </fileSet>
+ <fileSet>
+ <directory>${basedir}/target/classes</directory>
+ <outputDirectory>/</outputDirectory>
+ <includes>
+ <include>driver.classes.default.props</include>
+ </includes>
+ </fileSet>
+ </fileSets>
+</assembly>
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/Version.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/Version.java b/mr/src/main/java/org/apache/mahout/Version.java
new file mode 100644
index 0000000..5f3c879
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/Version.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout;
+
+import com.google.common.base.Charsets;
+import com.google.common.io.Resources;
+
+import java.io.IOException;
+
+public final class Version {
+
+ private Version() {
+ }
+
+ public static String version() {
+ return Version.class.getPackage().getImplementationVersion();
+ }
+
+ public static String versionFromResource() throws IOException {
+ return Resources.toString(Resources.getResource("version"), Charsets.UTF_8);
+ }
+
+ public static void main(String[] args) throws IOException {
+ System.out.println(version() + ' ' + versionFromResource());
+ }
+}
[18/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java
new file mode 100644
index 0000000..a5f32ad
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/Omega.java
@@ -0,0 +1,257 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+/**
+ * simplistic implementation for Omega matrix in Stochastic SVD method
+ */
+public class Omega {
+
+ private static final double UNIFORM_DIVISOR = Math.pow(2.0, 64);
+
+ private final long seed;
+ private final int kp;
+
+ public Omega(long seed, int kp) {
+ this.seed = seed;
+ this.kp = kp;
+ }
+
+ /**
+ * Get omega element at (x,y) uniformly distributed within [-1...1)
+ *
+ * @param row
+ * omega row
+ * @param column
+ * omega column
+ */
+ public double getQuick(int row, int column) {
+ long hash = murmur64((long) row << Integer.SIZE | column, 8, seed);
+ return hash / UNIFORM_DIVISOR;
+ }
+
+ /**
+ * compute YRow=ARow*Omega.
+ *
+ * @param aRow
+ * row of matrix A (size n)
+ * @param yRow
+ * row of matrix Y (result) must be pre-allocated to size of (k+p)
+ */
+ @Deprecated
+ public void computeYRow(Vector aRow, double[] yRow) {
+ // assert yRow.length == kp;
+ Arrays.fill(yRow, 0.0);
+ if (aRow.isDense()) {
+ int n = aRow.size();
+ for (int j = 0; j < n; j++) {
+ accumDots(j, aRow.getQuick(j), yRow);
+ }
+ } else {
+ for (Element el : aRow.nonZeroes()) {
+ accumDots(el.index(), el.get(), yRow);
+ }
+ }
+ }
+
+ /**
+ * A version to compute yRow as a sparse vector in case of extremely sparse
+ * matrices
+ *
+ * @param aRow
+ * @param yRowOut
+ */
+ public void computeYRow(Vector aRow, Vector yRowOut) {
+ yRowOut.assign(0.0);
+ if (aRow.isDense()) {
+ int n = aRow.size();
+ for (int j = 0; j < n; j++) {
+ accumDots(j, aRow.getQuick(j), yRowOut);
+ }
+ } else {
+ for (Element el : aRow.nonZeroes()) {
+ accumDots(el.index(), el.get(), yRowOut);
+ }
+ }
+ }
+
+ /*
+ * computes t(Omega) %*% v in multithreaded fashion
+ */
+ public Vector mutlithreadedTRightMultiply(final Vector v) {
+
+ int nThreads = Runtime.getRuntime().availableProcessors();
+ ExecutorService es =
+ new ThreadPoolExecutor(nThreads,
+ nThreads,
+ 1,
+ TimeUnit.SECONDS,
+ new ArrayBlockingQueue<Runnable>(kp));
+
+ try {
+
+ List<Future<Double>> dotFutures = Lists.newArrayListWithCapacity(kp);
+
+ for (int i = 0; i < kp; i++) {
+ final int index = i;
+
+ Future<Double> dotFuture = es.submit(new Callable<Double>() {
+ @Override
+ public Double call() throws Exception {
+ double result = 0.0;
+ if (v.isDense()) {
+ for (int k = 0; k < v.size(); k++) {
+ // it's ok, this is reentrant
+ result += getQuick(k, index) * v.getQuick(k);
+ }
+
+ } else {
+ for (Element el : v.nonZeroes()) {
+ int k = el.index();
+ result += getQuick(k, index) * el.get();
+ }
+ }
+ return result;
+ }
+ });
+ dotFutures.add(dotFuture);
+ }
+
+ try {
+ Vector res = new DenseVector(kp);
+ for (int i = 0; i < kp; i++) {
+ res.setQuick(i, dotFutures.get(i).get());
+ }
+ return res;
+ } catch (InterruptedException exc) {
+ throw new IllegalStateException("Interrupted", exc);
+ } catch (ExecutionException exc) {
+ if (exc.getCause() instanceof RuntimeException) {
+ throw (RuntimeException) exc.getCause();
+ } else {
+ throw new IllegalStateException(exc.getCause());
+ }
+ }
+
+ } finally {
+ es.shutdown();
+ }
+ }
+
+ protected void accumDots(int aIndex, double aElement, double[] yRow) {
+ for (int i = 0; i < kp; i++) {
+ yRow[i] += getQuick(aIndex, i) * aElement;
+ }
+ }
+
+ protected void accumDots(int aIndex, double aElement, Vector yRow) {
+ for (int i = 0; i < kp; i++) {
+ yRow.setQuick(i, yRow.getQuick(i) + getQuick(aIndex, i) * aElement);
+ }
+ }
+
+ /**
+ * Shortened version for data < 8 bytes packed into {@code len} lowest bytes
+ * of {@code val}.
+ *
+ * @param val
+ * the value
+ * @param len
+ * the length of data packed into this many low bytes of {@code val}
+ * @param seed
+ * the seed to use
+ * @return murmur hash
+ */
+ public static long murmur64(long val, int len, long seed) {
+
+ // assert len > 0 && len <= 8;
+ long m = 0xc6a4a7935bd1e995L;
+ long h = seed ^ len * m;
+
+ long k = val;
+
+ k *= m;
+ int r = 47;
+ k ^= k >>> r;
+ k *= m;
+
+ h ^= k;
+ h *= m;
+
+ h ^= h >>> r;
+ h *= m;
+ h ^= h >>> r;
+ return h;
+ }
+
+ public static long murmur64(byte[] val, int offset, int len, long seed) {
+
+ long m = 0xc6a4a7935bd1e995L;
+ int r = 47;
+ long h = seed ^ (len * m);
+
+ int lt = len >>> 3;
+ for (int i = 0; i < lt; i++, offset += 8) {
+ long k = 0;
+ for (int j = 0; j < 8; j++) {
+ k <<= 8;
+ k |= val[offset + j] & 0xff;
+ }
+
+ k *= m;
+ k ^= k >>> r;
+ k *= m;
+
+ h ^= k;
+ h *= m;
+ }
+
+ if (offset < len) {
+ long k = 0;
+ while (offset < len) {
+ k <<= 8;
+ k |= val[offset] & 0xff;
+ offset++;
+ }
+ h ^= k;
+ h *= m;
+ }
+
+ h ^= h >>> r;
+ h *= m;
+ h ^= h >>> r;
+ return h;
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/QJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/QJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/QJob.java
new file mode 100644
index 0000000..76dc299
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/QJob.java
@@ -0,0 +1,237 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Deque;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.compress.DefaultCodec;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.lib.MultipleOutputs;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRFirstStep;
+
+/**
+ * Compute first level of QHat-transpose blocks.
+ * <P>
+ *
+ * See Mahout-376 working notes for details.
+ * <P>
+ *
+ * Uses some of Hadoop deprecated api wherever newer api is not available.
+ * Hence, @SuppressWarnings("deprecation") for imports (MAHOUT-593).
+ * <P>
+ *
+ */
+@SuppressWarnings("deprecation")
+public final class QJob {
+
+ public static final String PROP_OMEGA_SEED = "ssvd.omegaseed";
+ public static final String PROP_K = QRFirstStep.PROP_K;
+ public static final String PROP_P = QRFirstStep.PROP_P;
+ public static final String PROP_SB_PATH = "ssvdpca.sb.path";
+ public static final String PROP_AROWBLOCK_SIZE =
+ QRFirstStep.PROP_AROWBLOCK_SIZE;
+
+ public static final String OUTPUT_RHAT = "R";
+ public static final String OUTPUT_QHAT = "QHat";
+
+ private QJob() {
+ }
+
+ public static class QMapper
+ extends
+ Mapper<Writable, VectorWritable, SplitPartitionedWritable, VectorWritable> {
+
+ private MultipleOutputs outputs;
+ private final Deque<Closeable> closeables = Lists.newLinkedList();
+ private SplitPartitionedWritable qHatKey;
+ private SplitPartitionedWritable rHatKey;
+ private Vector yRow;
+ private Vector sb;
+ private Omega omega;
+ private int kp;
+
+ private QRFirstStep qr;
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+
+ Configuration conf = context.getConfiguration();
+ int k = Integer.parseInt(conf.get(PROP_K));
+ int p = Integer.parseInt(conf.get(PROP_P));
+ kp = k + p;
+ long omegaSeed = Long.parseLong(conf.get(PROP_OMEGA_SEED));
+ omega = new Omega(omegaSeed, k + p);
+
+ String sbPathStr = conf.get(PROP_SB_PATH);
+ if (sbPathStr != null) {
+ sb = SSVDHelper.loadAndSumUpVectors(new Path(sbPathStr), conf);
+ if (sb == null)
+ throw new IOException(String.format("Unable to load s_omega from path %s.", sbPathStr));
+ }
+
+ outputs = new MultipleOutputs(new JobConf(conf));
+ closeables.addFirst(new Closeable() {
+ @Override
+ public void close() throws IOException {
+ outputs.close();
+ }
+ });
+
+ qHatKey = new SplitPartitionedWritable(context);
+ rHatKey = new SplitPartitionedWritable(context);
+
+ OutputCollector<Writable, DenseBlockWritable> qhatCollector =
+ new OutputCollector<Writable, DenseBlockWritable>() {
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void collect(Writable nil, DenseBlockWritable dbw)
+ throws IOException {
+ outputs.getCollector(OUTPUT_QHAT, null).collect(qHatKey, dbw);
+ qHatKey.incrementItemOrdinal();
+ }
+ };
+
+ OutputCollector<Writable, VectorWritable> rhatCollector =
+ new OutputCollector<Writable, VectorWritable>() {
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void collect(Writable nil, VectorWritable rhat)
+ throws IOException {
+ outputs.getCollector(OUTPUT_RHAT, null).collect(rHatKey, rhat);
+ rHatKey.incrementItemOrdinal();
+ }
+ };
+
+ qr = new QRFirstStep(conf, qhatCollector, rhatCollector);
+ closeables.addFirst(qr); // important: qr closes first!!
+ yRow = new DenseVector(kp);
+ }
+
+ @Override
+ protected void map(Writable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ omega.computeYRow(value.get(), yRow);
+ if (sb != null) {
+ yRow.assign(sb, Functions.MINUS);
+ }
+ qr.collect(key, yRow);
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+ IOUtils.close(closeables);
+ }
+ }
+
+ public static void run(Configuration conf,
+ Path[] inputPaths,
+ Path sbPath,
+ Path outputPath,
+ int aBlockRows,
+ int minSplitSize,
+ int k,
+ int p,
+ long seed,
+ int numReduceTasks) throws ClassNotFoundException,
+ InterruptedException, IOException {
+
+ JobConf oldApiJob = new JobConf(conf);
+ MultipleOutputs.addNamedOutput(oldApiJob,
+ OUTPUT_QHAT,
+ org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
+ SplitPartitionedWritable.class,
+ DenseBlockWritable.class);
+ MultipleOutputs.addNamedOutput(oldApiJob,
+ OUTPUT_RHAT,
+ org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
+ SplitPartitionedWritable.class,
+ VectorWritable.class);
+
+ Job job = new Job(oldApiJob);
+ job.setJobName("Q-job");
+ job.setJarByClass(QJob.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ FileInputFormat.setInputPaths(job, inputPaths);
+ if (minSplitSize > 0) {
+ FileInputFormat.setMinInputSplitSize(job, minSplitSize);
+ }
+
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ FileOutputFormat.setCompressOutput(job, true);
+ FileOutputFormat.setOutputCompressorClass(job, DefaultCodec.class);
+ SequenceFileOutputFormat.setOutputCompressionType(job,
+ CompressionType.BLOCK);
+
+ job.setMapOutputKeyClass(SplitPartitionedWritable.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+
+ job.setOutputKeyClass(SplitPartitionedWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+
+ job.setMapperClass(QMapper.class);
+
+ job.getConfiguration().setInt(PROP_AROWBLOCK_SIZE, aBlockRows);
+ job.getConfiguration().setLong(PROP_OMEGA_SEED, seed);
+ job.getConfiguration().setInt(PROP_K, k);
+ job.getConfiguration().setInt(PROP_P, p);
+ if (sbPath != null) {
+ job.getConfiguration().set(PROP_SB_PATH, sbPath.toString());
+ }
+
+ /*
+ * number of reduce tasks doesn't matter. we don't actually send anything to
+ * reducers.
+ */
+
+ job.setNumReduceTasks(0 /* numReduceTasks */);
+
+ job.submit();
+ job.waitForCompletion(false);
+
+ if (!job.isSuccessful()) {
+ throw new IOException("Q job unsuccessful.");
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDCli.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDCli.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDCli.java
new file mode 100644
index 0000000..7b4fefb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDCli.java
@@ -0,0 +1,201 @@
+/* Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.hadoop.MatrixColumnMeansJob;
+
+/**
+ * Mahout CLI adapter for SSVDSolver
+ */
+public class SSVDCli extends AbstractJob {
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption("rank", "k", "decomposition rank", true);
+ addOption("oversampling", "p", "oversampling", String.valueOf(15));
+ addOption("blockHeight",
+ "r",
+ "Y block height (must be > (k+p))",
+ String.valueOf(10000));
+ addOption("outerProdBlockHeight",
+ "oh",
+ "block height of outer products during multiplication, increase for sparse inputs",
+ String.valueOf(30000));
+ addOption("abtBlockHeight",
+ "abth",
+ "block height of Y_i in ABtJob during AB' multiplication, increase for extremely sparse inputs",
+ String.valueOf(200000));
+ addOption("minSplitSize", "s", "minimum split size", String.valueOf(-1));
+ addOption("computeU", "U", "compute U (true/false)", String.valueOf(true));
+ addOption("uHalfSigma",
+ "uhs",
+ "Compute U * Sigma^0.5",
+ String.valueOf(false));
+ addOption("uSigma", "us", "Compute U * Sigma", String.valueOf(false));
+ addOption("computeV", "V", "compute V (true/false)", String.valueOf(true));
+ addOption("vHalfSigma",
+ "vhs",
+ "compute V * Sigma^0.5",
+ String.valueOf(false));
+ addOption("reduceTasks",
+ "t",
+ "number of reduce tasks (where applicable)",
+ true);
+ addOption("powerIter",
+ "q",
+ "number of additional power iterations (0..2 is good)",
+ String.valueOf(0));
+ addOption("broadcast",
+ "br",
+ "whether use distributed cache to broadcast matrices wherever possible",
+ String.valueOf(true));
+ addOption("pca",
+ "pca",
+ "run in pca mode: compute column-wise mean and subtract from input",
+ String.valueOf(false));
+ addOption("pcaOffset",
+ "xi",
+ "path(glob) of external pca mean (optional, dont compute, use external mean");
+ addOption(DefaultOptionCreator.overwriteOption().create());
+
+ Map<String, List<String>> pargs = parseArguments(args);
+ if (pargs == null) {
+ return -1;
+ }
+
+ int k = Integer.parseInt(getOption("rank"));
+ int p = Integer.parseInt(getOption("oversampling"));
+ int r = Integer.parseInt(getOption("blockHeight"));
+ int h = Integer.parseInt(getOption("outerProdBlockHeight"));
+ int abh = Integer.parseInt(getOption("abtBlockHeight"));
+ int q = Integer.parseInt(getOption("powerIter"));
+ int minSplitSize = Integer.parseInt(getOption("minSplitSize"));
+ boolean computeU = Boolean.parseBoolean(getOption("computeU"));
+ boolean computeV = Boolean.parseBoolean(getOption("computeV"));
+ boolean cUHalfSigma = Boolean.parseBoolean(getOption("uHalfSigma"));
+ boolean cUSigma = Boolean.parseBoolean(getOption("uSigma"));
+ boolean cVHalfSigma = Boolean.parseBoolean(getOption("vHalfSigma"));
+ int reduceTasks = Integer.parseInt(getOption("reduceTasks"));
+ boolean broadcast = Boolean.parseBoolean(getOption("broadcast"));
+ String xiPathStr = getOption("pcaOffset");
+ Path xiPath = xiPathStr == null ? null : new Path(xiPathStr);
+ boolean pca = Boolean.parseBoolean(getOption("pca")) || xiPath != null;
+
+ boolean overwrite = hasOption(DefaultOptionCreator.OVERWRITE_OPTION);
+
+ Configuration conf = getConf();
+ if (conf == null) {
+ throw new IOException("No Hadoop configuration present");
+ }
+
+ Path[] inputPaths = { getInputPath() };
+ Path tempPath = getTempPath();
+ FileSystem fs = FileSystem.get(getTempPath().toUri(), conf);
+
+ // housekeeping
+ if (overwrite) {
+ // clear the output path
+ HadoopUtil.delete(getConf(), getOutputPath());
+ // clear the temp path
+ HadoopUtil.delete(getConf(), getTempPath());
+ }
+
+ fs.mkdirs(getOutputPath());
+
+ // MAHOUT-817
+ if (pca && xiPath == null) {
+ xiPath = new Path(tempPath, "xi");
+ if (overwrite) {
+ fs.delete(xiPath, true);
+ }
+ MatrixColumnMeansJob.run(conf, inputPaths[0], xiPath);
+ }
+
+ SSVDSolver solver =
+ new SSVDSolver(conf,
+ inputPaths,
+ new Path(tempPath, "ssvd"),
+ r,
+ k,
+ p,
+ reduceTasks);
+
+ solver.setMinSplitSize(minSplitSize);
+ solver.setComputeU(computeU);
+ solver.setComputeV(computeV);
+ solver.setcUHalfSigma(cUHalfSigma);
+ solver.setcVHalfSigma(cVHalfSigma);
+ solver.setcUSigma(cUSigma);
+ solver.setOuterBlockHeight(h);
+ solver.setAbtBlockHeight(abh);
+ solver.setQ(q);
+ solver.setBroadcast(broadcast);
+ solver.setOverwrite(overwrite);
+
+ if (xiPath != null) {
+ solver.setPcaMeanPath(new Path(xiPath, "part-*"));
+ }
+
+ solver.run();
+
+ Vector svalues = solver.getSingularValues().viewPart(0, k);
+ SSVDHelper.saveVector(svalues, getOutputPath("sigma"), conf);
+
+ if (computeU && !fs.rename(new Path(solver.getUPath()), getOutputPath())) {
+ throw new IOException("Unable to move U results to the output path.");
+ }
+ if (cUHalfSigma
+ && !fs.rename(new Path(solver.getuHalfSigmaPath()), getOutputPath())) {
+ throw new IOException("Unable to move U*Sigma^0.5 results to the output path.");
+ }
+ if (cUSigma
+ && !fs.rename(new Path(solver.getuSigmaPath()), getOutputPath())) {
+ throw new IOException("Unable to move U*Sigma results to the output path.");
+ }
+ if (computeV && !fs.rename(new Path(solver.getVPath()), getOutputPath())) {
+ throw new IOException("Unable to move V results to the output path.");
+ }
+ if (cVHalfSigma
+ && !fs.rename(new Path(solver.getvHalfSigmaPath()), getOutputPath())) {
+ throw new IOException("Unable to move V*Sigma^0.5 results to the output path.");
+ }
+
+ // Delete the temp path on exit
+ fs.deleteOnExit(getTempPath());
+
+ return 0;
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new SSVDCli(), args);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDHelper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDHelper.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDHelper.java
new file mode 100644
index 0000000..c585f33
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDHelper.java
@@ -0,0 +1,322 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterators;
+import com.google.common.io.Closeables;
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.List;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseSymmetricMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.UpperTriangular;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * set of small file manipulation helpers.
+ */
+public final class SSVDHelper {
+
+ private static final Pattern OUTPUT_FILE_PATTERN = Pattern.compile("(\\w+)-(m|r)-(\\d+)(\\.\\w+)?");
+
+ private SSVDHelper() {
+ }
+
+ /**
+ * load single vector from an hdfs file (possibly presented as glob).
+ */
+ static Vector loadVector(Path glob, Configuration conf) throws IOException {
+
+ SequenceFileDirValueIterator<VectorWritable> iter =
+ new SequenceFileDirValueIterator<>(glob,
+ PathType.GLOB,
+ null,
+ null,
+ true,
+ conf);
+
+ try {
+ if (!iter.hasNext()) {
+ throw new IOException("Empty input while reading vector");
+ }
+ VectorWritable vw = iter.next();
+
+ if (iter.hasNext()) {
+ throw new IOException("Unexpected data after the end of vector file");
+ }
+
+ return vw.get();
+
+ } finally {
+ Closeables.close(iter, true);
+ }
+ }
+
+ /**
+ * save single vector into hdfs file.
+ *
+ * @param v vector to save
+ */
+ public static void saveVector(Vector v,
+ Path vectorFilePath,
+ Configuration conf) throws IOException {
+ VectorWritable vw = new VectorWritable(v);
+ FileSystem fs = FileSystem.get(conf);
+ try (SequenceFile.Writer w = new SequenceFile.Writer(fs,
+ conf,
+ vectorFilePath,
+ IntWritable.class,
+ VectorWritable.class)) {
+ w.append(new IntWritable(), vw);
+ }
+ /*
+ * this is a writer, no quiet close please. we must bail out on incomplete
+ * close.
+ */
+
+ }
+
+ /**
+ * sniff label type in the input files
+ */
+ static Class<? extends Writable> sniffInputLabelType(Path[] inputPath,
+ Configuration conf)
+ throws IOException {
+ FileSystem fs = FileSystem.get(conf);
+ for (Path p : inputPath) {
+ FileStatus[] fstats = fs.globStatus(p);
+ if (fstats == null || fstats.length == 0) {
+ continue;
+ }
+
+ FileStatus firstSeqFile;
+ if (fstats[0].isDir()) {
+ firstSeqFile = fs.listStatus(fstats[0].getPath(), PathFilters.logsCRCFilter())[0];
+ } else {
+ firstSeqFile = fstats[0];
+ }
+
+ SequenceFile.Reader r = null;
+ try {
+ r = new SequenceFile.Reader(fs, firstSeqFile.getPath(), conf);
+ return r.getKeyClass().asSubclass(Writable.class);
+ } finally {
+ Closeables.close(r, true);
+ }
+ }
+ throw new IOException("Unable to open input files to determine input label type.");
+ }
+
+ static final Comparator<FileStatus> PARTITION_COMPARATOR =
+ new Comparator<FileStatus>() {
+ private final Matcher matcher = OUTPUT_FILE_PATTERN.matcher("");
+
+ @Override
+ public int compare(FileStatus o1, FileStatus o2) {
+ matcher.reset(o1.getPath().getName());
+ if (!matcher.matches()) {
+ throw new IllegalArgumentException("Unexpected file name, unable to deduce partition #:"
+ + o1.getPath());
+ }
+ int p1 = Integer.parseInt(matcher.group(3));
+ matcher.reset(o2.getPath().getName());
+ if (!matcher.matches()) {
+ throw new IllegalArgumentException("Unexpected file name, unable to deduce partition #:"
+ + o2.getPath());
+ }
+
+ int p2 = Integer.parseInt(matcher.group(3));
+ return p1 - p2;
+ }
+
+ };
+
+ public static Iterator<Pair<Writable, Vector>> drmIterator(FileSystem fs, Path glob, Configuration conf,
+ Deque<Closeable> closeables)
+ throws IOException {
+ SequenceFileDirIterator<Writable, VectorWritable> ret =
+ new SequenceFileDirIterator<>(glob,
+ PathType.GLOB,
+ PathFilters.logsCRCFilter(),
+ PARTITION_COMPARATOR,
+ true,
+ conf);
+ closeables.addFirst(ret);
+ return Iterators.transform(ret, new Function<Pair<Writable, VectorWritable>, Pair<Writable, Vector>>() {
+ @Override
+ public Pair<Writable, Vector> apply(Pair<Writable, VectorWritable> p) {
+ return new Pair(p.getFirst(), p.getSecond().get());
+ }
+ });
+ }
+
+ /**
+ * helper capabiltiy to load distributed row matrices into dense matrix (to
+ * support tests mainly).
+ *
+ * @param fs filesystem
+ * @param glob FS glob
+ * @param conf configuration
+ * @return Dense matrix array
+ */
+ public static DenseMatrix drmLoadAsDense(FileSystem fs, Path glob, Configuration conf) throws IOException {
+
+ Deque<Closeable> closeables = new ArrayDeque<>();
+ try {
+ List<double[]> denseData = new ArrayList<>();
+ for (Iterator<Pair<Writable, Vector>> iter = drmIterator(fs, glob, conf, closeables);
+ iter.hasNext(); ) {
+ Pair<Writable, Vector> p = iter.next();
+ Vector v = p.getSecond();
+ double[] dd = new double[v.size()];
+ if (v.isDense()) {
+ for (int i = 0; i < v.size(); i++) {
+ dd[i] = v.getQuick(i);
+ }
+ } else {
+ for (Vector.Element el : v.nonZeroes()) {
+ dd[el.index()] = el.get();
+ }
+ }
+ denseData.add(dd);
+ }
+ if (denseData.size() == 0) {
+ return null;
+ } else {
+ return new DenseMatrix(denseData.toArray(new double[denseData.size()][]));
+ }
+ } finally {
+ IOUtils.close(closeables);
+ }
+ }
+
+ /**
+ * Load multiple upper triangular matrices and sum them up.
+ *
+ * @return the sum of upper triangular inputs.
+ */
+ public static DenseSymmetricMatrix loadAndSumUpperTriangularMatricesAsSymmetric(Path glob, Configuration conf) throws IOException {
+ Vector v = loadAndSumUpVectors(glob, conf);
+ return v == null ? null : new DenseSymmetricMatrix(v);
+ }
+
+ /**
+ * @return sum of all vectors in different files specified by glob
+ */
+ public static Vector loadAndSumUpVectors(Path glob, Configuration conf)
+ throws IOException {
+
+ SequenceFileDirValueIterator<VectorWritable> iter =
+ new SequenceFileDirValueIterator<>(glob,
+ PathType.GLOB,
+ null,
+ PARTITION_COMPARATOR,
+ true,
+ conf);
+
+ try {
+ Vector v = null;
+ while (iter.hasNext()) {
+ if (v == null) {
+ v = new DenseVector(iter.next().get());
+ } else {
+ v.assign(iter.next().get(), Functions.PLUS);
+ }
+ }
+ return v;
+
+ } finally {
+ Closeables.close(iter, true);
+ }
+
+ }
+
+ /**
+ * Load only one upper triangular matrix and issue error if mroe than one is
+ * found.
+ */
+ public static UpperTriangular loadUpperTriangularMatrix(Path glob, Configuration conf) throws IOException {
+
+ /*
+ * there still may be more than one file in glob and only one of them must
+ * contain the matrix.
+ */
+
+ try (SequenceFileDirValueIterator<VectorWritable> iter = new SequenceFileDirValueIterator<>(glob,
+ PathType.GLOB,
+ null,
+ null,
+ true,
+ conf)) {
+ if (!iter.hasNext()) {
+ throw new IOException("No triangular matrices found");
+ }
+ Vector v = iter.next().get();
+ UpperTriangular result = new UpperTriangular(v);
+ if (iter.hasNext()) {
+ throw new IOException("Unexpected overrun in upper triangular matrix files");
+ }
+ return result;
+
+ }
+ }
+
+ /**
+ * extracts row-wise raw data from a Mahout matrix for 3rd party solvers.
+ * Unfortunately values member is 100% encapsulated in {@link org.apache.mahout.math.DenseMatrix} at
+ * this point, so we have to resort to abstract element-wise copying.
+ */
+ public static double[][] extractRawData(Matrix m) {
+ int rows = m.numRows();
+ int cols = m.numCols();
+ double[][] result = new double[rows][];
+ for (int i = 0; i < rows; i++) {
+ result[i] = new double[cols];
+ for (int j = 0; j < cols; j++) {
+ result[i][j] = m.getQuick(i, j);
+ }
+ }
+ return result;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDSolver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDSolver.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDSolver.java
new file mode 100644
index 0000000..94be450
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SSVDSolver.java
@@ -0,0 +1,662 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.*;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.solver.EigenDecomposition;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Deque;
+import java.util.Random;
+
+/**
+ * Stochastic SVD solver (API class).
+ * <p/>
+ * <p/>
+ * Implementation details are in my working notes in MAHOUT-376
+ * (https://issues.apache.org/jira/browse/MAHOUT-376).
+ * <p/>
+ * <p/>
+ * As of the time of this writing, I don't have benchmarks for this method in
+ * comparison to other methods. However, non-hadoop differentiating
+ * characteristics of this method are thought to be :
+ * <LI>"faster" and precision is traded off in favor of speed. However, there's
+ * lever in terms of "oversampling parameter" p. Higher values of p produce
+ * better precision but are trading off speed (and minimum RAM requirement).
+ * This also means that this method is almost guaranteed to be less precise than
+ * Lanczos unless full rank SVD decomposition is sought.
+ * <LI>"more scale" -- can presumably take on larger problems than Lanczos one
+ * (not confirmed by benchmark at this time)
+ * <p/>
+ * <p/>
+ * <p/>
+ * Specifically in regards to this implementation, <i>I think</i> couple of
+ * other differentiating points are:
+ * <LI>no need to specify input matrix height or width in command line, it is
+ * what it gets to be.
+ * <LI>supports any Writable as DRM row keys and copies them to correspondent
+ * rows of U matrix;
+ * <LI>can request U or V or U<sub>σ</sub>=U* Σ<sup>0.5</sup> or
+ * V<sub>σ</sub>=V* Σ<sup>0.5</sup> none of which would require pass
+ * over input A and these jobs are parallel map-only jobs.
+ * <p/>
+ * <p/>
+ * <p/>
+ * This class is central public API for SSVD solver. The use pattern is as
+ * follows:
+ * <p/>
+ * <UL>
+ * <LI>create the solver using constructor and supplying computation parameters.
+ * <LI>set optional parameters thru setter methods.
+ * <LI>call {@link #run()}.
+ * <LI> {@link #getUPath()} (if computed) returns the path to the directory
+ * containing m x k U matrix file(s).
+ * <LI> {@link #getVPath()} (if computed) returns the path to the directory
+ * containing n x k V matrix file(s).
+ * <p/>
+ * </UL>
+ */
+public final class SSVDSolver {
+
+ private Vector svalues;
+ private boolean computeU = true;
+ private boolean computeV = true;
+ private String uPath;
+ private String vPath;
+ private String uSigmaPath;
+ private String uHalfSigmaPath;
+ private String vSigmaPath;
+ private String vHalfSigmaPath;
+ private int outerBlockHeight = 30000;
+ private int abtBlockHeight = 200000;
+
+ // configured stuff
+ private final Configuration conf;
+ private final Path[] inputPath;
+ private final Path outputPath;
+ private final int ablockRows;
+ private final int k;
+ private final int p;
+ private int q;
+ private final int reduceTasks;
+ private int minSplitSize = -1;
+ private boolean cUHalfSigma;
+ private boolean cUSigma;
+ private boolean cVHalfSigma;
+ private boolean cVSigma;
+ private boolean overwrite;
+ private boolean broadcast = true;
+ private Path pcaMeanPath;
+
+ // for debugging
+ private long omegaSeed;
+
+ /**
+ * create new SSVD solver. Required parameters are passed to constructor to
+ * ensure they are set. Optional parameters can be set using setters .
+ * <p/>
+ *
+ * @param conf hadoop configuration
+ * @param inputPath Input path (should be compatible with DistributedRowMatrix as of
+ * the time of this writing).
+ * @param outputPath Output path containing U, V and singular values vector files.
+ * @param ablockRows The vertical hight of a q-block (bigger value require more memory
+ * in mappers+ perhaps larger {@code minSplitSize} values
+ * @param k desired rank
+ * @param p SSVD oversampling parameter
+ * @param reduceTasks Number of reduce tasks (where applicable)
+ */
+ public SSVDSolver(Configuration conf,
+ Path[] inputPath,
+ Path outputPath,
+ int ablockRows,
+ int k,
+ int p,
+ int reduceTasks) {
+ this.conf = conf;
+ this.inputPath = inputPath;
+ this.outputPath = outputPath;
+ this.ablockRows = ablockRows;
+ this.k = k;
+ this.p = p;
+ this.reduceTasks = reduceTasks;
+ }
+
+ public int getQ() {
+ return q;
+ }
+
+ /**
+ * sets q, amount of additional power iterations to increase precision
+ * (0..2!). Defaults to 0.
+ *
+ * @param q
+ */
+ public void setQ(int q) {
+ this.q = q;
+ }
+
+ /**
+ * The setting controlling whether to compute U matrix of low rank SSVD.
+ * Default true.
+ */
+ public void setComputeU(boolean val) {
+ computeU = val;
+ }
+
+ /**
+ * Setting controlling whether to compute V matrix of low-rank SSVD.
+ *
+ * @param val true if we want to output V matrix. Default is true.
+ */
+ public void setComputeV(boolean val) {
+ computeV = val;
+ }
+
+ /**
+ * @param cUHat whether produce U*Sigma^0.5 as well (default false)
+ */
+ public void setcUHalfSigma(boolean cUHat) {
+ this.cUHalfSigma = cUHat;
+ }
+
+ /**
+ * @param cVHat whether produce V*Sigma^0.5 as well (default false)
+ */
+ public void setcVHalfSigma(boolean cVHat) {
+ this.cVHalfSigma = cVHat;
+ }
+
+ /**
+ * @param cUSigma whether produce U*Sigma output as well (default false)
+ */
+ public void setcUSigma(boolean cUSigma) {
+ this.cUSigma = cUSigma;
+ }
+
+ /**
+ * @param cVSigma whether produce V*Sigma output as well (default false)
+ */
+ public void setcVSigma(boolean cVSigma) {
+ this.cVSigma = cVSigma;
+ }
+
+ /**
+ * Sometimes, if requested A blocks become larger than a split, we may need to
+ * use that to ensure at least k+p rows of A get into a split. This is
+ * requirement necessary to obtain orthonormalized Q blocks of SSVD.
+ *
+ * @param size the minimum split size to use
+ */
+ public void setMinSplitSize(int size) {
+ minSplitSize = size;
+ }
+
+ /**
+ * This contains k+p singular values resulted from the solver run.
+ *
+ * @return singlular values (largest to smallest)
+ */
+ public Vector getSingularValues() {
+ return svalues;
+ }
+
+ /**
+ * returns U path (if computation were requested and successful).
+ *
+ * @return U output hdfs path, or null if computation was not completed for
+ * whatever reason.
+ */
+ public String getUPath() {
+ return uPath;
+ }
+
+ /**
+ * return V path ( if computation was requested and successful ) .
+ *
+ * @return V output hdfs path, or null if computation was not completed for
+ * whatever reason.
+ */
+ public String getVPath() {
+ return vPath;
+ }
+
+ public String getuSigmaPath() {
+ return uSigmaPath;
+ }
+
+ public String getuHalfSigmaPath() {
+ return uHalfSigmaPath;
+ }
+
+ public String getvSigmaPath() {
+ return vSigmaPath;
+ }
+
+ public String getvHalfSigmaPath() {
+ return vHalfSigmaPath;
+ }
+
+ public boolean isOverwrite() {
+ return overwrite;
+ }
+
+ /**
+ * if true, driver to clean output folder first if exists.
+ *
+ * @param overwrite
+ */
+ public void setOverwrite(boolean overwrite) {
+ this.overwrite = overwrite;
+ }
+
+ public int getOuterBlockHeight() {
+ return outerBlockHeight;
+ }
+
+ /**
+ * The height of outer blocks during Q'A multiplication. Higher values allow
+ * to produce less keys for combining and shuffle and sort therefore somewhat
+ * improving running time; but require larger blocks to be formed in RAM (so
+ * setting this too high can lead to OOM).
+ *
+ * @param outerBlockHeight
+ */
+ public void setOuterBlockHeight(int outerBlockHeight) {
+ this.outerBlockHeight = outerBlockHeight;
+ }
+
+ public int getAbtBlockHeight() {
+ return abtBlockHeight;
+ }
+
+ /**
+ * the block height of Y_i during power iterations. It is probably important
+ * to set it higher than default 200,000 for extremely sparse inputs and when
+ * more ram is available. y_i block height and ABt job would occupy approx.
+ * abtBlockHeight x (k+p) x sizeof (double) (as dense).
+ *
+ * @param abtBlockHeight
+ */
+ public void setAbtBlockHeight(int abtBlockHeight) {
+ this.abtBlockHeight = abtBlockHeight;
+ }
+
+ public boolean isBroadcast() {
+ return broadcast;
+ }
+
+ /**
+ * If this property is true, use DestributedCache mechanism to broadcast some
+ * stuff around. May improve efficiency. Default is false.
+ *
+ * @param broadcast
+ */
+ public void setBroadcast(boolean broadcast) {
+ this.broadcast = broadcast;
+ }
+
+ /**
+ * Optional. Single-vector file path for a vector (aka xi in MAHOUT-817
+ * working notes) to be subtracted from each row of input.
+ * <p/>
+ * <p/>
+ * Brute force approach would force would turn input into a dense input, which
+ * is often not very desirable. By supplying this offset to SSVD solver, we
+ * can avoid most of that overhead due to increased input density.
+ * <p/>
+ * <p/>
+ * The vector size for this offest is n (width of A input). In PCA and R this
+ * is known as "column means", but in this case it can be any offset of row
+ * vectors of course to propagate into SSVD solution.
+ * <p/>
+ */
+ public Path getPcaMeanPath() {
+ return pcaMeanPath;
+ }
+
+ public void setPcaMeanPath(Path pcaMeanPath) {
+ this.pcaMeanPath = pcaMeanPath;
+ }
+
+ long getOmegaSeed() {
+ return omegaSeed;
+ }
+
+ /**
+ * run all SSVD jobs.
+ *
+ * @throws IOException if I/O condition occurs.
+ */
+ public void run() throws IOException {
+
+ Deque<Closeable> closeables = Lists.newLinkedList();
+ try {
+ Class<? extends Writable> labelType =
+ SSVDHelper.sniffInputLabelType(inputPath, conf);
+ FileSystem fs = FileSystem.get(conf);
+
+ Path qPath = new Path(outputPath, "Q-job");
+ Path btPath = new Path(outputPath, "Bt-job");
+ Path uHatPath = new Path(outputPath, "UHat");
+ Path svPath = new Path(outputPath, "Sigma");
+ Path uPath = new Path(outputPath, "U");
+ Path uSigmaPath = new Path(outputPath, "USigma");
+ Path uHalfSigmaPath = new Path(outputPath, "UHalfSigma");
+ Path vPath = new Path(outputPath, "V");
+ Path vHalfSigmaPath = new Path(outputPath, "VHalfSigma");
+ Path vSigmaPath = new Path(outputPath, "VSigma");
+
+ Path pcaBasePath = new Path(outputPath, "pca");
+
+ if (overwrite) {
+ fs.delete(outputPath, true);
+ }
+
+ if (pcaMeanPath != null) {
+ fs.mkdirs(pcaBasePath);
+ }
+ Random rnd = RandomUtils.getRandom();
+ omegaSeed = rnd.nextLong();
+
+ Path sbPath = null;
+ double xisquaredlen = 0.0;
+ if (pcaMeanPath != null) {
+ /*
+ * combute s_b0 if pca offset present.
+ *
+ * Just in case, we treat xi path as a possible reduce or otherwise
+ * multiple task output that we assume we need to sum up partial
+ * components. If it is just one file, it will work too.
+ */
+
+ Vector xi = SSVDHelper.loadAndSumUpVectors(pcaMeanPath, conf);
+ if (xi == null) {
+ throw new IOException(String.format("unable to load mean path xi from %s.",
+ pcaMeanPath.toString()));
+ }
+
+ xisquaredlen = xi.dot(xi);
+ Omega omega = new Omega(omegaSeed, k + p);
+ Vector s_b0 = omega.mutlithreadedTRightMultiply(xi);
+
+ SSVDHelper.saveVector(s_b0, sbPath = new Path(pcaBasePath, "somega.seq"), conf);
+ }
+
+ /*
+ * if we work with pca offset, we need to precompute s_bq0 aka s_omega for
+ * jobs to use.
+ */
+
+ QJob.run(conf,
+ inputPath,
+ sbPath,
+ qPath,
+ ablockRows,
+ minSplitSize,
+ k,
+ p,
+ omegaSeed,
+ reduceTasks);
+
+ /*
+ * restrict number of reducers to a reasonable number so we don't have to
+ * run too many additions in the frontend when reconstructing BBt for the
+ * last B' and BB' computations. The user may not realize that and gives a
+ * bit too many (I would be happy i that were ever the case though).
+ */
+
+ BtJob.run(conf,
+ inputPath,
+ qPath,
+ pcaMeanPath,
+ btPath,
+ minSplitSize,
+ k,
+ p,
+ outerBlockHeight,
+ q <= 0 ? Math.min(1000, reduceTasks) : reduceTasks,
+ broadcast,
+ labelType,
+ q <= 0);
+
+ sbPath = new Path(btPath, BtJob.OUTPUT_SB + "-*");
+ Path sqPath = new Path(btPath, BtJob.OUTPUT_SQ + "-*");
+
+ // power iterations
+ for (int i = 0; i < q; i++) {
+
+ qPath = new Path(outputPath, String.format("ABt-job-%d", i + 1));
+ Path btPathGlob = new Path(btPath, BtJob.OUTPUT_BT + "-*");
+ ABtDenseOutJob.run(conf,
+ inputPath,
+ btPathGlob,
+ pcaMeanPath,
+ sqPath,
+ sbPath,
+ qPath,
+ ablockRows,
+ minSplitSize,
+ k,
+ p,
+ abtBlockHeight,
+ reduceTasks,
+ broadcast);
+
+ btPath = new Path(outputPath, String.format("Bt-job-%d", i + 1));
+
+ BtJob.run(conf,
+ inputPath,
+ qPath,
+ pcaMeanPath,
+ btPath,
+ minSplitSize,
+ k,
+ p,
+ outerBlockHeight,
+ i == q - 1 ? Math.min(1000, reduceTasks) : reduceTasks,
+ broadcast,
+ labelType,
+ i == q - 1);
+ sbPath = new Path(btPath, BtJob.OUTPUT_SB + "-*");
+ sqPath = new Path(btPath, BtJob.OUTPUT_SQ + "-*");
+ }
+
+ DenseSymmetricMatrix bbt =
+ SSVDHelper.loadAndSumUpperTriangularMatricesAsSymmetric(new Path(btPath,
+ BtJob.OUTPUT_BBT
+ + "-*"), conf);
+
+ // convert bbt to something our eigensolver could understand
+ assert bbt.columnSize() == k + p;
+
+ /*
+ * we currently use a 3rd party in-core eigensolver. So we need just a
+ * dense array representation for it.
+ */
+ Matrix bbtSquare = new DenseMatrix(k + p, k + p);
+ bbtSquare.assign(bbt);
+
+ // MAHOUT-817
+ if (pcaMeanPath != null) {
+ Vector sq = SSVDHelper.loadAndSumUpVectors(sqPath, conf);
+ Vector sb = SSVDHelper.loadAndSumUpVectors(sbPath, conf);
+ Matrix mC = sq.cross(sb);
+
+ bbtSquare.assign(mC, Functions.MINUS);
+ bbtSquare.assign(mC.transpose(), Functions.MINUS);
+
+ Matrix outerSq = sq.cross(sq);
+ outerSq.assign(Functions.mult(xisquaredlen));
+ bbtSquare.assign(outerSq, Functions.PLUS);
+
+ }
+
+ EigenDecomposition eigen = new EigenDecomposition(bbtSquare);
+
+ Matrix uHat = eigen.getV();
+ svalues = eigen.getRealEigenvalues().clone();
+
+ svalues.assign(Functions.SQRT);
+
+ // save/redistribute UHat
+ fs.mkdirs(uHatPath);
+ DistributedRowMatrixWriter.write(uHatPath =
+ new Path(uHatPath, "uhat.seq"), conf, uHat);
+
+ // save sigma.
+ SSVDHelper.saveVector(svalues,
+ svPath = new Path(svPath, "svalues.seq"),
+ conf);
+
+ UJob ujob = null;
+ if (computeU) {
+ ujob = new UJob();
+ ujob.run(conf,
+ new Path(btPath, BtJob.OUTPUT_Q + "-*"),
+ uHatPath,
+ svPath,
+ uPath,
+ k,
+ reduceTasks,
+ labelType,
+ OutputScalingEnum.NOSCALING);
+ // actually this is map-only job anyway
+ }
+
+ UJob uhsjob = null;
+ if (cUHalfSigma) {
+ uhsjob = new UJob();
+ uhsjob.run(conf,
+ new Path(btPath, BtJob.OUTPUT_Q + "-*"),
+ uHatPath,
+ svPath,
+ uHalfSigmaPath,
+ k,
+ reduceTasks,
+ labelType,
+ OutputScalingEnum.HALFSIGMA);
+ }
+
+ UJob usjob = null;
+ if (cUSigma) {
+ usjob = new UJob();
+ usjob.run(conf,
+ new Path(btPath, BtJob.OUTPUT_Q + "-*"),
+ uHatPath,
+ svPath,
+ uSigmaPath,
+ k,
+ reduceTasks,
+ labelType,
+ OutputScalingEnum.SIGMA);
+ }
+
+ VJob vjob = null;
+ if (computeV) {
+ vjob = new VJob();
+ vjob.run(conf,
+ new Path(btPath, BtJob.OUTPUT_BT + "-*"),
+ pcaMeanPath,
+ sqPath,
+ uHatPath,
+ svPath,
+ vPath,
+ k,
+ reduceTasks,
+ OutputScalingEnum.NOSCALING);
+ }
+
+ VJob vhsjob = null;
+ if (cVHalfSigma) {
+ vhsjob = new VJob();
+ vhsjob.run(conf,
+ new Path(btPath, BtJob.OUTPUT_BT + "-*"),
+ pcaMeanPath,
+ sqPath,
+ uHatPath,
+ svPath,
+ vHalfSigmaPath,
+ k,
+ reduceTasks,
+ OutputScalingEnum.HALFSIGMA);
+ }
+
+ VJob vsjob = null;
+ if (cVSigma) {
+ vsjob = new VJob();
+ vsjob.run(conf,
+ new Path(btPath, BtJob.OUTPUT_BT + "-*"),
+ pcaMeanPath,
+ sqPath,
+ uHatPath,
+ svPath,
+ vSigmaPath,
+ k,
+ reduceTasks,
+ OutputScalingEnum.SIGMA);
+ }
+
+ if (ujob != null) {
+ ujob.waitForCompletion();
+ this.uPath = uPath.toString();
+ }
+ if (uhsjob != null) {
+ uhsjob.waitForCompletion();
+ this.uHalfSigmaPath = uHalfSigmaPath.toString();
+ }
+ if (usjob != null) {
+ usjob.waitForCompletion();
+ this.uSigmaPath = uSigmaPath.toString();
+ }
+ if (vjob != null) {
+ vjob.waitForCompletion();
+ this.vPath = vPath.toString();
+ }
+ if (vhsjob != null) {
+ vhsjob.waitForCompletion();
+ this.vHalfSigmaPath = vHalfSigmaPath.toString();
+ }
+ if (vsjob != null) {
+ vsjob.waitForCompletion();
+ this.vSigmaPath = vSigmaPath.toString();
+ }
+
+ } catch (InterruptedException exc) {
+ throw new IOException("Interrupted", exc);
+ } catch (ClassNotFoundException exc) {
+ throw new IOException(exc);
+
+ } finally {
+ IOUtils.close(closeables);
+ }
+ }
+
+ enum OutputScalingEnum {
+ NOSCALING, SIGMA, HALFSIGMA
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SparseRowBlockAccumulator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SparseRowBlockAccumulator.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SparseRowBlockAccumulator.java
new file mode 100644
index 0000000..081f55a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SparseRowBlockAccumulator.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Aggregate incoming rows into blocks based on the row number (long). Rows can
+ * be sparse (meaning they come perhaps in big intervals) and don't even have to
+ * come in any order, but they should be coming in proximity, so when we output
+ * block key, we hopefully aggregate more than one row by then.
+ * <P>
+ *
+ * If block is sufficiently large to fit all rows that mapper may produce, it
+ * will not even ever hit a spill at all as we would already be plussing
+ * efficiently in the mapper.
+ * <P>
+ *
+ * Also, for sparse inputs it will also be working especially well if transposed
+ * columns of the left side matrix and corresponding rows of the right side
+ * matrix experience sparsity in same elements.
+ * <P>
+ *
+ */
+public class SparseRowBlockAccumulator implements
+ OutputCollector<Long, Vector>, Closeable {
+
+ private final int height;
+ private final OutputCollector<LongWritable, SparseRowBlockWritable> delegate;
+ private long currentBlockNum = -1;
+ private SparseRowBlockWritable block;
+ private final LongWritable blockKeyW = new LongWritable();
+
+ public SparseRowBlockAccumulator(int height,
+ OutputCollector<LongWritable, SparseRowBlockWritable> delegate) {
+ this.height = height;
+ this.delegate = delegate;
+ }
+
+ private void flushBlock() throws IOException {
+ if (block == null || block.getNumRows() == 0) {
+ return;
+ }
+ blockKeyW.set(currentBlockNum);
+ delegate.collect(blockKeyW, block);
+ block.clear();
+ }
+
+ @Override
+ public void collect(Long rowIndex, Vector v) throws IOException {
+
+ long blockKey = rowIndex / height;
+
+ if (blockKey != currentBlockNum) {
+ flushBlock();
+ if (block == null) {
+ block = new SparseRowBlockWritable(100);
+ }
+ currentBlockNum = blockKey;
+ }
+
+ block.plusRow((int) (rowIndex % height), v);
+ }
+
+ @Override
+ public void close() throws IOException {
+ flushBlock();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SparseRowBlockWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SparseRowBlockWritable.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SparseRowBlockWritable.java
new file mode 100644
index 0000000..b7f5b94
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SparseRowBlockWritable.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Varint;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.PlusMult;
+
+/**
+ * block that supports accumulating rows and their sums , suitable for combiner
+ * and reducers of multiplication jobs.
+ */
+public class SparseRowBlockWritable implements Writable {
+
+ private int[] rowIndices;
+ private Vector[] rows;
+ private int numRows;
+
+ public SparseRowBlockWritable() {
+ this(10);
+ }
+
+ public SparseRowBlockWritable(int initialCapacity) {
+ rowIndices = new int[initialCapacity];
+ rows = new Vector[initialCapacity];
+ }
+
+ public int[] getRowIndices() {
+ return rowIndices;
+ }
+
+ public Vector[] getRows() {
+ return rows;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ numRows = Varint.readUnsignedVarInt(in);
+ if (rows == null || rows.length < numRows) {
+ rows = new Vector[numRows];
+ rowIndices = new int[numRows];
+ }
+ VectorWritable vw = new VectorWritable();
+ for (int i = 0; i < numRows; i++) {
+ rowIndices[i] = Varint.readUnsignedVarInt(in);
+ vw.readFields(in);
+ rows[i] = vw.get().clone();
+ }
+
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ Varint.writeUnsignedVarInt(numRows, out);
+ VectorWritable vw = new VectorWritable();
+ for (int i = 0; i < numRows; i++) {
+ Varint.writeUnsignedVarInt(rowIndices[i], out);
+ vw.set(rows[i]);
+ vw.write(out);
+ }
+ }
+
+ public void plusRow(int index, Vector row) {
+ /*
+ * often accumulation goes in row-increasing order, so check for this to
+ * avoid binary search (another log Height multiplier).
+ */
+
+ int pos =
+ numRows == 0 || rowIndices[numRows - 1] < index ? -numRows - 1 : Arrays
+ .binarySearch(rowIndices, 0, numRows, index);
+ if (pos >= 0) {
+ rows[pos].assign(row, PlusMult.plusMult(1));
+ } else {
+ insertIntoPos(-pos - 1, index, row);
+ }
+ }
+
+ private void insertIntoPos(int pos, int rowIndex, Vector row) {
+ // reallocate if needed
+ if (numRows == rows.length) {
+ rows = Arrays.copyOf(rows, numRows + 1 << 1);
+ rowIndices = Arrays.copyOf(rowIndices, numRows + 1 << 1);
+ }
+ // make a hole if needed
+ System.arraycopy(rows, pos, rows, pos + 1, numRows - pos);
+ System.arraycopy(rowIndices, pos, rowIndices, pos + 1, numRows - pos);
+ // put
+ rowIndices[pos] = rowIndex;
+ rows[pos] = row.clone();
+ numRows++;
+ }
+
+ /**
+ * pluses one block into another. Use it for accumulation of partial products in
+ * combiners and reducers.
+ *
+ * @param bOther
+ * block to add
+ */
+ public void plusBlock(SparseRowBlockWritable bOther) {
+ /*
+ * since we maintained row indices in a sorted order, we can run sort merge
+ * to expedite this operation
+ */
+ int i = 0;
+ int j = 0;
+ while (i < numRows && j < bOther.numRows) {
+ while (i < numRows && rowIndices[i] < bOther.rowIndices[j]) {
+ i++;
+ }
+ if (i < numRows) {
+ if (rowIndices[i] == bOther.rowIndices[j]) {
+ rows[i].assign(bOther.rows[j], PlusMult.plusMult(1));
+ } else {
+ // insert into i-th position
+ insertIntoPos(i, bOther.rowIndices[j], bOther.rows[j]);
+ }
+ // increment in either case
+ i++;
+ j++;
+ }
+ }
+ for (; j < bOther.numRows; j++) {
+ insertIntoPos(numRows, bOther.rowIndices[j], bOther.rows[j]);
+ }
+ }
+
+ public int getNumRows() {
+ return numRows;
+ }
+
+ public void clear() {
+ numRows = 0;
+ Arrays.fill(rows, null);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SplitPartitionedWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SplitPartitionedWritable.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SplitPartitionedWritable.java
new file mode 100644
index 0000000..7caeb4a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/SplitPartitionedWritable.java
@@ -0,0 +1,151 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.io.Serializable;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.io.WritableComparator;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.Varint;
+
+/**
+ * a key for vectors allowing to identify them by their coordinates in original
+ * split of A.
+ *
+ * We assume all passes over A results in the same splits, thus, we can always
+ * prepare side files that come into contact with A, sp that they are sorted and
+ * partitioned same way.
+ * <P>
+ *
+ * Hashcode is defined the way that all records of the same split go to the same
+ * reducer.
+ * <P>
+ *
+ * In addition, we are defining a grouping comparator allowing group one split
+ * into the same reducer group.
+ * <P>
+ *
+ */
+public class SplitPartitionedWritable implements
+ WritableComparable<SplitPartitionedWritable> {
+
+ private int taskId;
+ private long taskItemOrdinal;
+
+ public SplitPartitionedWritable(Mapper<?, ?, ?, ?>.Context mapperContext) {
+ // this is basically a split # if i understand it right
+ taskId = mapperContext.getTaskAttemptID().getTaskID().getId();
+ }
+
+ public SplitPartitionedWritable() {
+ }
+
+ public int getTaskId() {
+ return taskId;
+ }
+
+ public long getTaskItemOrdinal() {
+ return taskItemOrdinal;
+ }
+
+ public void incrementItemOrdinal() {
+ taskItemOrdinal++;
+ }
+
+ public void setTaskItemOrdinal(long taskItemOrdinal) {
+ this.taskItemOrdinal = taskItemOrdinal;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ taskId = Varint.readUnsignedVarInt(in);
+ taskItemOrdinal = Varint.readUnsignedVarLong(in);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ Varint.writeUnsignedVarInt(taskId, out);
+ Varint.writeUnsignedVarLong(taskItemOrdinal, out);
+ }
+
+ @Override
+ public int hashCode() {
+ int prime = 31;
+ int result = 1;
+ result = prime * result + taskId;
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ SplitPartitionedWritable other = (SplitPartitionedWritable) obj;
+ return taskId == other.taskId;
+ }
+
+ @Override
+ public int compareTo(SplitPartitionedWritable o) {
+ if (taskId < o.taskId) {
+ return -1;
+ }
+ if (taskId > o.taskId) {
+ return 1;
+ }
+ if (taskItemOrdinal < o.taskItemOrdinal) {
+ return -1;
+ }
+ if (taskItemOrdinal > o.taskItemOrdinal) {
+ return 1;
+ }
+ return 0;
+ }
+
+ public static final class SplitGroupingComparator extends WritableComparator implements Serializable {
+
+ public SplitGroupingComparator() {
+ super(SplitPartitionedWritable.class, true);
+ }
+
+ @Override
+ public int compare(Object a, Object b) {
+ SplitPartitionedWritable o1 = (SplitPartitionedWritable) a;
+ SplitPartitionedWritable o2 = (SplitPartitionedWritable) b;
+
+ if (o1.taskId < o2.taskId) {
+ return -1;
+ }
+ if (o1.taskId > o2.taskId) {
+ return 1;
+ }
+ return 0;
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/UJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/UJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/UJob.java
new file mode 100644
index 0000000..a6db079
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/UJob.java
@@ -0,0 +1,170 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.compress.DefaultCodec;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Computes U=Q*Uhat of SSVD (optionally adding x pow(Sigma, 0.5) )
+ *
+ */
+public class UJob {
+ private static final String OUTPUT_U = "u";
+ private static final String PROP_UHAT_PATH = "ssvd.uhat.path";
+ private static final String PROP_SIGMA_PATH = "ssvd.sigma.path";
+ private static final String PROP_OUTPUT_SCALING = "ssvd.u.output.scaling";
+ private static final String PROP_K = "ssvd.k";
+
+ private Job job;
+
+ public void run(Configuration conf, Path inputPathQ, Path inputUHatPath,
+ Path sigmaPath, Path outputPath, int k, int numReduceTasks,
+ Class<? extends Writable> labelClass, SSVDSolver.OutputScalingEnum outputScaling)
+ throws ClassNotFoundException, InterruptedException, IOException {
+
+ job = new Job(conf);
+ job.setJobName("U-job");
+ job.setJarByClass(UJob.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ FileInputFormat.setInputPaths(job, inputPathQ);
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ // WARN: tight hadoop integration here:
+ job.getConfiguration().set("mapreduce.output.basename", OUTPUT_U);
+ FileOutputFormat.setCompressOutput(job, true);
+ FileOutputFormat.setOutputCompressorClass(job, DefaultCodec.class);
+ SequenceFileOutputFormat.setOutputCompressionType(job, CompressionType.BLOCK);
+
+ job.setMapperClass(UMapper.class);
+ job.setMapOutputKeyClass(IntWritable.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+
+ job.setOutputKeyClass(labelClass);
+ job.setOutputValueClass(VectorWritable.class);
+
+ job.getConfiguration().set(PROP_UHAT_PATH, inputUHatPath.toString());
+ job.getConfiguration().set(PROP_SIGMA_PATH, sigmaPath.toString());
+ job.getConfiguration().set(PROP_OUTPUT_SCALING, outputScaling.name());
+ job.getConfiguration().setInt(PROP_K, k);
+ job.setNumReduceTasks(0);
+ job.submit();
+
+ }
+
+ public void waitForCompletion() throws IOException, ClassNotFoundException,
+ InterruptedException {
+ job.waitForCompletion(false);
+
+ if (!job.isSuccessful()) {
+ throw new IOException("U job unsuccessful.");
+ }
+
+ }
+
+ public static final class UMapper extends
+ Mapper<Writable, VectorWritable, Writable, VectorWritable> {
+
+ private Matrix uHat;
+ private DenseVector uRow;
+ private VectorWritable uRowWritable;
+ private int kp;
+ private int k;
+ private Vector sValues;
+
+ @Override
+ protected void map(Writable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ Vector qRow = value.get();
+ if (sValues != null) {
+ for (int i = 0; i < k; i++) {
+ uRow.setQuick(i,
+ qRow.dot(uHat.viewColumn(i)) * sValues.getQuick(i));
+ }
+ } else {
+ for (int i = 0; i < k; i++) {
+ uRow.setQuick(i, qRow.dot(uHat.viewColumn(i)));
+ }
+ }
+
+ /*
+ * MAHOUT-1067: inherit A names too.
+ */
+ if (qRow instanceof NamedVector) {
+ uRowWritable.set(new NamedVector(uRow, ((NamedVector) qRow).getName()));
+ } else {
+ uRowWritable.set(uRow);
+ }
+
+ context.write(key, uRowWritable); // U inherits original A row labels.
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ super.setup(context);
+ Path uHatPath = new Path(context.getConfiguration().get(PROP_UHAT_PATH));
+ Path sigmaPath = new Path(context.getConfiguration().get(PROP_SIGMA_PATH));
+ FileSystem fs = FileSystem.get(uHatPath.toUri(), context.getConfiguration());
+
+ uHat = SSVDHelper.drmLoadAsDense(fs, uHatPath, context.getConfiguration());
+ // since uHat is (k+p) x (k+p)
+ kp = uHat.columnSize();
+ k = context.getConfiguration().getInt(PROP_K, kp);
+ uRow = new DenseVector(k);
+ uRowWritable = new VectorWritable(uRow);
+
+ SSVDSolver.OutputScalingEnum outputScaling =
+ SSVDSolver.OutputScalingEnum.valueOf(context.getConfiguration()
+ .get(PROP_OUTPUT_SCALING));
+ switch (outputScaling) {
+ case SIGMA:
+ sValues = SSVDHelper.loadVector(sigmaPath, context.getConfiguration());
+ break;
+ case HALFSIGMA:
+ sValues = SSVDHelper.loadVector(sigmaPath, context.getConfiguration());
+ sValues.assign(Functions.SQRT);
+ break;
+ default:
+ }
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/VJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/VJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/VJob.java
new file mode 100644
index 0000000..daee93d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/VJob.java
@@ -0,0 +1,224 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.compress.DefaultCodec;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.PlusMult;
+
+public class VJob {
+ private static final String OUTPUT_V = "v";
+ private static final String PROP_UHAT_PATH = "ssvd.uhat.path";
+ private static final String PROP_SIGMA_PATH = "ssvd.sigma.path";
+ private static final String PROP_OUTPUT_SCALING = "ssvd.v.output.scaling";
+ private static final String PROP_K = "ssvd.k";
+ public static final String PROP_SQ_PATH = "ssvdpca.sq.path";
+ public static final String PROP_XI_PATH = "ssvdpca.xi.path";
+
+ private Job job;
+
+ public static final class VMapper extends
+ Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> {
+
+ private Matrix uHat;
+ private Vector vRow;
+ private Vector sValues;
+ private VectorWritable vRowWritable;
+ private int kp;
+ private int k;
+ /*
+ * xi and s_q are PCA-related corrections, per MAHOUT-817
+ */
+ private Vector xi;
+ private Vector sq;
+ private final PlusMult plusMult = new PlusMult(0);
+
+ @Override
+ protected void map(IntWritable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ Vector bCol = value.get();
+ /*
+ * MAHOUT-817: PCA correction for B': b_{col=i} -= s_q * xi_{i}
+ */
+ if (xi != null) {
+ /*
+ * code defensively against shortened xi which may be externally
+ * supplied
+ */
+ int btIndex = key.get();
+ double xii = xi.size() > btIndex ? xi.getQuick(btIndex) : 0.0;
+ plusMult.setMultiplicator(-xii);
+ bCol.assign(sq, plusMult);
+ }
+
+ for (int i = 0; i < k; i++) {
+ vRow.setQuick(i, bCol.dot(uHat.viewColumn(i)) / sValues.getQuick(i));
+ }
+ context.write(key, vRowWritable);
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ super.setup(context);
+
+ Configuration conf = context.getConfiguration();
+ FileSystem fs = FileSystem.get(conf);
+ Path uHatPath = new Path(conf.get(PROP_UHAT_PATH));
+
+ Path sigmaPath = new Path(conf.get(PROP_SIGMA_PATH));
+
+ uHat = SSVDHelper.drmLoadAsDense(fs, uHatPath, conf);
+ // since uHat is (k+p) x (k+p)
+ kp = uHat.columnSize();
+ k = context.getConfiguration().getInt(PROP_K, kp);
+ vRow = new DenseVector(k);
+ vRowWritable = new VectorWritable(vRow);
+
+ sValues = SSVDHelper.loadVector(sigmaPath, conf);
+ SSVDSolver.OutputScalingEnum outputScaling =
+ SSVDSolver.OutputScalingEnum.valueOf(context.getConfiguration()
+ .get(PROP_OUTPUT_SCALING));
+ switch (outputScaling) {
+ case SIGMA:
+ sValues.assign(1.0);
+ break;
+ case HALFSIGMA:
+ sValues = SSVDHelper.loadVector(sigmaPath, context.getConfiguration());
+ sValues.assign(Functions.SQRT);
+ break;
+ default:
+ }
+
+ /*
+ * PCA -related corrections (MAHOUT-817)
+ */
+ String xiPathStr = conf.get(PROP_XI_PATH);
+ if (xiPathStr != null) {
+ xi = SSVDHelper.loadAndSumUpVectors(new Path(xiPathStr), conf);
+ sq =
+ SSVDHelper.loadAndSumUpVectors(new Path(conf.get(PROP_SQ_PATH)), conf);
+ }
+
+ }
+
+ }
+
+ /**
+ *
+ * @param conf
+ * @param inputPathBt
+ * @param xiPath
+ * PCA row mean (MAHOUT-817, to fix B')
+ * @param sqPath
+ * sq (MAHOUT-817, to fix B')
+ * @param inputUHatPath
+ * @param inputSigmaPath
+ * @param outputPath
+ * @param k
+ * @param numReduceTasks
+ * @param outputScaling output scaling: apply Sigma, or Sigma^0.5, or none
+ * @throws ClassNotFoundException
+ * @throws InterruptedException
+ * @throws IOException
+ */
+ public void run(Configuration conf,
+ Path inputPathBt,
+ Path xiPath,
+ Path sqPath,
+
+ Path inputUHatPath,
+ Path inputSigmaPath,
+
+ Path outputPath,
+ int k,
+ int numReduceTasks,
+ SSVDSolver.OutputScalingEnum outputScaling) throws ClassNotFoundException,
+ InterruptedException, IOException {
+
+ job = new Job(conf);
+ job.setJobName("V-job");
+ job.setJarByClass(VJob.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ FileInputFormat.setInputPaths(job, inputPathBt);
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ // Warn: tight hadoop integration here:
+ job.getConfiguration().set("mapreduce.output.basename", OUTPUT_V);
+ FileOutputFormat.setCompressOutput(job, true);
+ FileOutputFormat.setOutputCompressorClass(job, DefaultCodec.class);
+ SequenceFileOutputFormat.setOutputCompressionType(job,
+ CompressionType.BLOCK);
+
+ job.setMapOutputKeyClass(IntWritable.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+
+ job.setMapperClass(VMapper.class);
+
+ job.getConfiguration().set(PROP_UHAT_PATH, inputUHatPath.toString());
+ job.getConfiguration().set(PROP_SIGMA_PATH, inputSigmaPath.toString());
+ job.getConfiguration().set(PROP_OUTPUT_SCALING, outputScaling.name());
+ job.getConfiguration().setInt(PROP_K, k);
+ job.setNumReduceTasks(0);
+
+ /*
+ * PCA-related options, MAHOUT-817
+ */
+ if (xiPath != null) {
+ job.getConfiguration().set(PROP_XI_PATH, xiPath.toString());
+ job.getConfiguration().set(PROP_SQ_PATH, sqPath.toString());
+ }
+
+ job.submit();
+
+ }
+
+ public void waitForCompletion() throws IOException, ClassNotFoundException,
+ InterruptedException {
+ job.waitForCompletion(false);
+
+ if (!job.isSuccessful()) {
+ throw new IOException("V job unsuccessful.");
+ }
+
+ }
+
+}
[02/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java b/mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java
new file mode 100644
index 0000000..7ccd6a7
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/iterator/TestStableFixedSizeSampler.java
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator;
+
+import java.util.Iterator;
+
+public final class TestStableFixedSizeSampler extends SamplerCase {
+
+ @Override
+ protected Iterator<Integer> createSampler(int n, Iterator<Integer> source) {
+ return new StableFixedSizeSamplingIterator<Integer>(n, source);
+ }
+
+ @Override
+ protected boolean isSorted() {
+ return true;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java b/mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java
new file mode 100644
index 0000000..f94d63e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/common/lucene/AnalyzerUtilsTest.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.lucene;
+
+import org.apache.lucene.analysis.cjk.CJKAnalyzer;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.junit.Test;
+
+import static org.junit.Assert.assertNotNull;
+
+public class AnalyzerUtilsTest {
+
+ @Test
+ public void createStandardAnalyzer() throws Exception {
+ assertNotNull(AnalyzerUtils.createAnalyzer(StandardAnalyzer.class.getName()));
+ }
+
+ @Test
+ public void createCJKAnalyzer() throws Exception {
+ assertNotNull(AnalyzerUtils.createAnalyzer(CJKAnalyzer.class.getName()));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java b/mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java
new file mode 100644
index 0000000..e0bdc98
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/driver/MahoutDriverTest.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.driver;
+
+import org.junit.Test;
+
+/**
+ * Tests if MahoutDriver can be run directly through its main method.
+ */
+public final class MahoutDriverTest {
+
+ @Test
+ public void testMain() throws Throwable {
+ MahoutDriver.main(new String[] {"canopy", "help"});
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java b/mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java
new file mode 100644
index 0000000..e53db7e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/ep/EvolutionaryProcessTest.java
@@ -0,0 +1,81 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.ep;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+public final class EvolutionaryProcessTest extends MahoutTestCase {
+
+ @Test
+ public void testConverges() throws Exception {
+ State<Foo, Double> s0 = new State<Foo, Double>(new double[5], 1);
+ s0.setPayload(new Foo());
+ EvolutionaryProcess<Foo, Double> ep = new EvolutionaryProcess<Foo, Double>(10, 100, s0);
+
+ State<Foo, Double> best = null;
+ for (int i = 0; i < 20; i++) {
+ best = ep.parallelDo(new EvolutionaryProcess.Function<Payload<Double>>() {
+ @Override
+ public double apply(Payload<Double> payload, double[] params) {
+ int i = 1;
+ double sum = 0;
+ for (double x : params) {
+ sum += i * (x - i) * (x - i);
+ i++;
+ }
+ return -sum;
+ }
+ });
+
+ ep.mutatePopulation(3);
+
+ System.out.printf("%10.3f %.3f\n", best.getValue(), best.getOmni());
+ }
+
+ ep.close();
+ assertNotNull(best);
+ assertEquals(0.0, best.getValue(), 0.02);
+ }
+
+ private static class Foo implements Payload<Double> {
+ @Override
+ public Foo copy() {
+ return this;
+ }
+
+ @Override
+ public void update(double[] params) {
+ // ignore
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ // no-op
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ // no-op
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java b/mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
new file mode 100644
index 0000000..226d4b1
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/math/MatrixWritableTest.java
@@ -0,0 +1,148 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Writable;
+import org.junit.Test;
+
+public final class MatrixWritableTest extends MahoutTestCase {
+
+ @Test
+ public void testSparseMatrixWritable() throws Exception {
+ Matrix m = new SparseMatrix(5, 5);
+ m.set(1, 2, 3.0);
+ m.set(3, 4, 5.0);
+ Map<String, Integer> bindings = Maps.newHashMap();
+ bindings.put("A", 0);
+ bindings.put("B", 1);
+ bindings.put("C", 2);
+ bindings.put("D", 3);
+ bindings.put("default", 4);
+ m.setRowLabelBindings(bindings);
+ m.setColumnLabelBindings(bindings);
+ doTestMatrixWritableEquals(m);
+ }
+
+ @Test
+ public void testSparseRowMatrixWritable() throws Exception {
+ Matrix m = new SparseRowMatrix(5, 5);
+ m.set(1, 2, 3.0);
+ m.set(3, 4, 5.0);
+ Map<String, Integer> bindings = Maps.newHashMap();
+ bindings.put("A", 0);
+ bindings.put("B", 1);
+ bindings.put("C", 2);
+ bindings.put("D", 3);
+ bindings.put("default", 4);
+ m.setRowLabelBindings(bindings);
+ m.setColumnLabelBindings(bindings);
+ doTestMatrixWritableEquals(m);
+ }
+
+ @Test
+ public void testDenseMatrixWritable() throws Exception {
+ Matrix m = new DenseMatrix(5,5);
+ m.set(1, 2, 3.0);
+ m.set(3, 4, 5.0);
+ Map<String, Integer> bindings = Maps.newHashMap();
+ bindings.put("A", 0);
+ bindings.put("B", 1);
+ bindings.put("C", 2);
+ bindings.put("D", 3);
+ bindings.put("default", 4);
+ m.setRowLabelBindings(bindings);
+ m.setColumnLabelBindings(bindings);
+ doTestMatrixWritableEquals(m);
+ }
+
+ private static void doTestMatrixWritableEquals(Matrix m) throws IOException {
+ Writable matrixWritable = new MatrixWritable(m);
+ MatrixWritable matrixWritable2 = new MatrixWritable();
+ writeAndRead(matrixWritable, matrixWritable2);
+ Matrix m2 = matrixWritable2.get();
+ compareMatrices(m, m2);
+ doCheckBindings(m2.getRowLabelBindings());
+ doCheckBindings(m2.getColumnLabelBindings());
+ }
+
+ private static void compareMatrices(Matrix m, Matrix m2) {
+ assertEquals(m.numRows(), m2.numRows());
+ assertEquals(m.numCols(), m2.numCols());
+ for (int r = 0; r < m.numRows(); r++) {
+ for (int c = 0; c < m.numCols(); c++) {
+ assertEquals(m.get(r, c), m2.get(r, c), EPSILON);
+ }
+ }
+ Map<String,Integer> bindings = m.getRowLabelBindings();
+ Map<String, Integer> bindings2 = m2.getRowLabelBindings();
+ assertEquals(bindings == null, bindings2 == null);
+ if (bindings != null) {
+ assertEquals(bindings.size(), m.numRows());
+ assertEquals(bindings.size(), bindings2.size());
+ for (Map.Entry<String,Integer> entry : bindings.entrySet()) {
+ assertEquals(entry.getValue(), bindings2.get(entry.getKey()));
+ }
+ }
+ bindings = m.getColumnLabelBindings();
+ bindings2 = m2.getColumnLabelBindings();
+ assertEquals(bindings == null, bindings2 == null);
+ if (bindings != null) {
+ assertEquals(bindings.size(), bindings2.size());
+ for (Map.Entry<String,Integer> entry : bindings.entrySet()) {
+ assertEquals(entry.getValue(), bindings2.get(entry.getKey()));
+ }
+ }
+ }
+
+ private static void doCheckBindings(Map<String,Integer> labels) {
+ assertTrue("Missing label", labels.keySet().contains("A"));
+ assertTrue("Missing label", labels.keySet().contains("B"));
+ assertTrue("Missing label", labels.keySet().contains("C"));
+ assertTrue("Missing label", labels.keySet().contains("D"));
+ assertTrue("Missing label", labels.keySet().contains("default"));
+ }
+
+ private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(baos);
+ try {
+ toWrite.write(dos);
+ } finally {
+ Closeables.close(dos, false);
+ }
+
+ ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
+ DataInputStream dis = new DataInputStream(bais);
+ try {
+ toRead.readFields(dis);
+ } finally {
+ Closeables.close(dis, true);
+ }
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/VarintTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/math/VarintTest.java b/mr/src/test/java/org/apache/mahout/math/VarintTest.java
new file mode 100644
index 0000000..0b1a664
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/math/VarintTest.java
@@ -0,0 +1,189 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math;
+
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+
+/**
+ * Tests {@link Varint}.
+ */
+public final class VarintTest extends MahoutTestCase {
+
+ @Test
+ public void testUnsignedLong() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ Varint.writeUnsignedVarLong(0L, out);
+ for (long i = 1L; i > 0L && i <= (1L << 62); i <<= 1) {
+ Varint.writeUnsignedVarLong(i-1, out);
+ Varint.writeUnsignedVarLong(i, out);
+ }
+ Varint.writeUnsignedVarLong(Long.MAX_VALUE, out);
+
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ assertEquals(0L, Varint.readUnsignedVarLong(in));
+ for (long i = 1L; i > 0L && i <= (1L << 62); i <<= 1) {
+ assertEquals(i-1, Varint.readUnsignedVarLong(in));
+ assertEquals(i, Varint.readUnsignedVarLong(in));
+ }
+ assertEquals(Long.MAX_VALUE, Varint.readUnsignedVarLong(in));
+ }
+
+ @Test
+ public void testSignedPositiveLong() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ Varint.writeSignedVarLong(0L, out);
+ for (long i = 1L; i <= (1L << 61); i <<= 1) {
+ Varint.writeSignedVarLong(i-1, out);
+ Varint.writeSignedVarLong(i, out);
+ }
+ Varint.writeSignedVarLong((1L << 62) - 1, out);
+ Varint.writeSignedVarLong((1L << 62), out);
+ Varint.writeSignedVarLong(Long.MAX_VALUE, out);
+
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ assertEquals(0L, Varint.readSignedVarLong(in));
+ for (long i = 1L; i <= (1L << 61); i <<= 1) {
+ assertEquals(i-1, Varint.readSignedVarLong(in));
+ assertEquals(i, Varint.readSignedVarLong(in));
+ }
+ assertEquals((1L << 62) - 1, Varint.readSignedVarLong(in));
+ assertEquals((1L << 62), Varint.readSignedVarLong(in));
+ assertEquals(Long.MAX_VALUE, Varint.readSignedVarLong(in));
+ }
+
+ @Test
+ public void testSignedNegativeLong() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ for (long i = -1L; i >= -(1L << 62); i <<= 1) {
+ Varint.writeSignedVarLong(i, out);
+ Varint.writeSignedVarLong(i+1, out);
+ }
+ Varint.writeSignedVarLong(Long.MIN_VALUE, out);
+ Varint.writeSignedVarLong(Long.MIN_VALUE+1, out);
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ for (long i = -1L; i >= -(1L << 62); i <<= 1) {
+ assertEquals(i, Varint.readSignedVarLong(in));
+ assertEquals(i+1, Varint.readSignedVarLong(in));
+ }
+ assertEquals(Long.MIN_VALUE, Varint.readSignedVarLong(in));
+ assertEquals(Long.MIN_VALUE+1, Varint.readSignedVarLong(in));
+ }
+
+ @Test
+ public void testUnsignedInt() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ Varint.writeUnsignedVarInt(0, out);
+ for (int i = 1; i > 0 && i <= (1 << 30); i <<= 1) {
+ Varint.writeUnsignedVarLong(i-1, out);
+ Varint.writeUnsignedVarLong(i, out);
+ }
+ Varint.writeUnsignedVarLong(Integer.MAX_VALUE, out);
+
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ assertEquals(0, Varint.readUnsignedVarInt(in));
+ for (int i = 1; i > 0 && i <= (1 << 30); i <<= 1) {
+ assertEquals(i-1, Varint.readUnsignedVarInt(in));
+ assertEquals(i, Varint.readUnsignedVarInt(in));
+ }
+ assertEquals(Integer.MAX_VALUE, Varint.readUnsignedVarInt(in));
+ }
+
+ @Test
+ public void testSignedPositiveInt() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ Varint.writeSignedVarInt(0, out);
+ for (int i = 1; i <= (1 << 29); i <<= 1) {
+ Varint.writeSignedVarLong(i-1, out);
+ Varint.writeSignedVarLong(i, out);
+ }
+ Varint.writeSignedVarInt((1 << 30) - 1, out);
+ Varint.writeSignedVarInt((1 << 30), out);
+ Varint.writeSignedVarInt(Integer.MAX_VALUE, out);
+
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ assertEquals(0, Varint.readSignedVarInt(in));
+ for (int i = 1; i <= (1 << 29); i <<= 1) {
+ assertEquals(i-1, Varint.readSignedVarInt(in));
+ assertEquals(i, Varint.readSignedVarInt(in));
+ }
+ assertEquals((1L << 30) - 1, Varint.readSignedVarInt(in));
+ assertEquals((1L << 30), Varint.readSignedVarInt(in));
+ assertEquals(Integer.MAX_VALUE, Varint.readSignedVarInt(in));
+ }
+
+ @Test
+ public void testSignedNegativeInt() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ for (int i = -1; i >= -(1 << 30); i <<= 1) {
+ Varint.writeSignedVarInt(i, out);
+ Varint.writeSignedVarInt(i+1, out);
+ }
+ Varint.writeSignedVarInt(Integer.MIN_VALUE, out);
+ Varint.writeSignedVarInt(Integer.MIN_VALUE+1, out);
+ DataInput in = new DataInputStream(new ByteArrayInputStream(baos.toByteArray()));
+ for (int i = -1; i >= -(1 << 30); i <<= 1) {
+ assertEquals(i, Varint.readSignedVarInt(in));
+ assertEquals(i+1, Varint.readSignedVarInt(in));
+ }
+ assertEquals(Integer.MIN_VALUE, Varint.readSignedVarInt(in));
+ assertEquals(Integer.MIN_VALUE+1, Varint.readSignedVarInt(in));
+ }
+
+ @Test
+ public void testUnsignedSize() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ int expectedSize = 0;
+ for (int exponent = 0; exponent <= 62; exponent++) {
+ Varint.writeUnsignedVarLong(1L << exponent, out);
+ expectedSize += 1 + exponent / 7;
+ assertEquals(expectedSize, baos.size());
+ }
+ }
+
+ @Test
+ public void testSignedSize() throws Exception {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutput out = new DataOutputStream(baos);
+ int expectedSize = 0;
+ for (int exponent = 0; exponent <= 61; exponent++) {
+ Varint.writeSignedVarLong(1L << exponent, out);
+ expectedSize += 1 + ((exponent + 1) / 7);
+ assertEquals(expectedSize, baos.size());
+ }
+ for (int exponent = 0; exponent <= 61; exponent++) {
+ Varint.writeSignedVarLong(-(1L << exponent)-1, out);
+ expectedSize += 1 + ((exponent + 1) / 7);
+ assertEquals(expectedSize, baos.size());
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java b/mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java
new file mode 100644
index 0000000..60fb8b4
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/math/VectorWritableTest.java
@@ -0,0 +1,123 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more contributor license
+ * agreements. See the NOTICE file distributed with this work for additional information regarding
+ * copyright ownership. The ASF licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License. You may obtain a
+ * copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License
+ * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+ * or implied. See the License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package org.apache.mahout.math;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.math.Vector.Element;
+import org.junit.Test;
+
+import com.carrotsearch.randomizedtesting.RandomizedTest;
+import com.carrotsearch.randomizedtesting.annotations.Repeat;
+import com.google.common.io.Closeables;
+
+public final class VectorWritableTest extends RandomizedTest {
+ private static final int MAX_VECTOR_SIZE = 100;
+
+ public void createRandom(Vector v) {
+ int size = randomInt(v.size() - 1);
+ for (int i = 0; i < size; ++i) {
+ v.set(randomInt(v.size() - 1), randomDouble());
+ }
+
+ int zeros = Math.max(2, size / 4);
+ for (Element e : v.nonZeroes()) {
+ if (e.index() % zeros == 0) {
+ e.set(0.0);
+ }
+ }
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testViewSequentialAccessSparseVectorWritable() throws Exception {
+ Vector v = new SequentialAccessSparseVector(MAX_VECTOR_SIZE);
+ createRandom(v);
+ Vector view = new VectorView(v, 0, v.size());
+ doTestVectorWritableEquals(view);
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testSequentialAccessSparseVectorWritable() throws Exception {
+ Vector v = new SequentialAccessSparseVector(MAX_VECTOR_SIZE);
+ createRandom(v);
+ doTestVectorWritableEquals(v);
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testRandomAccessSparseVectorWritable() throws Exception {
+ Vector v = new RandomAccessSparseVector(MAX_VECTOR_SIZE);
+ createRandom(v);
+ doTestVectorWritableEquals(v);
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testDenseVectorWritable() throws Exception {
+ Vector v = new DenseVector(MAX_VECTOR_SIZE);
+ createRandom(v);
+ doTestVectorWritableEquals(v);
+ }
+
+ @Test
+ @Repeat(iterations = 20)
+ public void testNamedVectorWritable() throws Exception {
+ Vector v = new DenseVector(MAX_VECTOR_SIZE);
+ v = new NamedVector(v, "Victor");
+ createRandom(v);
+ doTestVectorWritableEquals(v);
+ }
+
+ private static void doTestVectorWritableEquals(Vector v) throws IOException {
+ Writable vectorWritable = new VectorWritable(v);
+ VectorWritable vectorWritable2 = new VectorWritable();
+ writeAndRead(vectorWritable, vectorWritable2);
+ Vector v2 = vectorWritable2.get();
+ if (v instanceof NamedVector) {
+ assertTrue(v2 instanceof NamedVector);
+ NamedVector nv = (NamedVector) v;
+ NamedVector nv2 = (NamedVector) v2;
+ assertEquals(nv.getName(), nv2.getName());
+ assertEquals("Victor", nv.getName());
+ }
+ assertEquals(v, v2);
+ }
+
+ private static void writeAndRead(Writable toWrite, Writable toRead) throws IOException {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ DataOutputStream dos = new DataOutputStream(baos);
+ try {
+ toWrite.write(dos);
+ } finally {
+ Closeables.close(dos, false);
+ }
+
+ ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
+ DataInputStream dis = new DataInputStream(bais);
+ try {
+ toRead.readFields(dis);
+ } finally {
+ Closeables.close(dos, true);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java b/mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java
new file mode 100644
index 0000000..a23f7b4
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/math/hadoop/MathHelper.java
@@ -0,0 +1,236 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop;
+
+import java.io.IOException;
+import java.text.DecimalFormat;
+import java.text.DecimalFormatSymbols;
+import java.util.Locale;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+import org.easymock.EasyMock;
+import org.easymock.IArgumentMatcher;
+import org.junit.Assert;
+
+/**
+ * a collection of small helper methods useful for unit-testing mathematical operations
+ */
+public final class MathHelper {
+
+ private MathHelper() {}
+
+ /**
+ * convenience method to create a {@link Vector.Element}
+ */
+ public static Vector.Element elem(int index, double value) {
+ return new ElementToCheck(index, value);
+ }
+
+ /**
+ * a simple implementation of {@link Vector.Element}
+ */
+ static class ElementToCheck implements Vector.Element {
+ private final int index;
+ private double value;
+
+ ElementToCheck(int index, double value) {
+ this.index = index;
+ this.value = value;
+ }
+ @Override
+ public double get() {
+ return value;
+ }
+ @Override
+ public int index() {
+ return index;
+ }
+ @Override
+ public void set(double value) {
+ this.value = value;
+ }
+ }
+
+ /**
+ * applies an {@link IArgumentMatcher} to a {@link VectorWritable} that checks whether all elements are included
+ */
+ public static VectorWritable vectorMatches(final Vector.Element... elements) {
+ EasyMock.reportMatcher(new IArgumentMatcher() {
+ @Override
+ public boolean matches(Object argument) {
+ if (argument instanceof VectorWritable) {
+ Vector v = ((VectorWritable) argument).get();
+ return consistsOf(v, elements);
+ }
+ return false;
+ }
+
+ @Override
+ public void appendTo(StringBuffer buffer) {}
+ });
+ return null;
+ }
+
+ /**
+ * checks whether the {@link Vector} is equivalent to the set of {@link Vector.Element}s
+ */
+ public static boolean consistsOf(Vector vector, Vector.Element... elements) {
+ if (elements.length != numberOfNoNZeroNonNaNElements(vector)) {
+ return false;
+ }
+ for (Vector.Element element : elements) {
+ if (Math.abs(element.get() - vector.get(element.index())) > MahoutTestCase.EPSILON) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * returns the number of elements in the {@link Vector} that are neither 0 nor NaN
+ */
+ public static int numberOfNoNZeroNonNaNElements(Vector vector) {
+ int elementsInVector = 0;
+ for (Element currentElement : vector.nonZeroes()) {
+ if (!Double.isNaN(currentElement.get())) {
+ elementsInVector++;
+ }
+ }
+ return elementsInVector;
+ }
+
+ /**
+ * read a {@link Matrix} from a SequenceFile<IntWritable,VectorWritable>
+ */
+ public static Matrix readMatrix(Configuration conf, Path path, int rows, int columns) {
+ boolean readOneRow = false;
+ Matrix matrix = new DenseMatrix(rows, columns);
+ for (Pair<IntWritable,VectorWritable> record :
+ new SequenceFileIterable<IntWritable,VectorWritable>(path, true, conf)) {
+ IntWritable key = record.getFirst();
+ VectorWritable value = record.getSecond();
+ readOneRow = true;
+ int row = key.get();
+ for (Element element : value.get().nonZeroes()) {
+ matrix.set(row, element.index(), element.get());
+ }
+ }
+ if (!readOneRow) {
+ throw new IllegalStateException("Not a single row read!");
+ }
+ return matrix;
+ }
+
+ /**
+ * read a {@link Matrix} from a SequenceFile<IntWritable,VectorWritable>
+ */
+ public static OpenIntObjectHashMap<Vector> readMatrixRows(Configuration conf, Path path) {
+ boolean readOneRow = false;
+ OpenIntObjectHashMap<Vector> rows = new OpenIntObjectHashMap<Vector>();
+ for (Pair<IntWritable,VectorWritable> record :
+ new SequenceFileIterable<IntWritable,VectorWritable>(path, true, conf)) {
+ IntWritable key = record.getFirst();
+ readOneRow = true;
+ rows.put(key.get(), record.getSecond().get());
+ }
+ if (!readOneRow) {
+ throw new IllegalStateException("Not a single row read!");
+ }
+ return rows;
+ }
+
+ /**
+ * write a two-dimensional double array to an SequenceFile<IntWritable,VectorWritable>
+ */
+ public static void writeDistributedRowMatrix(double[][] entries, FileSystem fs, Configuration conf, Path path)
+ throws IOException {
+ SequenceFile.Writer writer = null;
+ try {
+ writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class);
+ for (int n = 0; n < entries.length; n++) {
+ Vector v = new RandomAccessSparseVector(entries[n].length);
+ for (int m = 0; m < entries[n].length; m++) {
+ v.setQuick(m, entries[n][m]);
+ }
+ writer.append(new IntWritable(n), new VectorWritable(v));
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ public static void assertMatrixEquals(Matrix expected, Matrix actual) {
+ Assert.assertEquals(expected.numRows(), actual.numRows());
+ Assert.assertEquals(actual.numCols(), actual.numCols());
+ for (int row = 0; row < expected.numRows(); row++) {
+ for (int col = 0; col < expected.numCols(); col ++) {
+ Assert.assertEquals("Non-matching values in [" + row + ',' + col + ']',
+ expected.get(row, col), actual.get(row, col), MahoutTestCase.EPSILON);
+ }
+ }
+ }
+
+ public static String nice(Vector v) {
+ if (!v.isSequentialAccess()) {
+ v = new DenseVector(v);
+ }
+
+ DecimalFormat df = new DecimalFormat("0.00", DecimalFormatSymbols.getInstance(Locale.ENGLISH));
+
+ StringBuilder buffer = new StringBuilder("[");
+ String separator = "";
+ for (Vector.Element e : v.all()) {
+ buffer.append(separator);
+ if (Double.isNaN(e.get())) {
+ buffer.append(" - ");
+ } else {
+ if (e.get() >= 0) {
+ buffer.append(' ');
+ }
+ buffer.append(df.format(e.get()));
+ }
+ separator = "\t";
+ }
+ buffer.append(" ]");
+ return buffer.toString();
+ }
+
+ public static String nice(Matrix matrix) {
+ StringBuilder info = new StringBuilder();
+ for (int n = 0; n < matrix.numRows(); n++) {
+ info.append(nice(matrix.viewRow(n))).append('\n');
+ }
+ return info.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java b/mr/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java
new file mode 100644
index 0000000..13da38a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java
@@ -0,0 +1,395 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.decomposer.SolverTest;
+import org.apache.mahout.math.function.Functions;
+import org.junit.Test;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterators;
+import com.google.common.collect.Maps;
+
+public final class TestDistributedRowMatrix extends MahoutTestCase {
+ public static final String TEST_PROPERTY_KEY = "test.property.key";
+ public static final String TEST_PROPERTY_VALUE = "test.property.value";
+
+ private static void assertEquals(VectorIterable m, VectorIterable mtt, double errorTolerance) {
+ Iterator<MatrixSlice> mIt = m.iterateAll();
+ Iterator<MatrixSlice> mttIt = mtt.iterateAll();
+ Map<Integer, Vector> mMap = Maps.newHashMap();
+ Map<Integer, Vector> mttMap = Maps.newHashMap();
+ while (mIt.hasNext() && mttIt.hasNext()) {
+ MatrixSlice ms = mIt.next();
+ mMap.put(ms.index(), ms.vector());
+ MatrixSlice mtts = mttIt.next();
+ mttMap.put(mtts.index(), mtts.vector());
+ }
+ for (Map.Entry<Integer, Vector> entry : mMap.entrySet()) {
+ Integer key = entry.getKey();
+ Vector value = entry.getValue();
+ if (value == null || mttMap.get(key) == null) {
+ assertTrue(value == null || value.norm(2) == 0);
+ assertTrue(mttMap.get(key) == null || mttMap.get(key).norm(2) == 0);
+ } else {
+ assertTrue(
+ value.getDistanceSquared(mttMap.get(key)) < errorTolerance);
+ }
+ }
+ }
+
+ @Test
+ public void testTranspose() throws Exception {
+ DistributedRowMatrix m = randomDistributedMatrix(10, 9, 5, 4, 1.0, false);
+ m.setConf(getConfiguration());
+ DistributedRowMatrix mt = m.transpose();
+ mt.setConf(getConfiguration());
+
+ Path tmpPath = getTestTempDirPath();
+ m.setOutputTempPathString(tmpPath.toString());
+ Path tmpOutPath = new Path(tmpPath, "/tmpOutTranspose");
+ mt.setOutputTempPathString(tmpOutPath.toString());
+ HadoopUtil.delete(getConfiguration(), tmpOutPath);
+ DistributedRowMatrix mtt = mt.transpose();
+ assertEquals(m, mtt, EPSILON);
+ }
+
+ @Test
+ public void testMatrixColumnMeansJob() throws Exception {
+ Matrix m =
+ SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0);
+ DistributedRowMatrix dm =
+ randomDistributedMatrix(100, 90, 50, 20, 1.0, false);
+ dm.setConf(getConfiguration());
+
+ Vector expected = new DenseVector(50);
+ for (int i = 0; i < m.numRows(); i++) {
+ expected.assign(m.viewRow(i), Functions.PLUS);
+ }
+ expected.assign(Functions.DIV, m.numRows());
+ Vector actual = dm.columnMeans("DenseVector");
+ assertEquals(0.0, expected.getDistanceSquared(actual), EPSILON);
+ }
+
+ @Test
+ public void testNullMatrixColumnMeansJob() throws Exception {
+ Matrix m =
+ SolverTest.randomSequentialAccessSparseMatrix(100, 90, 0, 0, 1.0);
+ DistributedRowMatrix dm =
+ randomDistributedMatrix(100, 90, 0, 0, 1.0, false);
+ dm.setConf(getConfiguration());
+
+ Vector expected = new DenseVector(0);
+ for (int i = 0; i < m.numRows(); i++) {
+ expected.assign(m.viewRow(i), Functions.PLUS);
+ }
+ expected.assign(Functions.DIV, m.numRows());
+ Vector actual = dm.columnMeans();
+ assertEquals(0.0, expected.getDistanceSquared(actual), EPSILON);
+ }
+
+ @Test
+ public void testMatrixTimesVector() throws Exception {
+ Vector v = new RandomAccessSparseVector(50);
+ v.assign(1.0);
+ Matrix m = SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0);
+ DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, false);
+ dm.setConf(getConfiguration());
+
+ Vector expected = m.times(v);
+ Vector actual = dm.times(v);
+ assertEquals(0.0, expected.getDistanceSquared(actual), EPSILON);
+ }
+
+ @Test
+ public void testMatrixTimesSquaredVector() throws Exception {
+ Vector v = new RandomAccessSparseVector(50);
+ v.assign(1.0);
+ Matrix m = SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0);
+ DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, false);
+ dm.setConf(getConfiguration());
+
+ Vector expected = m.timesSquared(v);
+ Vector actual = dm.timesSquared(v);
+ assertEquals(0.0, expected.getDistanceSquared(actual), 1.0e-9);
+ }
+
+ @Test
+ public void testMatrixTimesMatrix() throws Exception {
+ Matrix inputA = SolverTest.randomSequentialAccessSparseMatrix(20, 19, 15, 5, 10.0);
+ Matrix inputB = SolverTest.randomSequentialAccessSparseMatrix(20, 13, 25, 10, 5.0);
+ Matrix expected = inputA.transpose().times(inputB);
+
+ DistributedRowMatrix distA = randomDistributedMatrix(20, 19, 15, 5, 10.0, false, "distA");
+ distA.setConf(getConfiguration());
+ DistributedRowMatrix distB = randomDistributedMatrix(20, 13, 25, 10, 5.0, false, "distB");
+ distB.setConf(getConfiguration());
+ DistributedRowMatrix product = distA.times(distB);
+
+ assertEquals(expected, product, EPSILON);
+ }
+
+ @Test
+ public void testMatrixMultiplactionJobConfBuilder() throws Exception {
+ Configuration initialConf = createInitialConf();
+
+ Path baseTmpDirPath = getTestTempDirPath("testpaths");
+ Path aPath = new Path(baseTmpDirPath, "a");
+ Path bPath = new Path(baseTmpDirPath, "b");
+ Path outPath = new Path(baseTmpDirPath, "out");
+
+ Configuration mmJobConf = MatrixMultiplicationJob.createMatrixMultiplyJobConf(aPath, bPath, outPath, 10);
+ Configuration mmCustomJobConf = MatrixMultiplicationJob.createMatrixMultiplyJobConf(initialConf,
+ aPath,
+ bPath,
+ outPath,
+ 10);
+
+ assertNull(mmJobConf.get(TEST_PROPERTY_KEY));
+ assertEquals(TEST_PROPERTY_VALUE, mmCustomJobConf.get(TEST_PROPERTY_KEY));
+ }
+
+ @Test
+ public void testTransposeJobConfBuilder() throws Exception {
+ Configuration initialConf = createInitialConf();
+
+ Path baseTmpDirPath = getTestTempDirPath("testpaths");
+ Path inputPath = new Path(baseTmpDirPath, "input");
+ Path outputPath = new Path(baseTmpDirPath, "output");
+
+ Configuration transposeJobConf = TransposeJob.buildTransposeJob(inputPath, outputPath, 10).getConfiguration();
+
+ Configuration transposeCustomJobConf = TransposeJob.buildTransposeJob(initialConf, inputPath, outputPath, 10)
+ .getConfiguration();
+
+ assertNull(transposeJobConf.get(TEST_PROPERTY_KEY));
+ assertEquals(TEST_PROPERTY_VALUE, transposeCustomJobConf.get(TEST_PROPERTY_KEY));
+ }
+
+ @Test public void testTimesSquaredJobConfBuilders() throws Exception {
+ Configuration initialConf = createInitialConf();
+
+ Path baseTmpDirPath = getTestTempDirPath("testpaths");
+ Path inputPath = new Path(baseTmpDirPath, "input");
+ Path outputPath = new Path(baseTmpDirPath, "output");
+
+ Vector v = new RandomAccessSparseVector(50);
+ v.assign(1.0);
+
+ Job timesSquaredJob1 = TimesSquaredJob.createTimesSquaredJob(v, inputPath, outputPath);
+ Job customTimesSquaredJob1 = TimesSquaredJob.createTimesSquaredJob(initialConf, v, inputPath, outputPath);
+
+ assertNull(timesSquaredJob1.getConfiguration().get(TEST_PROPERTY_KEY));
+ assertEquals(TEST_PROPERTY_VALUE, customTimesSquaredJob1.getConfiguration().get(TEST_PROPERTY_KEY));
+
+ Job timesJob = TimesSquaredJob.createTimesJob(v, 50, inputPath, outputPath);
+ Job customTimesJob = TimesSquaredJob.createTimesJob(initialConf, v, 50, inputPath, outputPath);
+
+ assertNull(timesJob.getConfiguration().get(TEST_PROPERTY_KEY));
+ assertEquals(TEST_PROPERTY_VALUE, customTimesJob.getConfiguration().get(TEST_PROPERTY_KEY));
+
+ Job timesSquaredJob2 = TimesSquaredJob.createTimesSquaredJob(v, inputPath, outputPath,
+ TimesSquaredJob.TimesSquaredMapper.class, TimesSquaredJob.VectorSummingReducer.class);
+
+ Job customTimesSquaredJob2 = TimesSquaredJob.createTimesSquaredJob(initialConf, v, inputPath,
+ outputPath, TimesSquaredJob.TimesSquaredMapper.class, TimesSquaredJob.VectorSummingReducer.class);
+
+ assertNull(timesSquaredJob2.getConfiguration().get(TEST_PROPERTY_KEY));
+ assertEquals(TEST_PROPERTY_VALUE, customTimesSquaredJob2.getConfiguration().get(TEST_PROPERTY_KEY));
+
+ Job timesSquaredJob3 = TimesSquaredJob.createTimesSquaredJob(v, 50, inputPath, outputPath,
+ TimesSquaredJob.TimesSquaredMapper.class, TimesSquaredJob.VectorSummingReducer.class);
+
+ Job customTimesSquaredJob3 = TimesSquaredJob.createTimesSquaredJob(initialConf,
+ v, 50, inputPath, outputPath, TimesSquaredJob.TimesSquaredMapper.class,
+ TimesSquaredJob.VectorSummingReducer.class);
+
+ assertNull(timesSquaredJob3.getConfiguration().get(TEST_PROPERTY_KEY));
+ assertEquals(TEST_PROPERTY_VALUE, customTimesSquaredJob3.getConfiguration().get(TEST_PROPERTY_KEY));
+ }
+
+ @Test
+ public void testTimesVectorTempDirDeletion() throws Exception {
+ Configuration conf = getConfiguration();
+ Vector v = new RandomAccessSparseVector(50);
+ v.assign(1.0);
+ DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, false);
+ dm.setConf(conf);
+
+ Path outputPath = dm.getOutputTempPath();
+ FileSystem fs = outputPath.getFileSystem(conf);
+
+ deleteContentsOfPath(conf, outputPath);
+
+ assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length);
+
+ Vector result1 = dm.times(v);
+
+ assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length);
+
+ deleteContentsOfPath(conf, outputPath);
+ assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length);
+
+ conf.setBoolean(DistributedRowMatrix.KEEP_TEMP_FILES, true);
+ dm.setConf(conf);
+
+ Vector result2 = dm.times(v);
+
+ FileStatus[] outputStatuses = fs.listStatus(outputPath);
+ assertEquals(1, outputStatuses.length);
+ Path outputTempPath = outputStatuses[0].getPath();
+ Path inputVectorPath = new Path(outputTempPath, TimesSquaredJob.INPUT_VECTOR);
+ Path outputVectorPath = new Path(outputTempPath, TimesSquaredJob.OUTPUT_VECTOR_FILENAME);
+ assertEquals(1, fs.listStatus(inputVectorPath, PathFilters.logsCRCFilter()).length);
+ assertEquals(1, fs.listStatus(outputVectorPath, PathFilters.logsCRCFilter()).length);
+
+ assertEquals(0.0, result1.getDistanceSquared(result2), EPSILON);
+ }
+
+ @Test
+ public void testTimesSquaredVectorTempDirDeletion() throws Exception {
+ Configuration conf = getConfiguration();
+ Vector v = new RandomAccessSparseVector(50);
+ v.assign(1.0);
+ DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, false);
+ dm.setConf(getConfiguration());
+
+ Path outputPath = dm.getOutputTempPath();
+ FileSystem fs = outputPath.getFileSystem(conf);
+
+ deleteContentsOfPath(conf, outputPath);
+
+ assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length);
+
+ Vector result1 = dm.timesSquared(v);
+
+ assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length);
+
+ deleteContentsOfPath(conf, outputPath);
+ assertEquals(0, HadoopUtil.listStatus(fs, outputPath).length);
+
+ conf.setBoolean(DistributedRowMatrix.KEEP_TEMP_FILES, true);
+ dm.setConf(conf);
+
+ Vector result2 = dm.timesSquared(v);
+
+ FileStatus[] outputStatuses = fs.listStatus(outputPath);
+ assertEquals(1, outputStatuses.length);
+ Path outputTempPath = outputStatuses[0].getPath();
+ Path inputVectorPath = new Path(outputTempPath, TimesSquaredJob.INPUT_VECTOR);
+ Path outputVectorPath = new Path(outputTempPath, TimesSquaredJob.OUTPUT_VECTOR_FILENAME);
+ assertEquals(1, fs.listStatus(inputVectorPath, PathFilters.logsCRCFilter()).length);
+ assertEquals(1, fs.listStatus(outputVectorPath, PathFilters.logsCRCFilter()).length);
+
+ assertEquals(0.0, result1.getDistanceSquared(result2), EPSILON);
+ }
+
+ public Configuration createInitialConf() throws IOException {
+ Configuration initialConf = getConfiguration();
+ initialConf.set(TEST_PROPERTY_KEY, TEST_PROPERTY_VALUE);
+ return initialConf;
+ }
+
+ private static void deleteContentsOfPath(Configuration conf, Path path) throws Exception {
+ FileSystem fs = path.getFileSystem(conf);
+
+ FileStatus[] statuses = HadoopUtil.listStatus(fs, path);
+ for (FileStatus status : statuses) {
+ fs.delete(status.getPath(), true);
+ }
+ }
+
+ public DistributedRowMatrix randomDistributedMatrix(int numRows,
+ int nonNullRows,
+ int numCols,
+ int entriesPerRow,
+ double entryMean,
+ boolean isSymmetric) throws IOException {
+ return randomDistributedMatrix(numRows, nonNullRows, numCols, entriesPerRow, entryMean, isSymmetric, "testdata");
+ }
+
+ public DistributedRowMatrix randomDenseHierarchicalDistributedMatrix(int numRows,
+ int numCols,
+ boolean isSymmetric,
+ String baseTmpDirSuffix)
+ throws IOException {
+ Path baseTmpDirPath = getTestTempDirPath(baseTmpDirSuffix);
+ Matrix c = SolverTest.randomHierarchicalMatrix(numRows, numCols, isSymmetric);
+ return saveToFs(c, baseTmpDirPath);
+ }
+
+ public DistributedRowMatrix randomDistributedMatrix(int numRows,
+ int nonNullRows,
+ int numCols,
+ int entriesPerRow,
+ double entryMean,
+ boolean isSymmetric,
+ String baseTmpDirSuffix) throws IOException {
+ Path baseTmpDirPath = getTestTempDirPath(baseTmpDirSuffix);
+ Matrix c = SolverTest.randomSequentialAccessSparseMatrix(numRows, nonNullRows, numCols, entriesPerRow, entryMean);
+ if (isSymmetric) {
+ c = c.times(c.transpose());
+ }
+ return saveToFs(c, baseTmpDirPath);
+ }
+
+ private DistributedRowMatrix saveToFs(final Matrix m, Path baseTmpDirPath) throws IOException {
+ Configuration conf = getConfiguration();
+ FileSystem fs = FileSystem.get(baseTmpDirPath.toUri(), conf);
+
+ ClusteringTestUtils.writePointsToFile(new Iterable<VectorWritable>() {
+ @Override
+ public Iterator<VectorWritable> iterator() {
+ return Iterators.transform(m.iterator(), new Function<MatrixSlice,VectorWritable>() {
+ @Override
+ public VectorWritable apply(MatrixSlice input) {
+ return new VectorWritable(input.vector());
+ }
+ });
+ }
+ }, true, new Path(baseTmpDirPath, "distMatrix/part-00000"), fs, conf);
+
+ DistributedRowMatrix distMatrix = new DistributedRowMatrix(new Path(baseTmpDirPath, "distMatrix"),
+ new Path(baseTmpDirPath, "tmpOut"),
+ m.numRows(),
+ m.numCols());
+ distMatrix.setConf(new Configuration(conf));
+
+ return distMatrix;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolver.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolver.java b/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolver.java
new file mode 100644
index 0000000..ac01c28
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolver.java
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.decomposer;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.decomposer.SolverTest;
+import org.apache.mahout.math.decomposer.lanczos.LanczosState;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.hadoop.TestDistributedRowMatrix;
+import org.junit.Before;
+
+import java.io.File;
+import java.io.IOException;
+
+@Deprecated
+public final class TestDistributedLanczosSolver extends MahoutTestCase {
+
+ private int counter = 0;
+ private DistributedRowMatrix symCorpus;
+ private DistributedRowMatrix asymCorpus;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ File symTestData = getTestTempDir("symTestData");
+ File asymTestData = getTestTempDir("asymTestData");
+ symCorpus = new TestDistributedRowMatrix().randomDistributedMatrix(100,
+ 90, 80, 2, 10.0, true, symTestData.getAbsolutePath());
+ asymCorpus = new TestDistributedRowMatrix().randomDistributedMatrix(100,
+ 90, 80, 2, 10.0, false, asymTestData.getAbsolutePath());
+ }
+
+ private static String suf(boolean symmetric) {
+ return symmetric ? "_sym" : "_asym";
+ }
+
+ private DistributedRowMatrix getCorpus(boolean symmetric) {
+ return symmetric ? symCorpus : asymCorpus;
+ }
+
+ /*
+ private LanczosState doTestDistributedLanczosSolver(boolean symmetric,
+ int desiredRank) throws IOException {
+ return doTestDistributedLanczosSolver(symmetric, desiredRank, true);
+ }
+ */
+
+ private LanczosState doTestDistributedLanczosSolver(boolean symmetric,
+ int desiredRank, boolean hdfsBackedState)
+ throws IOException {
+ DistributedRowMatrix corpus = getCorpus(symmetric);
+ Configuration conf = getConfiguration();
+ corpus.setConf(conf);
+ DistributedLanczosSolver solver = new DistributedLanczosSolver();
+ Vector intitialVector = DistributedLanczosSolver.getInitialVector(corpus);
+ LanczosState state;
+ if (hdfsBackedState) {
+ HdfsBackedLanczosState hState = new HdfsBackedLanczosState(corpus,
+ desiredRank, intitialVector, new Path(getTestTempDirPath(),
+ "lanczosStateDir" + suf(symmetric) + counter));
+ hState.setConf(conf);
+ state = hState;
+ } else {
+ state = new LanczosState(corpus, desiredRank, intitialVector);
+ }
+ solver.solve(state, desiredRank, symmetric);
+ SolverTest.assertOrthonormal(state);
+ for (int i = 0; i < desiredRank/2; i++) {
+ SolverTest.assertEigen(i, state.getRightSingularVector(i), corpus, 0.1, symmetric);
+ }
+ counter++;
+ return state;
+ }
+
+ public void doTestResumeIteration(boolean symmetric) throws IOException {
+ DistributedRowMatrix corpus = getCorpus(symmetric);
+ Configuration conf = getConfiguration();
+ corpus.setConf(conf);
+ DistributedLanczosSolver solver = new DistributedLanczosSolver();
+ int rank = 10;
+ Vector intitialVector = DistributedLanczosSolver.getInitialVector(corpus);
+ HdfsBackedLanczosState state = new HdfsBackedLanczosState(corpus, rank,
+ intitialVector, new Path(getTestTempDirPath(), "lanczosStateDir" + suf(symmetric) + counter));
+ solver.solve(state, rank, symmetric);
+
+ rank *= 2;
+ state = new HdfsBackedLanczosState(corpus, rank,
+ intitialVector, new Path(getTestTempDirPath(), "lanczosStateDir" + suf(symmetric) + counter));
+ solver = new DistributedLanczosSolver();
+ solver.solve(state, rank, symmetric);
+
+ LanczosState allAtOnceState = doTestDistributedLanczosSolver(symmetric, rank, false);
+ for (int i=0; i<state.getIterationNumber(); i++) {
+ Vector v = state.getBasisVector(i).normalize();
+ Vector w = allAtOnceState.getBasisVector(i).normalize();
+ double diff = v.minus(w).norm(2);
+ assertTrue("basis " + i + " is too long: " + diff, diff < 0.1);
+ }
+ counter++;
+ }
+
+ // TODO when this can be made to run in under 20 minutes, re-enable
+ /*
+ @Test
+ public void testDistributedLanczosSolver() throws Exception {
+ doTestDistributedLanczosSolver(true, 30);
+ doTestDistributedLanczosSolver(false, 30);
+ doTestResumeIteration(true);
+ doTestResumeIteration(false);
+ }
+ */
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java b/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java
new file mode 100644
index 0000000..5dfb328
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/math/hadoop/decomposer/TestDistributedLanczosSolverCLI.java
@@ -0,0 +1,190 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.decomposer;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.hadoop.TestDistributedRowMatrix;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collection;
+import java.util.Arrays;
+
+@Deprecated
+public final class TestDistributedLanczosSolverCLI extends MahoutTestCase {
+ private static final Logger log = LoggerFactory.getLogger(TestDistributedLanczosSolverCLI.class);
+
+ @Test
+ public void testDistributedLanczosSolverCLI() throws Exception {
+ Path testData = getTestTempDirPath("testdata");
+ DistributedRowMatrix corpus =
+ new TestDistributedRowMatrix().randomDenseHierarchicalDistributedMatrix(10, 9, false,
+ testData.toString());
+ corpus.setConf(getConfiguration());
+ Path output = getTestTempDirPath("output");
+ Path tmp = getTestTempDirPath("tmp");
+ Path workingDir = getTestTempDirPath("working");
+ String[] args = {
+ "-i", new Path(testData, "distMatrix").toString(),
+ "-o", output.toString(),
+ "--tempDir", tmp.toString(),
+ "--numRows", "10",
+ "--numCols", "9",
+ "--rank", "6",
+ "--symmetric", "false",
+ "--workingDir", workingDir.toString()
+ };
+ ToolRunner.run(getConfiguration(), new DistributedLanczosSolver().new DistributedLanczosSolverJob(), args);
+
+ output = getTestTempDirPath("output2");
+ tmp = getTestTempDirPath("tmp2");
+ args = new String[] {
+ "-i", new Path(testData, "distMatrix").toString(),
+ "-o", output.toString(),
+ "--tempDir", tmp.toString(),
+ "--numRows", "10",
+ "--numCols", "9",
+ "--rank", "7",
+ "--symmetric", "false",
+ "--workingDir", workingDir.toString()
+ };
+ ToolRunner.run(getConfiguration(), new DistributedLanczosSolver().new DistributedLanczosSolverJob(), args);
+
+ Path rawEigenvectors = new Path(output, DistributedLanczosSolver.RAW_EIGENVECTORS);
+ Matrix eigenVectors = new DenseMatrix(7, corpus.numCols());
+ Configuration conf = getConfiguration();
+
+ int i = 0;
+ for (VectorWritable value : new SequenceFileValueIterable<VectorWritable>(rawEigenvectors, conf)) {
+ Vector v = value.get();
+ eigenVectors.assignRow(i, v);
+ i++;
+ }
+ assertEquals("number of eigenvectors", 7, i);
+ }
+
+ @Test
+ public void testDistributedLanczosSolverEVJCLI() throws Exception {
+ Path testData = getTestTempDirPath("testdata");
+ DistributedRowMatrix corpus = new TestDistributedRowMatrix()
+ .randomDenseHierarchicalDistributedMatrix(10, 9, false, testData.toString());
+ corpus.setConf(getConfiguration());
+ Path output = getTestTempDirPath("output");
+ Path tmp = getTestTempDirPath("tmp");
+ String[] args = {
+ "-i", new Path(testData, "distMatrix").toString(),
+ "-o", output.toString(),
+ "--tempDir", tmp.toString(),
+ "--numRows", "10",
+ "--numCols", "9",
+ "--rank", "6",
+ "--symmetric", "false",
+ "--cleansvd", "true"
+ };
+ ToolRunner.run(getConfiguration(), new DistributedLanczosSolver().new DistributedLanczosSolverJob(), args);
+
+ Path cleanEigenvectors = new Path(output, EigenVerificationJob.CLEAN_EIGENVECTORS);
+ Matrix eigenVectors = new DenseMatrix(6, corpus.numCols());
+ Collection<Double> eigenvalues = Lists.newArrayList();
+
+ output = getTestTempDirPath("output2");
+ tmp = getTestTempDirPath("tmp2");
+ args = new String[] {
+ "-i", new Path(testData, "distMatrix").toString(),
+ "-o", output.toString(),
+ "--tempDir", tmp.toString(),
+ "--numRows", "10",
+ "--numCols", "9",
+ "--rank", "7",
+ "--symmetric", "false",
+ "--cleansvd", "true"
+ };
+ ToolRunner.run(getConfiguration(), new DistributedLanczosSolver().new DistributedLanczosSolverJob(), args);
+ Path cleanEigenvectors2 = new Path(output, EigenVerificationJob.CLEAN_EIGENVECTORS);
+ Matrix eigenVectors2 = new DenseMatrix(7, corpus.numCols());
+ Configuration conf = getConfiguration();
+ Collection<Double> newEigenValues = Lists.newArrayList();
+
+ int i = 0;
+ for (VectorWritable value : new SequenceFileValueIterable<VectorWritable>(cleanEigenvectors, conf)) {
+ NamedVector v = (NamedVector) value.get();
+ eigenVectors.assignRow(i, v);
+ log.info(v.getName());
+ if (EigenVector.getCosAngleError(v.getName()) < 1.0e-3) {
+ eigenvalues.add(EigenVector.getEigenValue(v.getName()));
+ }
+ i++;
+ }
+ assertEquals("number of clean eigenvectors", 3, i);
+
+ i = 0;
+ for (VectorWritable value : new SequenceFileValueIterable<VectorWritable>(cleanEigenvectors2, conf)) {
+ NamedVector v = (NamedVector) value.get();
+ log.info(v.getName());
+ eigenVectors2.assignRow(i, v);
+ newEigenValues.add(EigenVector.getEigenValue(v.getName()));
+ i++;
+ }
+
+ Collection<Integer> oldEigensFound = Lists.newArrayList();
+ for (int row = 0; row < eigenVectors.numRows(); row++) {
+ Vector oldEigen = eigenVectors.viewRow(row);
+ if (oldEigen == null) {
+ break;
+ }
+ for (int newRow = 0; newRow < eigenVectors2.numRows(); newRow++) {
+ Vector newEigen = eigenVectors2.viewRow(newRow);
+ if (newEigen != null && oldEigen.dot(newEigen) > 0.9) {
+ oldEigensFound.add(row);
+ break;
+ }
+ }
+ }
+ assertEquals("the number of new eigenvectors", 5, i);
+
+ Collection<Double> oldEigenValuesNotFound = Lists.newArrayList();
+ for (double d : eigenvalues) {
+ boolean found = false;
+ for (double newD : newEigenValues) {
+ if (Math.abs((d - newD)/d) < 0.1) {
+ found = true;
+ }
+ }
+ if (!found) {
+ oldEigenValuesNotFound.add(d);
+ }
+ }
+ assertEquals("number of old eigenvalues not found: "
+ + Arrays.toString(oldEigenValuesNotFound.toArray(new Double[oldEigenValuesNotFound.size()])),
+ 0, oldEigenValuesNotFound.size());
+ assertEquals("did not find enough old eigenvectors", 3, oldEigensFound.size());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java b/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
new file mode 100644
index 0000000..bb2c373
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
@@ -0,0 +1,238 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity;
+
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.common.DummyOutputCollector;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.easymock.EasyMock;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+
+public class TestVectorDistanceSimilarityJob extends MahoutTestCase {
+
+ private FileSystem fs;
+
+ private static final double[][] REFERENCE = { { 1, 1 }, { 2, 1 }, { 1, 2 }, { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 },
+ { 4, 5 }, { 5, 5 } };
+
+ private static final double[][] SEEDS = { { 1, 1 }, { 10, 10 } };
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ fs = FileSystem.get(getConfiguration());
+ }
+
+ @Test
+ public void testVectorDistanceMapper() throws Exception {
+ Mapper<WritableComparable<?>, VectorWritable, StringTuple, DoubleWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+ StringTuple tuple = new StringTuple();
+ tuple.add("foo");
+ tuple.add("123");
+ context.write(tuple, new DoubleWritable(Math.sqrt(2.0)));
+ tuple = new StringTuple();
+ tuple.add("foo2");
+ tuple.add("123");
+ context.write(tuple, new DoubleWritable(1));
+
+ EasyMock.replay(context);
+
+ Vector vector = new RandomAccessSparseVector(2);
+ vector.set(0, 2);
+ vector.set(1, 2);
+
+ VectorDistanceMapper mapper = new VectorDistanceMapper();
+ setField(mapper, "measure", new EuclideanDistanceMeasure());
+ Collection<NamedVector> seedVectors = Lists.newArrayList();
+ Vector seed1 = new RandomAccessSparseVector(2);
+ seed1.set(0, 1);
+ seed1.set(1, 1);
+ Vector seed2 = new RandomAccessSparseVector(2);
+ seed2.set(0, 2);
+ seed2.set(1, 1);
+
+ seedVectors.add(new NamedVector(seed1, "foo"));
+ seedVectors.add(new NamedVector(seed2, "foo2"));
+ setField(mapper, "seedVectors", seedVectors);
+
+ mapper.map(new IntWritable(123), new VectorWritable(vector), context);
+
+ EasyMock.verify(context);
+ }
+
+ @Test
+ public void testVectorDistanceInvertedMapper() throws Exception {
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+ Vector expectVec = new DenseVector(new double[]{Math.sqrt(2.0), 1.0});
+ context.write(new Text("other"), new VectorWritable(expectVec));
+ EasyMock.replay(context);
+ Vector vector = new NamedVector(new RandomAccessSparseVector(2), "other");
+ vector.set(0, 2);
+ vector.set(1, 2);
+
+ VectorDistanceInvertedMapper mapper = new VectorDistanceInvertedMapper();
+ setField(mapper, "measure", new EuclideanDistanceMeasure());
+ Collection<NamedVector> seedVectors = Lists.newArrayList();
+ Vector seed1 = new RandomAccessSparseVector(2);
+ seed1.set(0, 1);
+ seed1.set(1, 1);
+ Vector seed2 = new RandomAccessSparseVector(2);
+ seed2.set(0, 2);
+ seed2.set(1, 1);
+
+ seedVectors.add(new NamedVector(seed1, "foo"));
+ seedVectors.add(new NamedVector(seed2, "foo2"));
+ setField(mapper, "seedVectors", seedVectors);
+
+ mapper.map(new IntWritable(123), new VectorWritable(vector), context);
+
+ EasyMock.verify(context);
+
+ }
+
+ @Test
+ public void testRun() throws Exception {
+ Path input = getTestTempDirPath("input");
+ Path output = getTestTempDirPath("output");
+ Path seedsPath = getTestTempDirPath("seeds");
+
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+ List<VectorWritable> seeds = getPointsWritable(SEEDS);
+
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"), fs, conf);
+
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
+ optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
+ output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName() };
+
+ ToolRunner.run(getConfiguration(), new VectorDistanceSimilarityJob(), args);
+
+ int expectedOutputSize = SEEDS.length * REFERENCE.length;
+ int outputSize = Iterables.size(new SequenceFileIterable<StringTuple, DoubleWritable>(new Path(output,
+ "part-m-00000"), conf));
+ assertEquals(expectedOutputSize, outputSize);
+ }
+
+ @Test
+ public void testMaxDistance() throws Exception {
+
+ Path input = getTestTempDirPath("input");
+ Path output = getTestTempDirPath("output");
+ Path seedsPath = getTestTempDirPath("seeds");
+
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+ List<VectorWritable> seeds = getPointsWritable(SEEDS);
+
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"), fs, conf);
+
+ double maxDistance = 10;
+
+ String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
+ optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
+ output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(VectorDistanceSimilarityJob.MAX_DISTANCE), String.valueOf(maxDistance) };
+
+ ToolRunner.run(getConfiguration(), new VectorDistanceSimilarityJob(), args);
+
+ int outputSize = 0;
+
+ for (Pair<StringTuple, DoubleWritable> record : new SequenceFileIterable<StringTuple, DoubleWritable>(
+ new Path(output, "part-m-00000"), conf)) {
+ outputSize++;
+ assertTrue(record.getSecond().get() <= maxDistance);
+ }
+
+ assertEquals(14, outputSize);
+ }
+
+ @Test
+ public void testRunInverted() throws Exception {
+ Path input = getTestTempDirPath("input");
+ Path output = getTestTempDirPath("output");
+ Path seedsPath = getTestTempDirPath("seeds");
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+ List<VectorWritable> seeds = getPointsWritable(SEEDS);
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"), fs, conf);
+ String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
+ optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
+ output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION),
+ EuclideanDistanceMeasure.class.getName(),
+ optKey(VectorDistanceSimilarityJob.OUT_TYPE_KEY), "v"
+ };
+ ToolRunner.run(getConfiguration(), new VectorDistanceSimilarityJob(), args);
+
+ DummyOutputCollector<Text, VectorWritable> collector = new DummyOutputCollector<Text, VectorWritable>();
+
+ for (Pair<Text, VectorWritable> record : new SequenceFileIterable<Text, VectorWritable>(
+ new Path(output, "part-m-00000"), conf)) {
+ collector.collect(record.getFirst(), record.getSecond());
+ }
+ assertEquals(REFERENCE.length, collector.getData().size());
+ for (Map.Entry<Text, List<VectorWritable>> entry : collector.getData().entrySet()) {
+ assertEquals(SEEDS.length, entry.getValue().iterator().next().get().size());
+ }
+ }
+
+ private static List<VectorWritable> getPointsWritable(double[][] raw) {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : raw) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJobTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJobTest.java b/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJobTest.java
new file mode 100644
index 0000000..5d64f90
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJobTest.java
@@ -0,0 +1,214 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.TanimotoCoefficientSimilarity;
+import org.apache.mahout.math.map.OpenIntIntHashMap;
+import org.junit.Test;
+
+import java.io.File;
+
+public class RowSimilarityJobTest extends MahoutTestCase {
+
+ /**
+ * integration test with a tiny data set
+ *
+ * <pre>
+ *
+ * input matrix:
+ *
+ * 1, 0, 1, 1, 0
+ * 0, 0, 1, 1, 0
+ * 0, 0, 0, 0, 1
+ *
+ * similarity matrix (via tanimoto):
+ *
+ * 1, 0.666, 0
+ * 0.666, 1, 0
+ * 0, 0, 1
+ * </pre>
+ * @throws Exception
+ */
+ @Test
+ public void toyIntegration() throws Exception {
+
+ File inputFile = getTestTempFile("rows");
+ File outputDir = getTestTempDir("output");
+ outputDir.delete();
+ File tmpDir = getTestTempDir("tmp");
+
+ Configuration conf = getConfiguration();
+ Path inputPath = new Path(inputFile.getAbsolutePath());
+ FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
+
+ MathHelper.writeDistributedRowMatrix(new double[][] {
+ new double[] { 1, 0, 1, 1, 0 },
+ new double[] { 0, 0, 1, 1, 0 },
+ new double[] { 0, 0, 0, 0, 1 } },
+ fs, conf, inputPath);
+
+ RowSimilarityJob rowSimilarityJob = new RowSimilarityJob();
+ rowSimilarityJob.setConf(conf);
+ rowSimilarityJob.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "--numberOfColumns", String.valueOf(5), "--similarityClassname", TanimotoCoefficientSimilarity.class.getName(),
+ "--tempDir", tmpDir.getAbsolutePath() });
+
+
+ OpenIntIntHashMap observationsPerColumn =
+ Vectors.readAsIntMap(new Path(tmpDir.getAbsolutePath(), "observationsPerColumn.bin"), conf);
+ assertEquals(4, observationsPerColumn.size());
+ assertEquals(1, observationsPerColumn.get(0));
+ assertEquals(2, observationsPerColumn.get(2));
+ assertEquals(2, observationsPerColumn.get(3));
+ assertEquals(1, observationsPerColumn.get(4));
+
+ Matrix similarityMatrix = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "part-r-00000"), 3, 3);
+
+ assertNotNull(similarityMatrix);
+ assertEquals(3, similarityMatrix.numCols());
+ assertEquals(3, similarityMatrix.numRows());
+
+ assertEquals(1.0, similarityMatrix.get(0, 0), EPSILON);
+ assertEquals(1.0, similarityMatrix.get(1, 1), EPSILON);
+ assertEquals(1.0, similarityMatrix.get(2, 2), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(2, 0), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(2, 1), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(0, 2), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(1, 2), EPSILON);
+ assertEquals(0.666666, similarityMatrix.get(0, 1), EPSILON);
+ assertEquals(0.666666, similarityMatrix.get(1, 0), EPSILON);
+ }
+
+ @Test
+ public void toyIntegrationMaxSimilaritiesPerRow() throws Exception {
+
+ File inputFile = getTestTempFile("rows");
+ File outputDir = getTestTempDir("output");
+ outputDir.delete();
+ File tmpDir = getTestTempDir("tmp");
+
+ Configuration conf = getConfiguration();
+ Path inputPath = new Path(inputFile.getAbsolutePath());
+ FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
+
+ MathHelper.writeDistributedRowMatrix(new double[][]{
+ new double[] { 1, 0, 1, 1, 0, 1 },
+ new double[] { 0, 1, 1, 1, 1, 1 },
+ new double[] { 1, 1, 0, 1, 0, 0 } },
+ fs, conf, inputPath);
+
+ RowSimilarityJob rowSimilarityJob = new RowSimilarityJob();
+ rowSimilarityJob.setConf(conf);
+ rowSimilarityJob.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "--numberOfColumns", String.valueOf(6), "--similarityClassname", TanimotoCoefficientSimilarity.class.getName(),
+ "--maxSimilaritiesPerRow", String.valueOf(1), "--excludeSelfSimilarity", String.valueOf(true),
+ "--tempDir", tmpDir.getAbsolutePath() });
+
+ Matrix similarityMatrix = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "part-r-00000"), 3, 3);
+
+ assertNotNull(similarityMatrix);
+ assertEquals(3, similarityMatrix.numCols());
+ assertEquals(3, similarityMatrix.numRows());
+
+ assertEquals(0.0, similarityMatrix.get(0, 0), EPSILON);
+ assertEquals(0.5, similarityMatrix.get(0, 1), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(0, 2), EPSILON);
+
+ assertEquals(0.5, similarityMatrix.get(1, 0), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(1, 1), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(1, 2), EPSILON);
+
+ assertEquals(0.4, similarityMatrix.get(2, 0), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(2, 1), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(2, 2), EPSILON);
+ }
+
+ @Test
+ public void toyIntegrationWithThreshold() throws Exception {
+
+
+ File inputFile = getTestTempFile("rows");
+ File outputDir = getTestTempDir("output");
+ outputDir.delete();
+ File tmpDir = getTestTempDir("tmp");
+
+ Configuration conf = getConfiguration();
+ Path inputPath = new Path(inputFile.getAbsolutePath());
+ FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
+
+ MathHelper.writeDistributedRowMatrix(new double[][]{
+ new double[] { 1, 0, 1, 1, 0, 1 },
+ new double[] { 0, 1, 1, 1, 1, 1 },
+ new double[] { 1, 1, 0, 1, 0, 0 } },
+ fs, conf, inputPath);
+
+ RowSimilarityJob rowSimilarityJob = new RowSimilarityJob();
+ rowSimilarityJob.setConf(conf);
+ rowSimilarityJob.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "--numberOfColumns", String.valueOf(6), "--similarityClassname", TanimotoCoefficientSimilarity.class.getName(),
+ "--excludeSelfSimilarity", String.valueOf(true), "--threshold", String.valueOf(0.5),
+ "--tempDir", tmpDir.getAbsolutePath() });
+
+ Matrix similarityMatrix = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "part-r-00000"), 3, 3);
+
+ assertNotNull(similarityMatrix);
+ assertEquals(3, similarityMatrix.numCols());
+ assertEquals(3, similarityMatrix.numRows());
+
+ assertEquals(0.0, similarityMatrix.get(0, 0), EPSILON);
+ assertEquals(0.5, similarityMatrix.get(0, 1), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(0, 2), EPSILON);
+
+ assertEquals(0.5, similarityMatrix.get(1, 0), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(1, 1), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(1, 2), EPSILON);
+
+ assertEquals(0.0, similarityMatrix.get(2, 0), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(2, 1), EPSILON);
+ assertEquals(0.0, similarityMatrix.get(2, 2), EPSILON);
+ }
+
+ @Test
+ public void testVectorDimensions() throws Exception {
+
+ File inputFile = getTestTempFile("rows");
+
+ Configuration conf = getConfiguration();
+ Path inputPath = new Path(inputFile.getAbsolutePath());
+ FileSystem fs = FileSystem.get(inputPath.toUri(), conf);
+
+ MathHelper.writeDistributedRowMatrix(new double[][] {
+ new double[] { 1, 0, 1, 1, 0, 1 },
+ new double[] { 0, 1, 1, 1, 1, 1 },
+ new double[] { 1, 1, 0, 1, 0, 0 } },
+ fs, conf, inputPath);
+
+ RowSimilarityJob rowSimilarityJob = new RowSimilarityJob();
+ rowSimilarityJob.setConf(conf);
+
+ int numberOfColumns = rowSimilarityJob.getDimensions(inputPath);
+
+ assertEquals(6, numberOfColumns);
+ }
+}
[04/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/lda/cvb/TestCVBModelTrainer.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/lda/cvb/TestCVBModelTrainer.java b/mr/src/test/java/org/apache/mahout/clustering/lda/cvb/TestCVBModelTrainer.java
new file mode 100644
index 0000000..dd4360a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/lda/cvb/TestCVBModelTrainer.java
@@ -0,0 +1,138 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.lda.cvb;
+
+import com.google.common.base.Joiner;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixUtils;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+public final class TestCVBModelTrainer extends MahoutTestCase {
+
+ private static final double ETA = 0.1;
+ private static final double ALPHA = 0.1;
+
+ @Test
+ public void testInMemoryCVB0() throws Exception {
+ String[] terms = new String[26];
+ for (int i=0; i<terms.length; i++) {
+ terms[i] = String.valueOf((char) (i + 'a'));
+ }
+ int numGeneratingTopics = 3;
+ int numTerms = 26;
+ Matrix matrix = ClusteringTestUtils.randomStructuredModel(numGeneratingTopics, numTerms, new DoubleFunction() {
+ @Override public double apply(double d) {
+ return 1.0 / Math.pow(d + 1.0, 2);
+ }
+ });
+
+ int numDocs = 100;
+ int numSamples = 20;
+ int numTopicsPerDoc = 1;
+
+ Matrix sampledCorpus = ClusteringTestUtils.sampledCorpus(matrix, RandomUtils.getRandom(),
+ numDocs, numSamples, numTopicsPerDoc);
+
+ List<Double> perplexities = Lists.newArrayList();
+ int numTrials = 1;
+ for (int numTestTopics = 1; numTestTopics < 2 * numGeneratingTopics; numTestTopics++) {
+ double[] perps = new double[numTrials];
+ for (int trial = 0; trial < numTrials; trial++) {
+ InMemoryCollapsedVariationalBayes0 cvb =
+ new InMemoryCollapsedVariationalBayes0(sampledCorpus, terms, numTestTopics, ALPHA, ETA, 2, 1, 0);
+ cvb.setVerbose(true);
+ perps[trial] = cvb.iterateUntilConvergence(0, 5, 0, 0.2);
+ System.out.println(perps[trial]);
+ }
+ Arrays.sort(perps);
+ System.out.println(Arrays.toString(perps));
+ perplexities.add(perps[0]);
+ }
+ System.out.println(Joiner.on(",").join(perplexities));
+ }
+
+ @Test
+ public void testRandomStructuredModelViaMR() throws Exception {
+ int numGeneratingTopics = 3;
+ int numTerms = 9;
+ Matrix matrix = ClusteringTestUtils.randomStructuredModel(numGeneratingTopics, numTerms, new DoubleFunction() {
+ @Override
+ public double apply(double d) {
+ return 1.0 / Math.pow(d + 1.0, 3);
+ }
+ });
+
+ int numDocs = 500;
+ int numSamples = 10;
+ int numTopicsPerDoc = 1;
+
+ Matrix sampledCorpus = ClusteringTestUtils.sampledCorpus(matrix, RandomUtils.getRandom(1234),
+ numDocs, numSamples, numTopicsPerDoc);
+
+ Path sampleCorpusPath = getTestTempDirPath("corpus");
+ Configuration configuration = getConfiguration();
+ MatrixUtils.write(sampleCorpusPath, configuration, sampledCorpus);
+ int numIterations = 5;
+ List<Double> perplexities = Lists.newArrayList();
+ int startTopic = numGeneratingTopics - 1;
+ int numTestTopics = startTopic;
+ while (numTestTopics < numGeneratingTopics + 2) {
+ Path topicModelStateTempPath = getTestTempDirPath("topicTemp" + numTestTopics);
+ Configuration conf = getConfiguration();
+ CVB0Driver cvb0Driver = new CVB0Driver();
+ cvb0Driver.run(conf, sampleCorpusPath, null, numTestTopics, numTerms,
+ ALPHA, ETA, numIterations, 1, 0, null, null, topicModelStateTempPath, 1234, 0.2f, 2,
+ 1, 3, 1, false);
+ perplexities.add(lowestPerplexity(conf, topicModelStateTempPath));
+ numTestTopics++;
+ }
+ int bestTopic = -1;
+ double lowestPerplexity = Double.MAX_VALUE;
+ for (int t = 0; t < perplexities.size(); t++) {
+ if (perplexities.get(t) < lowestPerplexity) {
+ lowestPerplexity = perplexities.get(t);
+ bestTopic = t + startTopic;
+ }
+ }
+ assertEquals("The optimal number of topics is not that of the generating distribution", 4, bestTopic);
+ System.out.println("Perplexities: " + Joiner.on(", ").join(perplexities));
+ }
+
+ private static double lowestPerplexity(Configuration conf, Path topicModelTemp)
+ throws IOException {
+ double lowest = Double.MAX_VALUE;
+ double current;
+ int iteration = 2;
+ while (!Double.isNaN(current = CVB0Driver.readPerplexity(conf, topicModelTemp, iteration))) {
+ lowest = Math.min(current, lowest);
+ iteration++;
+ }
+ return lowest;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/spectral/TestAffinityMatrixInputJob.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/spectral/TestAffinityMatrixInputJob.java b/mr/src/test/java/org/apache/mahout/clustering/spectral/TestAffinityMatrixInputJob.java
new file mode 100644
index 0000000..6e0cd18
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/spectral/TestAffinityMatrixInputJob.java
@@ -0,0 +1,145 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.common.DummyRecordWriter;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix.MatrixEntryWritable;
+import org.junit.Test;
+
+/**
+ * <p>Tests the affinity matrix input M/R task.</p>
+ *
+ * <p>The tricky item with this task is that the format of the input
+ * must be correct; it must take the form of a graph input, and for the
+ * current implementation, the input must be symmetric, e.g. the weight
+ * from node A to B = the weight from node B to A. This is not explicitly
+ * enforced within the task itself (since, as of the time these tests were
+ * written, we have not yet decided on a final rule regarding the
+ * symmetry/non-symmetry of the affinity matrix, so we are unofficially
+ * enforcing symmetry). Input looks something like this:</p>
+ *
+ * <pre>0, 0, 0
+ * 0, 1, 10
+ * 0, 2, 20
+ * ...
+ * 1, 0, 10
+ * 2, 0, 20
+ * ...</pre>
+ *
+ * <p>The mapper's task is simply to convert each line of text into a
+ * DistributedRowMatrix entry, allowing the reducer to join each entry
+ * of the same row into a VectorWritable.</p>
+ *
+ * <p>Exceptions are thrown in cases of bad input format: if there are
+ * more or fewer than 3 numbers per line, or any of the numbers are missing.
+ */
+public class TestAffinityMatrixInputJob extends MahoutTestCase {
+
+ private static final String [] RAW = {"0,0,0", "0,1,5", "0,2,10", "1,0,5", "1,1,0",
+ "1,2,20", "2,0,10", "2,1,20", "2,2,0"};
+ private static final int RAW_DIMENSIONS = 3;
+
+ @Test
+ public void testAffinityMatrixInputMapper() throws Exception {
+ AffinityMatrixInputMapper mapper = new AffinityMatrixInputMapper();
+ Configuration conf = getConfiguration();
+ conf.setInt(Keys.AFFINITY_DIMENSIONS, RAW_DIMENSIONS);
+
+ // set up the dummy writer and the M/R context
+ DummyRecordWriter<IntWritable, MatrixEntryWritable> writer =
+ new DummyRecordWriter<IntWritable, MatrixEntryWritable>();
+ Mapper<LongWritable, Text, IntWritable, MatrixEntryWritable>.Context
+ context = DummyRecordWriter.build(mapper, conf, writer);
+
+ // loop through all the points and test each one is converted
+ // successfully to a DistributedRowMatrix.MatrixEntry
+ for (String s : RAW) {
+ mapper.map(new LongWritable(), new Text(s), context);
+ }
+
+ // test the data was successfully constructed
+ assertEquals("Number of map results", RAW_DIMENSIONS, writer.getData().size());
+ Set<IntWritable> keys = writer.getData().keySet();
+ for (IntWritable i : keys) {
+ List<MatrixEntryWritable> row = writer.getData().get(i);
+ assertEquals("Number of items in row", RAW_DIMENSIONS, row.size());
+ }
+ }
+
+ @Test
+ public void testAffinitymatrixInputReducer() throws Exception {
+ AffinityMatrixInputMapper mapper = new AffinityMatrixInputMapper();
+ Configuration conf = getConfiguration();
+ conf.setInt(Keys.AFFINITY_DIMENSIONS, RAW_DIMENSIONS);
+
+ // set up the dummy writer and the M/R context
+ DummyRecordWriter<IntWritable, MatrixEntryWritable> mapWriter =
+ new DummyRecordWriter<IntWritable, MatrixEntryWritable>();
+ Mapper<LongWritable, Text, IntWritable, MatrixEntryWritable>.Context
+ mapContext = DummyRecordWriter.build(mapper, conf, mapWriter);
+
+ // loop through all the points and test each one is converted
+ // successfully to a DistributedRowMatrix.MatrixEntry
+ for (String s : RAW) {
+ mapper.map(new LongWritable(), new Text(s), mapContext);
+ }
+ // store the data for checking later
+ Map<IntWritable, List<MatrixEntryWritable>> map = mapWriter.getData();
+
+ // now reduce the data
+ AffinityMatrixInputReducer reducer = new AffinityMatrixInputReducer();
+ DummyRecordWriter<IntWritable, VectorWritable> redWriter =
+ new DummyRecordWriter<IntWritable, VectorWritable>();
+ Reducer<IntWritable, MatrixEntryWritable,
+ IntWritable, VectorWritable>.Context redContext = DummyRecordWriter
+ .build(reducer, conf, redWriter, IntWritable.class, MatrixEntryWritable.class);
+ for (IntWritable key : mapWriter.getKeys()) {
+ reducer.reduce(key, mapWriter.getValue(key), redContext);
+ }
+
+ // check that all the elements are correctly ordered
+ assertEquals("Number of reduce results", RAW_DIMENSIONS, redWriter.getData().size());
+ for (IntWritable row : redWriter.getKeys()) {
+ List<VectorWritable> list = redWriter.getValue(row);
+ assertEquals("Should only be one vector", 1, list.size());
+ // check that the elements in the array are correctly ordered
+ Vector v = list.get(0).get();
+ for (Vector.Element e : v.all()) {
+ // find this value in the original map
+ MatrixEntryWritable toCompare = new MatrixEntryWritable();
+ toCompare.setRow(-1);
+ toCompare.setCol(e.index());
+ toCompare.setVal(e.get());
+ assertTrue("This entry was correctly placed in its row", map.get(row).contains(toCompare));
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/spectral/TestMatrixDiagonalizeJob.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/spectral/TestMatrixDiagonalizeJob.java b/mr/src/test/java/org/apache/mahout/clustering/spectral/TestMatrixDiagonalizeJob.java
new file mode 100644
index 0000000..7d4ec1f
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/spectral/TestMatrixDiagonalizeJob.java
@@ -0,0 +1,116 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.clustering.spectral.MatrixDiagonalizeJob.MatrixDiagonalizeMapper;
+import org.apache.mahout.clustering.spectral.MatrixDiagonalizeJob.MatrixDiagonalizeReducer;
+import org.apache.mahout.common.DummyRecordWriter;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Test;
+
+/**
+ * <p>The MatrixDiagonalize task is pretty simple: given a matrix,
+ * it sums the elements of the row, and sticks the sum in position (i, i)
+ * of a new matrix of identical dimensions to the original.</p>
+ */
+public class TestMatrixDiagonalizeJob extends MahoutTestCase {
+
+ private static final double[][] RAW = { {1, 2, 3}, {4, 5, 6}, {7, 8, 9} };
+ private static final int RAW_DIMENSIONS = 3;
+
+ private static double rowSum(double [] row) {
+ double sum = 0;
+ for (double r : row) {
+ sum += r;
+ }
+ return sum;
+ }
+
+ @Test
+ public void testMatrixDiagonalizeMapper() throws Exception {
+ MatrixDiagonalizeMapper mapper = new MatrixDiagonalizeMapper();
+ Configuration conf = getConfiguration();
+ conf.setInt(Keys.AFFINITY_DIMENSIONS, RAW_DIMENSIONS);
+
+ // set up the dummy writers
+ DummyRecordWriter<NullWritable, IntDoublePairWritable> writer =
+ new DummyRecordWriter<NullWritable, IntDoublePairWritable>();
+ Mapper<IntWritable, VectorWritable, NullWritable, IntDoublePairWritable>.Context
+ context = DummyRecordWriter.build(mapper, conf, writer);
+
+ // perform the mapping
+ for (int i = 0; i < RAW_DIMENSIONS; i++) {
+ RandomAccessSparseVector toAdd = new RandomAccessSparseVector(RAW_DIMENSIONS);
+ toAdd.assign(RAW[i]);
+ mapper.map(new IntWritable(i), new VectorWritable(toAdd), context);
+ }
+
+ // check the number of the results
+ assertEquals("Number of map results", RAW_DIMENSIONS,
+ writer.getValue(NullWritable.get()).size());
+ }
+
+ @Test
+ public void testMatrixDiagonalizeReducer() throws Exception {
+ MatrixDiagonalizeMapper mapper = new MatrixDiagonalizeMapper();
+ Configuration conf = getConfiguration();
+ conf.setInt(Keys.AFFINITY_DIMENSIONS, RAW_DIMENSIONS);
+
+ // set up the dummy writers
+ DummyRecordWriter<NullWritable, IntDoublePairWritable> mapWriter =
+ new DummyRecordWriter<NullWritable, IntDoublePairWritable>();
+ Mapper<IntWritable, VectorWritable, NullWritable, IntDoublePairWritable>.Context
+ mapContext = DummyRecordWriter.build(mapper, conf, mapWriter);
+
+ // perform the mapping
+ for (int i = 0; i < RAW_DIMENSIONS; i++) {
+ RandomAccessSparseVector toAdd = new RandomAccessSparseVector(RAW_DIMENSIONS);
+ toAdd.assign(RAW[i]);
+ mapper.map(new IntWritable(i), new VectorWritable(toAdd), mapContext);
+ }
+
+ // now perform the reduction
+ MatrixDiagonalizeReducer reducer = new MatrixDiagonalizeReducer();
+ DummyRecordWriter<NullWritable, VectorWritable> redWriter = new
+ DummyRecordWriter<NullWritable, VectorWritable>();
+ Reducer<NullWritable, IntDoublePairWritable, NullWritable, VectorWritable>.Context
+ redContext = DummyRecordWriter.build(reducer, conf, redWriter,
+ NullWritable.class, IntDoublePairWritable.class);
+
+ // only need one reduction
+ reducer.reduce(NullWritable.get(), mapWriter.getValue(NullWritable.get()), redContext);
+
+ // first, make sure there's only one result
+ List<VectorWritable> list = redWriter.getValue(NullWritable.get());
+ assertEquals("Only a single resulting vector", 1, list.size());
+ Vector v = list.get(0).get();
+ for (int i = 0; i < v.size(); i++) {
+ assertEquals("Element sum is correct", rowSum(RAW[i]), v.get(i),0.01);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/spectral/TestUnitVectorizerJob.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/spectral/TestUnitVectorizerJob.java b/mr/src/test/java/org/apache/mahout/clustering/spectral/TestUnitVectorizerJob.java
new file mode 100644
index 0000000..f317f6e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/spectral/TestUnitVectorizerJob.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.spectral.UnitVectorizerJob.UnitVectorizerMapper;
+import org.apache.mahout.common.DummyRecordWriter;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Test;
+
+public class TestUnitVectorizerJob extends MahoutTestCase {
+
+ private static final double [][] RAW = { {1, 2, 3}, {4, 5, 6}, {7, 8, 9} };
+
+ @Test
+ public void testUnitVectorizerMapper() throws Exception {
+ UnitVectorizerMapper mapper = new UnitVectorizerMapper();
+ Configuration conf = getConfiguration();
+
+ // set up the dummy writers
+ DummyRecordWriter<IntWritable, VectorWritable> writer = new
+ DummyRecordWriter<IntWritable, VectorWritable>();
+ Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context
+ context = DummyRecordWriter.build(mapper, conf, writer);
+
+ // perform the mapping
+ for (int i = 0; i < RAW.length; i++) {
+ Vector vector = new RandomAccessSparseVector(RAW[i].length);
+ vector.assign(RAW[i]);
+ mapper.map(new IntWritable(i), new VectorWritable(vector), context);
+ }
+
+ // check the results
+ assertEquals("Number of map results", RAW.length, writer.getData().size());
+ for (int i = 0; i < RAW.length; i++) {
+ IntWritable key = new IntWritable(i);
+ List<VectorWritable> list = writer.getValue(key);
+ assertEquals("Only one element per row", 1, list.size());
+ Vector v = list.get(0).get();
+ assertTrue("Unit vector sum is 1 or differs by 0.0001", Math.abs(v.norm(2) - 1) < 0.000001);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/spectral/TestVectorCache.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/spectral/TestVectorCache.java b/mr/src/test/java/org/apache/mahout/clustering/spectral/TestVectorCache.java
new file mode 100644
index 0000000..9091efe
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/spectral/TestVectorCache.java
@@ -0,0 +1,110 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.net.URI;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Test;
+
+public class TestVectorCache extends MahoutTestCase {
+
+ private static final double [] VECTOR = { 1, 2, 3, 4 };
+
+ @Test
+ public void testSave() throws Exception {
+ Configuration conf = getConfiguration();
+ Writable key = new IntWritable(0);
+ Vector value = new DenseVector(VECTOR);
+ Path path = getTestTempDirPath("output");
+
+ // write the vector out
+ VectorCache.save(key, value, path, conf, true, true);
+
+ // can we read it from here?
+ SequenceFileValueIterator<VectorWritable> iterator =
+ new SequenceFileValueIterator<VectorWritable>(path, true, conf);
+ try {
+ VectorWritable old = iterator.next();
+ // test if the values are identical
+ assertEquals("Saved vector is identical to original", old.get(), value);
+ } finally {
+ Closeables.close(iterator, true);
+ }
+ }
+
+ @Test
+ public void testLoad() throws Exception {
+ // save a vector manually
+ Configuration conf = getConfiguration();
+ Writable key = new IntWritable(0);
+ Vector value = new DenseVector(VECTOR);
+ Path path = getTestTempDirPath("output");
+
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ // write the vector
+ path = fs.makeQualified(path);
+ fs.deleteOnExit(path);
+ HadoopUtil.delete(conf, path);
+ SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class);
+ try {
+ writer.append(key, new VectorWritable(value));
+ } finally {
+ Closeables.close(writer, false);
+ }
+ DistributedCache.setCacheFiles(new URI[] {path.toUri()}, conf);
+
+ // load it
+ Vector result = VectorCache.load(conf);
+
+ // are they the same?
+ assertNotNull("Vector is null", result);
+ assertEquals("Loaded vector is not identical to original", result, value);
+ }
+
+ @Test
+ public void testAll() throws Exception {
+ Configuration conf = getConfiguration();
+ Vector v = new DenseVector(VECTOR);
+ Path toSave = getTestTempDirPath("output");
+ Writable key = new IntWritable(0);
+
+ // save it
+ VectorCache.save(key, v, toSave, conf);
+
+ // now, load it back
+ Vector v2 = VectorCache.load(conf);
+
+ // are they the same?
+ assertNotNull("Vector is null", v2);
+ assertEquals("Vectors are not identical", v2, v);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/spectral/TestVectorMatrixMultiplicationJob.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/spectral/TestVectorMatrixMultiplicationJob.java b/mr/src/test/java/org/apache/mahout/clustering/spectral/TestVectorMatrixMultiplicationJob.java
new file mode 100644
index 0000000..2fd83e2
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/spectral/TestVectorMatrixMultiplicationJob.java
@@ -0,0 +1,75 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral;
+
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.clustering.spectral.VectorMatrixMultiplicationJob.VectorMatrixMultiplicationMapper;
+import org.apache.mahout.common.DummyRecordWriter;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Test;
+
+/**
+ * <p>This test ensures that a Vector can be successfully multiplied
+ * with a matrix.</p>
+ */
+public class TestVectorMatrixMultiplicationJob extends MahoutTestCase {
+
+ private static final double [][] MATRIX = { {1, 1}, {2, 3} };
+ private static final double [] VECTOR = {9, 16};
+
+ @Test
+ public void testVectorMatrixMultiplicationMapper() throws Exception {
+ VectorMatrixMultiplicationMapper mapper = new VectorMatrixMultiplicationMapper();
+ Configuration conf = getConfiguration();
+
+ // set up all the parameters for the job
+ Vector toSave = new DenseVector(VECTOR);
+ DummyRecordWriter<IntWritable, VectorWritable> writer = new
+ DummyRecordWriter<IntWritable, VectorWritable>();
+ Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable>.Context
+ context = DummyRecordWriter.build(mapper, conf, writer);
+ mapper.setup(toSave);
+
+ // run the job
+ for (int i = 0; i < MATRIX.length; i++) {
+ Vector v = new RandomAccessSparseVector(MATRIX[i].length);
+ v.assign(MATRIX[i]);
+ mapper.map(new IntWritable(i), new VectorWritable(v), context);
+ }
+
+ // check the results
+ assertEquals("Number of map results", MATRIX.length, writer.getData().size());
+ for (int i = 0; i < MATRIX.length; i++) {
+ List<VectorWritable> list = writer.getValue(new IntWritable(i));
+ assertEquals("Only one vector per key", 1, list.size());
+ Vector v = list.get(0).get();
+ for (int j = 0; j < MATRIX[i].length; j++) {
+ double total = Math.sqrt(VECTOR[i]) * Math.sqrt(VECTOR[j]) * MATRIX[i][j];
+ assertEquals("Product matrix elements", total, v.get(j),EPSILON);
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/spectral/kmeans/TestEigenSeedGenerator.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/spectral/kmeans/TestEigenSeedGenerator.java b/mr/src/test/java/org/apache/mahout/clustering/spectral/kmeans/TestEigenSeedGenerator.java
new file mode 100644
index 0000000..4075fe4
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/spectral/kmeans/TestEigenSeedGenerator.java
@@ -0,0 +1,100 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.spectral.kmeans;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.iterator.ClusterWritable;
+import org.apache.mahout.clustering.spectral.kmeans.EigenSeedGenerator;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public final class TestEigenSeedGenerator extends MahoutTestCase {
+
+ private
+ static final double[][] RAW = {{1, 0, 0}, {1, 0, 0}, {0, 1, 0}, {0, 1, 0},
+ {0, 1, 0}, {0, 0, 1}, {0, 0, 1}};
+
+ private FileSystem fs;
+
+ private static List<VectorWritable> getPoints() {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : RAW) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ Configuration conf = getConfiguration();
+ fs = FileSystem.get(conf);
+ }
+
+ @Test
+ public void testEigenSeedGenerator() throws Exception {
+ List<VectorWritable> points = getPoints();
+ Job job = new Job();
+ Configuration conf = job.getConfiguration();
+ job.setMapOutputValueClass(VectorWritable.class);
+ Path input = getTestTempFilePath("eigen-input");
+ Path output = getTestTempDirPath("eigen-output");
+ ClusteringTestUtils.writePointsToFile(points, input, fs, conf);
+
+ EigenSeedGenerator.buildFromEigens(conf, input, output, 3, new ManhattanDistanceMeasure());
+
+ int clusterCount = 0;
+ Collection<Integer> set = new HashSet<Integer>();
+ Vector v[] = new Vector[3];
+ for (ClusterWritable clusterWritable :
+ new SequenceFileValueIterable<ClusterWritable>(
+ new Path(output, "part-eigenSeed"), true, conf)) {
+ Cluster cluster = clusterWritable.getValue();
+ int id = cluster.getId();
+ assertTrue(set.add(id)); // validate unique id's
+ v[id] = cluster.getCenter();
+ clusterCount++;
+ }
+ assertEquals(3, clusterCount); // validate sample count
+ // validate pair-wise orthogonality
+ assertEquals(0, v[0].dot(v[1]), 1E-10);
+ assertEquals(0, v[1].dot(v[2]), 1E-10);
+ assertEquals(0, v[0].dot(v[2]), 1E-10);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/BallKMeansTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/BallKMeansTest.java b/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/BallKMeansTest.java
new file mode 100644
index 0000000..340ca8e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/BallKMeansTest.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.cluster;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.ConstantVector;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.SingularValueDecomposition;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.WeightedVector;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.VectorFunction;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.Searcher;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+import org.apache.mahout.math.random.MultiNormal;
+import org.apache.mahout.math.random.WeightedThing;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.apache.mahout.clustering.ClusteringUtils.totalWeight;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class BallKMeansTest {
+ private static final int NUM_DATA_POINTS = 10000;
+ private static final int NUM_DIMENSIONS = 4;
+ private static final int NUM_ITERATIONS = 20;
+ private static final double DISTRIBUTION_RADIUS = 0.01;
+
+ @BeforeClass
+ public static void setUp() {
+ RandomUtils.useTestSeed();
+ syntheticData = DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS, DISTRIBUTION_RADIUS);
+
+ }
+
+ private static Pair<List<Centroid>, List<Centroid>> syntheticData;
+ private static final int K1 = 100;
+
+
+ @Test
+ public void testClusteringMultipleRuns() {
+ for (int i = 1; i <= 10; ++i) {
+ BallKMeans clusterer = new BallKMeans(new BruteSearch(new SquaredEuclideanDistanceMeasure()),
+ 1 << NUM_DIMENSIONS, NUM_ITERATIONS, true, i);
+ clusterer.cluster(syntheticData.getFirst());
+ double costKMeansPlusPlus = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), clusterer);
+
+ clusterer = new BallKMeans(new BruteSearch(new SquaredEuclideanDistanceMeasure()),
+ 1 << NUM_DIMENSIONS, NUM_ITERATIONS, false, i);
+ clusterer.cluster(syntheticData.getFirst());
+ double costKMeansRandom = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), clusterer);
+
+ System.out.printf("%d runs; kmeans++: %f; random: %f\n", i, costKMeansPlusPlus, costKMeansRandom);
+ assertTrue("kmeans++ cost should be less than random cost", costKMeansPlusPlus < costKMeansRandom);
+ }
+ }
+
+ @Test
+ public void testClustering() {
+ UpdatableSearcher searcher = new BruteSearch(new SquaredEuclideanDistanceMeasure());
+ BallKMeans clusterer = new BallKMeans(searcher, 1 << NUM_DIMENSIONS, NUM_ITERATIONS);
+
+ long startTime = System.currentTimeMillis();
+ Pair<List<Centroid>, List<Centroid>> data = syntheticData;
+ clusterer.cluster(data.getFirst());
+ long endTime = System.currentTimeMillis();
+
+ long hash = 0;
+ for (Centroid centroid : data.getFirst()) {
+ for (Vector.Element element : centroid.all()) {
+ hash = 31 * hash + 17 * element.index() + Double.toHexString(element.get()).hashCode();
+ }
+ }
+ System.out.printf("Hash = %08x\n", hash);
+
+ assertEquals("Total weight not preserved", totalWeight(syntheticData.getFirst()), totalWeight(clusterer), 1.0e-9);
+
+ // Verify that each corner of the cube has a centroid very nearby.
+ // This is probably FALSE for large-dimensional spaces!
+ OnlineSummarizer summarizer = new OnlineSummarizer();
+ for (Vector mean : syntheticData.getSecond()) {
+ WeightedThing<Vector> v = searcher.search(mean, 1).get(0);
+ summarizer.add(v.getWeight());
+ }
+ assertTrue(String.format("Median weight [%f] too large [>%f]", summarizer.getMedian(),
+ DISTRIBUTION_RADIUS), summarizer.getMedian() < DISTRIBUTION_RADIUS);
+
+ double clusterTime = (endTime - startTime) / 1000.0;
+ System.out.printf("%s\n%.2f for clustering\n%.1f us per row\n\n",
+ searcher.getClass().getName(), clusterTime,
+ clusterTime / syntheticData.getFirst().size() * 1.0e6);
+
+ // Verify that the total weight of the centroids near each corner is correct.
+ double[] cornerWeights = new double[1 << NUM_DIMENSIONS];
+ Searcher trueFinder = new BruteSearch(new EuclideanDistanceMeasure());
+ for (Vector trueCluster : syntheticData.getSecond()) {
+ trueFinder.add(trueCluster);
+ }
+ for (Centroid centroid : clusterer) {
+ WeightedThing<Vector> closest = trueFinder.search(centroid, 1).get(0);
+ cornerWeights[((Centroid)closest.getValue()).getIndex()] += centroid.getWeight();
+ }
+ int expectedNumPoints = NUM_DATA_POINTS / (1 << NUM_DIMENSIONS);
+ for (double v : cornerWeights) {
+ System.out.printf("%f ", v);
+ }
+ System.out.println();
+ for (double v : cornerWeights) {
+ assertEquals(expectedNumPoints, v, 0);
+ }
+ }
+
+ @Test
+ public void testInitialization() {
+ // Start with super clusterable data.
+ List<? extends WeightedVector> data = cubishTestData(0.01);
+
+ // Just do initialization of ball k-means. This should drop a point into each of the clusters.
+ BallKMeans r = new BallKMeans(new BruteSearch(new SquaredEuclideanDistanceMeasure()), 6, 20);
+ r.cluster(data);
+
+ // Put the centroids into a matrix.
+ Matrix x = new DenseMatrix(6, 5);
+ int row = 0;
+ for (Centroid c : r) {
+ x.viewRow(row).assign(c.viewPart(0, 5));
+ row++;
+ }
+
+ // Verify that each column looks right. Should contain zeros except for a single 6.
+ final Vector columnNorms = x.aggregateColumns(new VectorFunction() {
+ @Override
+ public double apply(Vector f) {
+ // Return the sum of three discrepancy measures.
+ return Math.abs(f.minValue()) + Math.abs(f.maxValue() - 6) + Math.abs(f.norm(1) - 6);
+ }
+ });
+ // Verify all errors are nearly zero.
+ assertEquals(0, columnNorms.norm(1) / columnNorms.size(), 0.1);
+
+ // Verify that the centroids are a permutation of the original ones.
+ SingularValueDecomposition svd = new SingularValueDecomposition(x);
+ Vector s = svd.getS().viewDiagonal().assign(Functions.div(6));
+ assertEquals(5, s.getLengthSquared(), 0.05);
+ assertEquals(5, s.norm(1), 0.05);
+ }
+
+ private static List<? extends WeightedVector> cubishTestData(double radius) {
+ List<WeightedVector> data = Lists.newArrayListWithCapacity(K1 + 5000);
+ int row = 0;
+
+ MultiNormal g = new MultiNormal(radius, new ConstantVector(0, 10));
+ for (int i = 0; i < K1; i++) {
+ data.add(new WeightedVector(g.sample(), 1, row++));
+ }
+
+ for (int i = 0; i < 5; i++) {
+ Vector m = new DenseVector(10);
+ m.set(i, 6); // This was originally i == 0 ? 6 : 6 which can't be right
+ MultiNormal gx = new MultiNormal(radius, m);
+ for (int j = 0; j < 1000; j++) {
+ data.add(new WeightedVector(gx.sample(), 1, row++));
+ }
+ }
+ return data;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/DataUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/DataUtils.java b/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/DataUtils.java
new file mode 100644
index 0000000..2257541
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/DataUtils.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.cluster;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.random.MultiNormal;
+
+/**
+ * A collection of miscellaneous utility functions for working with data to be clustered.
+ * Includes methods for generating synthetic data and estimating distance cutoff.
+ */
+public final class DataUtils {
+ private DataUtils() {
+ }
+
+ /**
+ * Samples numDatapoints vectors of numDimensions cardinality centered around the vertices of a
+ * numDimensions order hypercube. The distribution of points around these vertices is
+ * multinormal with a radius of distributionRadius.
+ * A hypercube of numDimensions has 2^numDimensions vertices. Keep this in mind when clustering
+ * the data.
+ *
+ * Note that it is almost always the case that you want to call RandomUtils.useTestSeed() before
+ * generating test data. This means that you can't generate data in the declaration of a static
+ * variable because such initializations happen before any @BeforeClass or @Before setup methods
+ * are called.
+ *
+ *
+ * @param numDimensions number of dimensions of the vectors to be generated.
+ * @param numDatapoints number of data points to be generated.
+ * @param distributionRadius radius of the distribution around the hypercube vertices.
+ * @return a pair of lists, whose first element is the sampled points and whose second element
+ * is the list of hypercube vertices that are the means of each distribution.
+ */
+ public static Pair<List<Centroid>, List<Centroid>> sampleMultiNormalHypercube(
+ int numDimensions, int numDatapoints, double distributionRadius) {
+ int pow2N = 1 << numDimensions;
+ // Construct data samplers centered on the corners of a unit hypercube.
+ // Additionally, keep the means of the distributions that will be generated so we can compare
+ // these to the ideal cluster centers.
+ List<Centroid> mean = Lists.newArrayListWithCapacity(pow2N);
+ List<MultiNormal> rowSamplers = Lists.newArrayList();
+ for (int i = 0; i < pow2N; i++) {
+ Vector v = new DenseVector(numDimensions);
+ // Select each of the num
+ int pow2J = 1 << (numDimensions - 1);
+ for (int j = 0; j < numDimensions; ++j) {
+ v.set(j, 1.0 / pow2J * (i & pow2J));
+ pow2J >>= 1;
+ }
+ mean.add(new Centroid(i, v, 1));
+ rowSamplers.add(new MultiNormal(distributionRadius, v));
+ }
+
+ // Sample the requested number of data points.
+ List<Centroid> data = Lists.newArrayListWithCapacity(numDatapoints);
+ for (int i = 0; i < numDatapoints; ++i) {
+ data.add(new Centroid(i, rowSamplers.get(i % pow2N).sample(), 1));
+ }
+ return new Pair<List<Centroid>, List<Centroid>>(data, mean);
+ }
+
+ /**
+ * Calls sampleMultinormalHypercube(numDimension, numDataPoints, 0.01).
+ * @see DataUtils#sampleMultiNormalHypercube(int, int, double)
+ */
+ public static Pair<List<Centroid>, List<Centroid>> sampleMultiNormalHypercube(int numDimensions,
+ int numDatapoints) {
+ return sampleMultiNormalHypercube(numDimensions, numDatapoints, 0.01);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java b/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
new file mode 100644
index 0000000..cf9263c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.cluster;
+
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.ProjectionSearch;
+import org.apache.mahout.math.neighborhood.Searcher;
+import org.apache.mahout.math.neighborhood.UpdatableSearcher;
+import org.apache.mahout.math.random.WeightedThing;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.runners.Parameterized.Parameters;
+
+
+@RunWith(Parameterized.class)
+public class StreamingKMeansTest {
+ private static final int NUM_DATA_POINTS = 1 << 16;
+ private static final int NUM_DIMENSIONS = 6;
+ private static final int NUM_PROJECTIONS = 2;
+ private static final int SEARCH_SIZE = 10;
+
+ private static Pair<List<Centroid>, List<Centroid>> syntheticData ;
+
+ @Before
+ public void setUp() {
+ RandomUtils.useTestSeed();
+ syntheticData =
+ DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS);
+ }
+
+ private UpdatableSearcher searcher;
+ private boolean allAtOnce;
+
+ public StreamingKMeansTest(UpdatableSearcher searcher, boolean allAtOnce) {
+ this.searcher = searcher;
+ this.allAtOnce = allAtOnce;
+ }
+
+ @Parameters
+ public static List<Object[]> generateData() {
+ return Arrays.asList(new Object[][] {
+ {new ProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), true},
+ {new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE),
+ true},
+ {new ProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), false},
+ {new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE),
+ false},
+ });
+ }
+
+ @Test
+ public void testAverageDistanceCutoff() {
+ double avgDistanceCutoff = 0;
+ double avgNumClusters = 0;
+ int numTests = 1;
+ System.out.printf("Distance cutoff for %s\n", searcher.getClass().getName());
+ for (int i = 0; i < numTests; ++i) {
+ searcher.clear();
+ int numStreamingClusters = (int)Math.log(syntheticData.getFirst().size()) * (1 <<
+ NUM_DIMENSIONS);
+ double distanceCutoff = 1.0e-6;
+ double estimatedCutoff = ClusteringUtils.estimateDistanceCutoff(syntheticData.getFirst(),
+ searcher.getDistanceMeasure(), 100);
+ System.out.printf("[%d] Generated synthetic data [magic] %f [estimate] %f\n", i, distanceCutoff, estimatedCutoff);
+ StreamingKMeans clusterer =
+ new StreamingKMeans(searcher, numStreamingClusters, estimatedCutoff);
+ clusterer.cluster(syntheticData.getFirst());
+ avgDistanceCutoff += clusterer.getDistanceCutoff();
+ avgNumClusters += clusterer.getNumClusters();
+ System.out.printf("[%d] %f\n", i, clusterer.getDistanceCutoff());
+ }
+ avgDistanceCutoff /= numTests;
+ avgNumClusters /= numTests;
+ System.out.printf("Final: distanceCutoff: %f estNumClusters: %f\n", avgDistanceCutoff, avgNumClusters);
+ }
+
+ @Test
+ public void testClustering() {
+ searcher.clear();
+ int numStreamingClusters = (int)Math.log(syntheticData.getFirst().size()) * (1 << NUM_DIMENSIONS);
+ System.out.printf("k log n = %d\n", numStreamingClusters);
+ double estimatedCutoff = ClusteringUtils.estimateDistanceCutoff(syntheticData.getFirst(),
+ searcher.getDistanceMeasure(), 100);
+ StreamingKMeans clusterer =
+ new StreamingKMeans(searcher, numStreamingClusters, estimatedCutoff);
+
+ long startTime = System.currentTimeMillis();
+ if (allAtOnce) {
+ clusterer.cluster(syntheticData.getFirst());
+ } else {
+ for (Centroid datapoint : syntheticData.getFirst()) {
+ clusterer.cluster(datapoint);
+ }
+ }
+ long endTime = System.currentTimeMillis();
+
+ System.out.printf("%s %s\n", searcher.getClass().getName(), searcher.getDistanceMeasure()
+ .getClass().getName());
+ System.out.printf("Total number of clusters %d\n", clusterer.getNumClusters());
+
+ System.out.printf("Weights: %f %f\n", ClusteringUtils.totalWeight(syntheticData.getFirst()),
+ ClusteringUtils.totalWeight(clusterer));
+ assertEquals("Total weight not preserved", ClusteringUtils.totalWeight(syntheticData.getFirst()),
+ ClusteringUtils.totalWeight(clusterer), 1.0e-9);
+
+ // and verify that each corner of the cube has a centroid very nearby
+ double maxWeight = 0;
+ for (Vector mean : syntheticData.getSecond()) {
+ WeightedThing<Vector> v = searcher.search(mean, 1).get(0);
+ maxWeight = Math.max(v.getWeight(), maxWeight);
+ }
+ assertTrue("Maximum weight too large " + maxWeight, maxWeight < 0.05);
+ double clusterTime = (endTime - startTime) / 1000.0;
+ System.out.printf("%s\n%.2f for clustering\n%.1f us per row\n\n",
+ searcher.getClass().getName(), clusterTime,
+ clusterTime / syntheticData.getFirst().size() * 1.0e6);
+
+ // verify that the total weight of the centroids near each corner is correct
+ double[] cornerWeights = new double[1 << NUM_DIMENSIONS];
+ Searcher trueFinder = new BruteSearch(new EuclideanDistanceMeasure());
+ for (Vector trueCluster : syntheticData.getSecond()) {
+ trueFinder.add(trueCluster);
+ }
+ for (Centroid centroid : clusterer) {
+ WeightedThing<Vector> closest = trueFinder.search(centroid, 1).get(0);
+ cornerWeights[((Centroid)closest.getValue()).getIndex()] += centroid.getWeight();
+ }
+ int expectedNumPoints = NUM_DATA_POINTS / (1 << NUM_DIMENSIONS);
+ for (double v : cornerWeights) {
+ System.out.printf("%f ", v);
+ }
+ System.out.println();
+ for (double v : cornerWeights) {
+ assertEquals(expectedNumPoints, v, 0);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java b/mr/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
new file mode 100644
index 0000000..9b582b4
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java
@@ -0,0 +1,283 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.mapreduce;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mrunit.mapreduce.MapDriver;
+import org.apache.hadoop.mrunit.mapreduce.MapReduceDriver;
+import org.apache.hadoop.mrunit.mapreduce.ReduceDriver;
+import org.apache.mahout.clustering.ClusteringUtils;
+import org.apache.mahout.clustering.streaming.cluster.DataUtils;
+import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Centroid;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.neighborhood.BruteSearch;
+import org.apache.mahout.math.neighborhood.FastProjectionSearch;
+import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch;
+import org.apache.mahout.math.neighborhood.ProjectionSearch;
+import org.apache.mahout.math.random.WeightedThing;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+@RunWith(Parameterized.class)
+public class StreamingKMeansTestMR extends MahoutTestCase {
+ private static final int NUM_DATA_POINTS = 1 << 15;
+ private static final int NUM_DIMENSIONS = 8;
+ private static final int NUM_PROJECTIONS = 3;
+ private static final int SEARCH_SIZE = 5;
+ private static final int MAX_NUM_ITERATIONS = 10;
+ private static final double DISTANCE_CUTOFF = 1.0e-6;
+
+ private static Pair<List<Centroid>, List<Centroid>> syntheticData;
+
+ @Before
+ public void setUp() {
+ RandomUtils.useTestSeed();
+ syntheticData =
+ DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS, 1.0e-4);
+ }
+
+ private final String searcherClassName;
+ private final String distanceMeasureClassName;
+
+ public StreamingKMeansTestMR(String searcherClassName, String distanceMeasureClassName) {
+ this.searcherClassName = searcherClassName;
+ this.distanceMeasureClassName = distanceMeasureClassName;
+ }
+
+ private void configure(Configuration configuration) {
+ configuration.set(DefaultOptionCreator.DISTANCE_MEASURE_OPTION, distanceMeasureClassName);
+ configuration.setInt(StreamingKMeansDriver.SEARCH_SIZE_OPTION, SEARCH_SIZE);
+ configuration.setInt(StreamingKMeansDriver.NUM_PROJECTIONS_OPTION, NUM_PROJECTIONS);
+ configuration.set(StreamingKMeansDriver.SEARCHER_CLASS_OPTION, searcherClassName);
+ configuration.setInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, 1 << NUM_DIMENSIONS);
+ configuration.setInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS,
+ (1 << NUM_DIMENSIONS) * (int)Math.log(NUM_DATA_POINTS));
+ configuration.setFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, (float) DISTANCE_CUTOFF);
+ configuration.setInt(StreamingKMeansDriver.MAX_NUM_ITERATIONS, MAX_NUM_ITERATIONS);
+
+ // Collapse the Centroids in the reducer.
+ configuration.setBoolean(StreamingKMeansDriver.REDUCE_STREAMING_KMEANS, true);
+ }
+
+ @Parameterized.Parameters
+ public static List<Object[]> generateData() {
+ return Arrays.asList(new Object[][]{
+ {ProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()},
+ {FastProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()},
+ {LocalitySensitiveHashSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()},
+ });
+ }
+
+ @Test
+ public void testHypercubeMapper() throws IOException {
+ MapDriver<Writable, VectorWritable, IntWritable, CentroidWritable> mapDriver =
+ MapDriver.newMapDriver(new StreamingKMeansMapper());
+ configure(mapDriver.getConfiguration());
+ System.out.printf("%s mapper test\n",
+ mapDriver.getConfiguration().get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
+ for (Centroid datapoint : syntheticData.getFirst()) {
+ mapDriver.addInput(new IntWritable(0), new VectorWritable(datapoint));
+ }
+ List<org.apache.hadoop.mrunit.types.Pair<IntWritable,CentroidWritable>> results = mapDriver.run();
+ BruteSearch resultSearcher = new BruteSearch(new SquaredEuclideanDistanceMeasure());
+ for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> result : results) {
+ resultSearcher.add(result.getSecond().getCentroid());
+ }
+ System.out.printf("Clustered the data into %d clusters\n", results.size());
+ for (Vector mean : syntheticData.getSecond()) {
+ WeightedThing<Vector> closest = resultSearcher.search(mean, 1).get(0);
+ assertTrue("Weight " + closest.getWeight() + " not less than 0.5", closest.getWeight() < 0.5);
+ }
+ }
+
+ @Test
+ public void testMapperVsLocal() throws IOException {
+ // Clusters the data using the StreamingKMeansMapper.
+ MapDriver<Writable, VectorWritable, IntWritable, CentroidWritable> mapDriver =
+ MapDriver.newMapDriver(new StreamingKMeansMapper());
+ Configuration configuration = mapDriver.getConfiguration();
+ configure(configuration);
+ System.out.printf("%s mapper vs local test\n",
+ mapDriver.getConfiguration().get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
+
+ for (Centroid datapoint : syntheticData.getFirst()) {
+ mapDriver.addInput(new IntWritable(0), new VectorWritable(datapoint));
+ }
+ List<Centroid> mapperCentroids = Lists.newArrayList();
+ for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> pair : mapDriver.run()) {
+ mapperCentroids.add(pair.getSecond().getCentroid());
+ }
+
+ // Clusters the data using local batch StreamingKMeans.
+ StreamingKMeans batchClusterer =
+ new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(configuration),
+ mapDriver.getConfiguration().getInt("estimatedNumMapClusters", -1), DISTANCE_CUTOFF);
+ batchClusterer.cluster(syntheticData.getFirst());
+ List<Centroid> batchCentroids = Lists.newArrayList();
+ for (Vector v : batchClusterer) {
+ batchCentroids.add((Centroid) v);
+ }
+
+ // Clusters the data using point by point StreamingKMeans.
+ StreamingKMeans perPointClusterer =
+ new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(configuration),
+ (1 << NUM_DIMENSIONS) * (int)Math.log(NUM_DATA_POINTS), DISTANCE_CUTOFF);
+ for (Centroid datapoint : syntheticData.getFirst()) {
+ perPointClusterer.cluster(datapoint);
+ }
+ List<Centroid> perPointCentroids = Lists.newArrayList();
+ for (Vector v : perPointClusterer) {
+ perPointCentroids.add((Centroid) v);
+ }
+
+ // Computes the cost (total sum of distances) of these different clusterings.
+ double mapperCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), mapperCentroids);
+ double localCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), batchCentroids);
+ double perPointCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), perPointCentroids);
+ System.out.printf("[Total cost] Mapper %f [%d] Local %f [%d] Perpoint local %f [%d];" +
+ "[ratio m-vs-l %f] [ratio pp-vs-l %f]\n", mapperCost, mapperCentroids.size(),
+ localCost, batchCentroids.size(), perPointCost, perPointCentroids.size(),
+ mapperCost / localCost, perPointCost / localCost);
+
+ // These ratios should be close to 1.0 and have been observed to be go as low as 0.6 and as low as 1.5.
+ // A buffer of [0.2, 1.8] seems appropriate.
+ assertEquals("Mapper StreamingKMeans / Batch local StreamingKMeans total cost ratio too far from 1",
+ 1.0, mapperCost / localCost, 0.8);
+ assertEquals("One by one local StreamingKMeans / Batch local StreamingKMeans total cost ratio too high",
+ 1.0, perPointCost / localCost, 0.8);
+ }
+
+ @Test
+ public void testHypercubeReducer() throws IOException {
+ ReduceDriver<IntWritable, CentroidWritable, IntWritable, CentroidWritable> reduceDriver =
+ ReduceDriver.newReduceDriver(new StreamingKMeansReducer());
+ Configuration configuration = reduceDriver.getConfiguration();
+ configure(configuration);
+
+ System.out.printf("%s reducer test\n", configuration.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
+ StreamingKMeans clusterer =
+ new StreamingKMeans(StreamingKMeansUtilsMR .searcherFromConfiguration(configuration),
+ (1 << NUM_DIMENSIONS) * (int)Math.log(NUM_DATA_POINTS), DISTANCE_CUTOFF);
+
+ long start = System.currentTimeMillis();
+ clusterer.cluster(syntheticData.getFirst());
+ long end = System.currentTimeMillis();
+
+ System.out.printf("%f [s]\n", (end - start) / 1000.0);
+ List<CentroidWritable> reducerInputs = Lists.newArrayList();
+ int postMapperTotalWeight = 0;
+ for (Centroid intermediateCentroid : clusterer) {
+ reducerInputs.add(new CentroidWritable(intermediateCentroid));
+ postMapperTotalWeight += intermediateCentroid.getWeight();
+ }
+
+ reduceDriver.addInput(new IntWritable(0), reducerInputs);
+ List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> results =
+ reduceDriver.run();
+ testReducerResults(postMapperTotalWeight, results);
+ }
+
+ @Test
+ public void testHypercubeMapReduce() throws IOException {
+ MapReduceDriver<Writable, VectorWritable, IntWritable, CentroidWritable, IntWritable, CentroidWritable>
+ mapReduceDriver = new MapReduceDriver<Writable, VectorWritable, IntWritable, CentroidWritable,
+ IntWritable, CentroidWritable>(new StreamingKMeansMapper(), new StreamingKMeansReducer());
+ Configuration configuration = mapReduceDriver.getConfiguration();
+ configure(configuration);
+
+ System.out.printf("%s full test\n", configuration.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION));
+ for (Centroid datapoint : syntheticData.getFirst()) {
+ mapReduceDriver.addInput(new IntWritable(0), new VectorWritable(datapoint));
+ }
+ List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> results = mapReduceDriver.run();
+ testReducerResults(syntheticData.getFirst().size(), results);
+ }
+
+ @Test
+ public void testHypercubeMapReduceRunSequentially() throws Exception {
+ Configuration configuration = getConfiguration();
+ configure(configuration);
+ configuration.set(DefaultOptionCreator.METHOD_OPTION, DefaultOptionCreator.SEQUENTIAL_METHOD);
+
+ Path inputPath = new Path("testInput");
+ Path outputPath = new Path("testOutput");
+ StreamingKMeansUtilsMR.writeVectorsToSequenceFile(syntheticData.getFirst(), inputPath, configuration);
+
+ StreamingKMeansDriver.run(configuration, inputPath, outputPath);
+
+ testReducerResults(syntheticData.getFirst().size(),
+ Lists.newArrayList(Iterables.transform(
+ new SequenceFileIterable<IntWritable, CentroidWritable>(outputPath, configuration),
+ new Function<
+ Pair<IntWritable, CentroidWritable>,
+ org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>>() {
+ @Override
+ public org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> apply(
+ org.apache.mahout.common.Pair<IntWritable, CentroidWritable> input) {
+ return new org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>(
+ input.getFirst(), input.getSecond());
+ }
+ })));
+ }
+
+ private static void testReducerResults(int totalWeight, List<org.apache.hadoop.mrunit.types.Pair<IntWritable,
+ CentroidWritable>> results) {
+ int expectedNumClusters = 1 << NUM_DIMENSIONS;
+ double expectedWeight = (double) totalWeight / expectedNumClusters;
+ int numClusters = 0;
+ int numUnbalancedClusters = 0;
+ int totalReducerWeight = 0;
+ for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> result : results) {
+ if (result.getSecond().getCentroid().getWeight() != expectedWeight) {
+ System.out.printf("Unbalanced weight %f in centroid %d\n", result.getSecond().getCentroid().getWeight(),
+ result.getSecond().getCentroid().getIndex());
+ ++numUnbalancedClusters;
+ }
+ assertEquals("Final centroid index is invalid", numClusters, result.getFirst().get());
+ totalReducerWeight += result.getSecond().getCentroid().getWeight();
+ ++numClusters;
+ }
+ System.out.printf("%d clusters are unbalanced\n", numUnbalancedClusters);
+ assertEquals("Invalid total weight", totalWeight, totalReducerWeight);
+ assertEquals("Invalid number of clusters", 1 << NUM_DIMENSIONS, numClusters);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFilesTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFilesTest.java b/mr/src/test/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFilesTest.java
new file mode 100644
index 0000000..2d790e5
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFilesTest.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.streaming.tools;
+
+import com.google.common.collect.Iterables;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.LocalFileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.junit.Test;
+
+public class ResplitSequenceFilesTest extends MahoutTestCase {
+
+ @Test
+ public void testSplitting() throws Exception {
+
+ Path inputFile = new Path(getTestTempDirPath("input"), "test.seq");
+ Path output = getTestTempDirPath("output");
+ Configuration conf = new Configuration();
+ LocalFileSystem fs = FileSystem.getLocal(conf);
+
+ SequenceFile.Writer writer = null;
+ try {
+ writer = SequenceFile.createWriter(fs, conf, inputFile, IntWritable.class, IntWritable.class);
+ writer.append(new IntWritable(1), new IntWritable(1));
+ writer.append(new IntWritable(2), new IntWritable(2));
+ writer.append(new IntWritable(3), new IntWritable(3));
+ writer.append(new IntWritable(4), new IntWritable(4));
+ writer.append(new IntWritable(5), new IntWritable(5));
+ writer.append(new IntWritable(6), new IntWritable(6));
+ writer.append(new IntWritable(7), new IntWritable(7));
+ writer.append(new IntWritable(8), new IntWritable(8));
+ } finally {
+ Closeables.close(writer, false);
+ }
+
+ String splitPattern = "split";
+ int numSplits = 4;
+
+ ResplitSequenceFiles.main(new String[] { "--input", inputFile.toString(),
+ "--output", output.toString() + "/" + splitPattern, "--numSplits", String.valueOf(numSplits) });
+
+ FileStatus[] statuses = HadoopUtil.getFileStatus(output, PathType.LIST, PathFilters.logsCRCFilter(), null, conf);
+
+ for (FileStatus status : statuses) {
+ String name = status.getPath().getName();
+ assertTrue(name.startsWith(splitPattern));
+ assertEquals(2, numEntries(status, conf));
+ }
+ assertEquals(numSplits, statuses.length);
+ }
+
+ private int numEntries(FileStatus status, Configuration conf) {
+ return Iterables.size(new SequenceFileIterable(status.getPath(), conf));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/topdown/PathDirectoryTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/topdown/PathDirectoryTest.java b/mr/src/test/java/org/apache/mahout/clustering/topdown/PathDirectoryTest.java
new file mode 100644
index 0000000..66b66e3
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/topdown/PathDirectoryTest.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.mahout.clustering.topdown;
+
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+import java.io.File;
+
+public final class PathDirectoryTest extends MahoutTestCase {
+
+ private final Path output = new Path("output");
+
+ @Test
+ public void shouldReturnTopLevelClusterPath() {
+ Path expectedPath = new Path(output, PathDirectory.TOP_LEVEL_CLUSTER_DIRECTORY);
+ assertEquals(expectedPath, PathDirectory.getTopLevelClusterPath(output));
+ }
+
+ @Test
+ public void shouldReturnClusterPostProcessorOutputDirectory() {
+ Path expectedPath = new Path(output, PathDirectory.POST_PROCESS_DIRECTORY);
+ assertEquals(expectedPath, PathDirectory.getClusterPostProcessorOutputDirectory(output));
+ }
+
+ @Test
+ public void shouldReturnClusterOutputClusteredPoints() {
+ Path expectedPath = new Path(output, PathDirectory.CLUSTERED_POINTS_DIRECTORY + File.separator + '*');
+ assertEquals(expectedPath, PathDirectory.getClusterOutputClusteredPoints(output));
+ }
+
+ @Test
+ public void shouldReturnBottomLevelClusterPath() {
+ Path expectedPath = new Path(output + File.separator
+ + PathDirectory.BOTTOM_LEVEL_CLUSTER_DIRECTORY + File.separator
+ + '1');
+ assertEquals(expectedPath, PathDirectory.getBottomLevelClusterPath(output, "1"));
+ }
+
+ @Test
+ public void shouldReturnClusterPathForClusterId() {
+ Path expectedPath = new Path(PathDirectory.getClusterPostProcessorOutputDirectory(output), new Path("1"));
+ assertEquals(expectedPath, PathDirectory.getClusterPathForClusterId(
+ PathDirectory.getClusterPostProcessorOutputDirectory(output), "1"));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReaderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReaderTest.java b/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReaderTest.java
new file mode 100644
index 0000000..0934ff7
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReaderTest.java
@@ -0,0 +1,121 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.topdown.postprocessor;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.canopy.CanopyDriver;
+import org.apache.mahout.clustering.classify.WeightedVectorWritable;
+import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.common.DummyOutputCollector;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public final class ClusterCountReaderTest extends MahoutTestCase {
+
+ public static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4}, {5, 4}, {4, 5}, {5, 5}};
+
+ private FileSystem fs;
+ private Path outputPathForCanopy;
+ private Path outputPathForKMeans;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ Configuration conf = getConfiguration();
+ fs = FileSystem.get(conf);
+ }
+
+ public static List<VectorWritable> getPointsWritable(double[][] raw) {
+ List<VectorWritable> points = Lists.newArrayList();
+ for (double[] fr : raw) {
+ Vector vec = new RandomAccessSparseVector(fr.length);
+ vec.assign(fr);
+ points.add(new VectorWritable(vec));
+ }
+ return points;
+ }
+
+ /**
+ * Story: User wants to use cluster post processor after canopy clustering and then run clustering on the
+ * output clusters
+ */
+ @Test
+ public void testGetNumberOfClusters() throws Exception {
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+
+ Path pointsPath = getTestTempDirPath("points");
+ Configuration conf = getConfiguration();
+ ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file2"), fs, conf);
+
+ outputPathForCanopy = getTestTempDirPath("canopy");
+ outputPathForKMeans = getTestTempDirPath("kmeans");
+
+ topLevelClustering(pointsPath, conf);
+
+ int numberOfClusters = ClusterCountReader.getNumberOfClusters(outputPathForKMeans, conf);
+ Assert.assertEquals(2, numberOfClusters);
+ verifyThatNumberOfClustersIsCorrect(conf, new Path(outputPathForKMeans, new Path("clusteredPoints")));
+
+ }
+
+ private void topLevelClustering(Path pointsPath, Configuration conf) throws IOException,
+ InterruptedException,
+ ClassNotFoundException {
+ DistanceMeasure measure = new ManhattanDistanceMeasure();
+ CanopyDriver.run(conf, pointsPath, outputPathForCanopy, measure, 4.0, 3.0, true, 0.0, true);
+ Path clustersIn = new Path(outputPathForCanopy, new Path(Cluster.CLUSTERS_DIR + '0'
+ + Cluster.FINAL_ITERATION_SUFFIX));
+ KMeansDriver.run(conf, pointsPath, clustersIn, outputPathForKMeans, 1, 1, true, 0.0, true);
+ }
+
+ private static void verifyThatNumberOfClustersIsCorrect(Configuration conf, Path clusteredPointsPath) {
+ DummyOutputCollector<IntWritable,WeightedVectorWritable> collector =
+ new DummyOutputCollector<IntWritable,WeightedVectorWritable>();
+
+ // The key is the clusterId, the value is the weighted vector
+ for (Pair<IntWritable,WeightedVectorWritable> record :
+ new SequenceFileIterable<IntWritable,WeightedVectorWritable>(new Path(clusteredPointsPath, "part-m-0"),
+ conf)) {
+ collector.collect(record.getFirst(), record.getSecond());
+ }
+ int clusterSize = collector.getKeys().size();
+ assertEquals(2, clusterSize);
+ }
+
+}
[20/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/HdfsBackedLanczosState.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/HdfsBackedLanczosState.java b/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/HdfsBackedLanczosState.java
new file mode 100644
index 0000000..f1874a8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/HdfsBackedLanczosState.java
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.decomposer;
+
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configurable;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.decomposer.lanczos.LanczosState;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class HdfsBackedLanczosState extends LanczosState implements Configurable {
+
+ private static final Logger log = LoggerFactory.getLogger(HdfsBackedLanczosState.class);
+
+ public static final String BASIS_PREFIX = "basis";
+ public static final String SINGULAR_PREFIX = "singular";
+ //public static final String METADATA_FILE = "metadata";
+
+ private Configuration conf;
+ private final Path baseDir;
+ private final Path basisPath;
+ private final Path singularVectorPath;
+ private FileSystem fs;
+
+ public HdfsBackedLanczosState(VectorIterable corpus, int desiredRank, Vector initialVector, Path dir) {
+ super(corpus, desiredRank, initialVector);
+ baseDir = dir;
+ //Path metadataPath = new Path(dir, METADATA_FILE);
+ basisPath = new Path(dir, BASIS_PREFIX);
+ singularVectorPath = new Path(dir, SINGULAR_PREFIX);
+ if (corpus instanceof Configurable) {
+ setConf(((Configurable)corpus).getConf());
+ }
+ }
+
+ @Override public void setConf(Configuration configuration) {
+ conf = configuration;
+ try {
+ setupDirs();
+ updateHdfsState();
+ } catch (IOException e) {
+ log.error("Could not retrieve filesystem: {}", conf, e);
+ }
+ }
+
+ @Override public Configuration getConf() {
+ return conf;
+ }
+
+ private void setupDirs() throws IOException {
+ fs = baseDir.getFileSystem(conf);
+ createDirIfNotExist(baseDir);
+ createDirIfNotExist(basisPath);
+ createDirIfNotExist(singularVectorPath);
+ }
+
+ private void createDirIfNotExist(Path path) throws IOException {
+ if (!fs.exists(path) && !fs.mkdirs(path)) {
+ throw new IOException("Unable to create: " + path);
+ }
+ }
+
+ @Override
+ public void setIterationNumber(int i) {
+ super.setIterationNumber(i);
+ try {
+ updateHdfsState();
+ } catch (IOException e) {
+ log.error("Could not update HDFS state: ", e);
+ }
+ }
+
+ protected void updateHdfsState() throws IOException {
+ if (conf == null) {
+ return;
+ }
+ int numBasisVectorsOnDisk = 0;
+ Path nextBasisVectorPath = new Path(basisPath, BASIS_PREFIX + '_' + numBasisVectorsOnDisk);
+ while (fs.exists(nextBasisVectorPath)) {
+ nextBasisVectorPath = new Path(basisPath, BASIS_PREFIX + '_' + ++numBasisVectorsOnDisk);
+ }
+ Vector nextVector;
+ while (numBasisVectorsOnDisk < iterationNumber
+ && (nextVector = getBasisVector(numBasisVectorsOnDisk)) != null) {
+ persistVector(nextBasisVectorPath, numBasisVectorsOnDisk, nextVector);
+ nextBasisVectorPath = new Path(basisPath, BASIS_PREFIX + '_' + ++numBasisVectorsOnDisk);
+ }
+ if (scaleFactor <= 0) {
+ scaleFactor = getScaleFactor(); // load from disk if possible
+ }
+ diagonalMatrix = getDiagonalMatrix(); // load from disk if possible
+ Vector norms = new DenseVector(diagonalMatrix.numCols() - 1);
+ Vector projections = new DenseVector(diagonalMatrix.numCols());
+ int i = 0;
+ while (i < diagonalMatrix.numCols() - 1) {
+ norms.set(i, diagonalMatrix.get(i, i + 1));
+ projections.set(i, diagonalMatrix.get(i, i));
+ i++;
+ }
+ projections.set(i, diagonalMatrix.get(i, i));
+ persistVector(new Path(baseDir, "projections"), 0, projections);
+ persistVector(new Path(baseDir, "norms"), 0, norms);
+ persistVector(new Path(baseDir, "scaleFactor"), 0, new DenseVector(new double[] {scaleFactor}));
+ for (Map.Entry<Integer, Vector> entry : singularVectors.entrySet()) {
+ persistVector(new Path(singularVectorPath, SINGULAR_PREFIX + '_' + entry.getKey()),
+ entry.getKey(), entry.getValue());
+ }
+ super.setIterationNumber(numBasisVectorsOnDisk);
+ }
+
+ protected void persistVector(Path p, int key, Vector vector) throws IOException {
+ SequenceFile.Writer writer = null;
+ try {
+ if (fs.exists(p)) {
+ log.warn("{} exists, will overwrite", p);
+ fs.delete(p, true);
+ }
+ writer = new SequenceFile.Writer(fs, conf, p,
+ IntWritable.class, VectorWritable.class);
+ writer.append(new IntWritable(key), new VectorWritable(vector));
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ protected Vector fetchVector(Path p, int keyIndex) throws IOException {
+ if (!fs.exists(p)) {
+ return null;
+ }
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, p, conf);
+ IntWritable key = new IntWritable();
+ VectorWritable vw = new VectorWritable();
+ while (reader.next(key, vw)) {
+ if (key.get() == keyIndex) {
+ return vw.get();
+ }
+ }
+ return null;
+ }
+
+ @Override
+ public Vector getBasisVector(int i) {
+ if (!basis.containsKey(i)) {
+ try {
+ Vector v = fetchVector(new Path(basisPath, BASIS_PREFIX + '_' + i), i);
+ basis.put(i, v);
+ } catch (IOException e) {
+ log.error("Could not load basis vector: {}", i, e);
+ }
+ }
+ return super.getBasisVector(i);
+ }
+
+ @Override
+ public Vector getRightSingularVector(int i) {
+ if (!singularVectors.containsKey(i)) {
+ try {
+ Vector v = fetchVector(new Path(singularVectorPath, BASIS_PREFIX + '_' + i), i);
+ singularVectors.put(i, v);
+ } catch (IOException e) {
+ log.error("Could not load singular vector: {}", i, e);
+ }
+ }
+ return super.getRightSingularVector(i);
+ }
+
+ @Override
+ public double getScaleFactor() {
+ if (scaleFactor <= 0) {
+ try {
+ Vector v = fetchVector(new Path(baseDir, "scaleFactor"), 0);
+ if (v != null && v.size() > 0) {
+ scaleFactor = v.get(0);
+ }
+ } catch (IOException e) {
+ log.error("could not load scaleFactor:", e);
+ }
+ }
+ return scaleFactor;
+ }
+
+ @Override
+ public Matrix getDiagonalMatrix() {
+ if (diagonalMatrix == null) {
+ diagonalMatrix = new DenseMatrix(desiredRank, desiredRank);
+ }
+ if (diagonalMatrix.get(0, 1) <= 0) {
+ try {
+ Vector norms = fetchVector(new Path(baseDir, "norms"), 0);
+ Vector projections = fetchVector(new Path(baseDir, "projections"), 0);
+ if (norms != null && projections != null) {
+ int i = 0;
+ while (i < projections.size() - 1) {
+ diagonalMatrix.set(i, i, projections.get(i));
+ diagonalMatrix.set(i, i + 1, norms.get(i));
+ diagonalMatrix.set(i + 1, i, norms.get(i));
+ i++;
+ }
+ diagonalMatrix.set(i, i, projections.get(i));
+ }
+ } catch (IOException e) {
+ log.error("Could not load diagonal matrix of norms and projections: ", e);
+ }
+ }
+ return diagonalMatrix;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java
new file mode 100644
index 0000000..9119f69
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java
@@ -0,0 +1,104 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.kmeans.Kluster;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collections;
+import java.util.List;
+
+final class SeedVectorUtil {
+
+ private static final Logger log = LoggerFactory.getLogger(SeedVectorUtil.class);
+
+ private SeedVectorUtil() {
+ }
+
+ public static List<NamedVector> loadSeedVectors(Configuration conf) {
+
+ String seedPathStr = conf.get(VectorDistanceSimilarityJob.SEEDS_PATH_KEY);
+ if (seedPathStr == null || seedPathStr.isEmpty()) {
+ return Collections.emptyList();
+ }
+
+ List<NamedVector> seedVectors = Lists.newArrayList();
+ long item = 0;
+ for (Writable value
+ : new SequenceFileDirValueIterable<>(new Path(seedPathStr),
+ PathType.LIST,
+ PathFilters.partFilter(),
+ conf)) {
+ Class<? extends Writable> valueClass = value.getClass();
+ if (valueClass.equals(Kluster.class)) {
+ // get the cluster info
+ Kluster cluster = (Kluster) value;
+ Vector vector = cluster.getCenter();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, cluster.getIdentifier()));
+ }
+ } else if (valueClass.equals(Canopy.class)) {
+ // get the cluster info
+ Canopy canopy = (Canopy) value;
+ Vector vector = canopy.getCenter();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, canopy.getIdentifier()));
+ }
+ } else if (valueClass.equals(Vector.class)) {
+ Vector vector = (Vector) value;
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, seedPathStr + '.' + item++));
+ }
+ } else if (valueClass.equals(VectorWritable.class) || valueClass.isInstance(VectorWritable.class)) {
+ VectorWritable vw = (VectorWritable) value;
+ Vector vector = vw.get();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, seedPathStr + '.' + item++));
+ }
+ } else {
+ throw new IllegalStateException("Bad value class: " + valueClass);
+ }
+ }
+ if (seedVectors.isEmpty()) {
+ throw new IllegalStateException("No seeds found. Check your path: " + seedPathStr);
+ }
+ log.info("Seed Vectors size: {}", seedVectors.size());
+ return seedVectors;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java
new file mode 100644
index 0000000..c45d55a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Similar to {@link VectorDistanceMapper}, except it outputs
+ * <input, Vector>, where the vector is a dense vector contain one entry for every seed vector
+ */
+public final class VectorDistanceInvertedMapper
+ extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
+
+ private DistanceMeasure measure;
+ private List<NamedVector> seedVectors;
+
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ String keyName;
+ Vector valVec = value.get();
+ if (valVec instanceof NamedVector) {
+ keyName = ((NamedVector) valVec).getName();
+ } else {
+ keyName = key.toString();
+ }
+ Vector outVec = new DenseVector(new double[seedVectors.size()]);
+ int i = 0;
+ for (NamedVector seedVector : seedVectors) {
+ outVec.setQuick(i++, measure.distance(seedVector, valVec));
+ }
+ context.write(new Text(keyName), new VectorWritable(outVec));
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ measure =
+ ClassUtils.instantiateAs(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY), DistanceMeasure.class);
+ measure.configure(conf);
+ seedVectors = SeedVectorUtil.loadSeedVectors(conf);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
new file mode 100644
index 0000000..9fccd8e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.List;
+
+public final class VectorDistanceMapper
+ extends Mapper<WritableComparable<?>, VectorWritable, StringTuple, DoubleWritable> {
+
+ private DistanceMeasure measure;
+ private List<NamedVector> seedVectors;
+ private boolean usesThreshold = false;
+ private double maxDistance;
+
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ String keyName;
+ Vector valVec = value.get();
+ if (valVec instanceof NamedVector) {
+ keyName = ((NamedVector) valVec).getName();
+ } else {
+ keyName = key.toString();
+ }
+
+ for (NamedVector seedVector : seedVectors) {
+ double distance = measure.distance(seedVector, valVec);
+ if (!usesThreshold || distance <= maxDistance) {
+ StringTuple outKey = new StringTuple();
+ outKey.add(seedVector.getName());
+ outKey.add(keyName);
+ context.write(outKey, new DoubleWritable(distance));
+ }
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+
+ String maxDistanceParam = conf.get(VectorDistanceSimilarityJob.MAX_DISTANCE);
+ if (maxDistanceParam != null) {
+ usesThreshold = true;
+ maxDistance = Double.parseDouble(maxDistanceParam);
+ }
+
+ measure = ClassUtils.instantiateAs(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY),
+ DistanceMeasure.class);
+ measure.configure(conf);
+ seedVectors = SeedVectorUtil.loadSeedVectors(conf);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
new file mode 100644
index 0000000..9f58f1e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
@@ -0,0 +1,153 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.base.Preconditions;
+
+import java.io.IOException;
+
+/**
+ * This class does a Map-side join between seed vectors (the map side can also be a Cluster) and a list of other vectors
+ * and emits the a tuple of seed id, other id, distance. It is a more generic version of KMean's mapper
+ */
+public class VectorDistanceSimilarityJob extends AbstractJob {
+
+ public static final String SEEDS = "seeds";
+ public static final String SEEDS_PATH_KEY = "seedsPath";
+ public static final String DISTANCE_MEASURE_KEY = "vectorDistSim.measure";
+ public static final String OUT_TYPE_KEY = "outType";
+ public static final String MAX_DISTANCE = "maxDistance";
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+ addOption(SEEDS, "s", "The set of vectors to compute distances against. Must fit in memory on the mapper");
+ addOption(MAX_DISTANCE, "mx", "set an upper-bound on distance (double) such that any pair of vectors with a"
+ + " distance greater than this value is ignored in the output. Ignored for non pairwise output!");
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addOption(OUT_TYPE_KEY, "ot", "[pw|v] -- Define the output style: pairwise, the default, (pw) or vector (v). "
+ + "Pairwise is a tuple of <seed, other, distance>, vector is <other, <Vector of size the number of seeds>>.",
+ "pw");
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ Path seeds = new Path(getOption(SEEDS));
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ if (measureClass == null) {
+ measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+ }
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
+ String outType = getOption(OUT_TYPE_KEY, "pw");
+
+ Double maxDistance = null;
+
+ if ("pw".equals(outType)) {
+ String maxDistanceArg = getOption(MAX_DISTANCE);
+ if (maxDistanceArg != null) {
+ maxDistance = Double.parseDouble(maxDistanceArg);
+ Preconditions.checkArgument(maxDistance > 0.0d, "value for " + MAX_DISTANCE + " must be greater than zero");
+ }
+ }
+
+ run(getConf(), input, seeds, output, measure, outType, maxDistance);
+ return 0;
+ }
+
+ public static void run(Configuration conf,
+ Path input,
+ Path seeds,
+ Path output,
+ DistanceMeasure measure, String outType)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ run(conf, input, seeds, output, measure, outType, null);
+ }
+
+ public static void run(Configuration conf,
+ Path input,
+ Path seeds,
+ Path output,
+ DistanceMeasure measure, String outType, Double maxDistance)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ if (maxDistance != null) {
+ conf.set(MAX_DISTANCE, String.valueOf(maxDistance));
+ }
+ conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName());
+ conf.set(SEEDS_PATH_KEY, seeds.toString());
+ Job job = new Job(conf, "Vector Distance Similarity: seeds: " + seeds + " input: " + input);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ if ("pw".equalsIgnoreCase(outType)) {
+ job.setMapOutputKeyClass(StringTuple.class);
+ job.setOutputKeyClass(StringTuple.class);
+ job.setMapOutputValueClass(DoubleWritable.class);
+ job.setOutputValueClass(DoubleWritable.class);
+ job.setMapperClass(VectorDistanceMapper.class);
+ } else if ("v".equalsIgnoreCase(outType)) {
+ job.setMapOutputKeyClass(Text.class);
+ job.setOutputKeyClass(Text.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setMapperClass(VectorDistanceInvertedMapper.class);
+ } else {
+ throw new IllegalArgumentException("Invalid outType specified: " + outType);
+ }
+
+ job.setNumReduceTasks(0);
+ FileInputFormat.addInputPath(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setJarByClass(VectorDistanceSimilarityJob.class);
+ HadoopUtil.delete(conf, output);
+ if (!job.waitForCompletion(true)) {
+ throw new IllegalStateException("VectorDistance Similarity failed processing " + seeds);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/MutableElement.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/MutableElement.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/MutableElement.java
new file mode 100644
index 0000000..ecd0d94
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/MutableElement.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence;
+
+import org.apache.mahout.math.Vector;
+
+public class MutableElement implements Vector.Element {
+
+ private int index;
+ private double value;
+
+ MutableElement(int index, double value) {
+ this.index = index;
+ this.value = value;
+ }
+
+ @Override
+ public double get() {
+ return value;
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ public void setIndex(int index) {
+ this.index = index;
+ }
+
+ @Override
+ public void set(double value) {
+ this.value = value;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob.java
new file mode 100644
index 0000000..fb28821
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob.java
@@ -0,0 +1,562 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence;
+
+import com.google.common.base.Preconditions;
+import com.google.common.primitives.Ints;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.mapreduce.VectorSumCombiner;
+import org.apache.mahout.common.mapreduce.VectorSumReducer;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasures;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure;
+import org.apache.mahout.math.map.OpenIntIntHashMap;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+public class RowSimilarityJob extends AbstractJob {
+
+ public static final double NO_THRESHOLD = Double.MIN_VALUE;
+ public static final long NO_FIXED_RANDOM_SEED = Long.MIN_VALUE;
+
+ private static final String SIMILARITY_CLASSNAME = RowSimilarityJob.class + ".distributedSimilarityClassname";
+ private static final String NUMBER_OF_COLUMNS = RowSimilarityJob.class + ".numberOfColumns";
+ private static final String MAX_SIMILARITIES_PER_ROW = RowSimilarityJob.class + ".maxSimilaritiesPerRow";
+ private static final String EXCLUDE_SELF_SIMILARITY = RowSimilarityJob.class + ".excludeSelfSimilarity";
+
+ private static final String THRESHOLD = RowSimilarityJob.class + ".threshold";
+ private static final String NORMS_PATH = RowSimilarityJob.class + ".normsPath";
+ private static final String MAXVALUES_PATH = RowSimilarityJob.class + ".maxWeightsPath";
+
+ private static final String NUM_NON_ZERO_ENTRIES_PATH = RowSimilarityJob.class + ".nonZeroEntriesPath";
+ private static final int DEFAULT_MAX_SIMILARITIES_PER_ROW = 100;
+
+ private static final String OBSERVATIONS_PER_COLUMN_PATH = RowSimilarityJob.class + ".observationsPerColumnPath";
+
+ private static final String MAX_OBSERVATIONS_PER_ROW = RowSimilarityJob.class + ".maxObservationsPerRow";
+ private static final String MAX_OBSERVATIONS_PER_COLUMN = RowSimilarityJob.class + ".maxObservationsPerColumn";
+ private static final String RANDOM_SEED = RowSimilarityJob.class + ".randomSeed";
+
+ private static final int DEFAULT_MAX_OBSERVATIONS_PER_ROW = 500;
+ private static final int DEFAULT_MAX_OBSERVATIONS_PER_COLUMN = 500;
+
+ private static final int NORM_VECTOR_MARKER = Integer.MIN_VALUE;
+ private static final int MAXVALUE_VECTOR_MARKER = Integer.MIN_VALUE + 1;
+ private static final int NUM_NON_ZERO_ENTRIES_VECTOR_MARKER = Integer.MIN_VALUE + 2;
+
+ enum Counters { ROWS, USED_OBSERVATIONS, NEGLECTED_OBSERVATIONS, COOCCURRENCES, PRUNED_COOCCURRENCES }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new RowSimilarityJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("numberOfColumns", "r", "Number of columns in the input matrix", false);
+ addOption("similarityClassname", "s", "Name of distributed similarity class to instantiate, alternatively use "
+ + "one of the predefined similarities (" + VectorSimilarityMeasures.list() + ')');
+ addOption("maxSimilaritiesPerRow", "m", "Number of maximum similarities per row (default: "
+ + DEFAULT_MAX_SIMILARITIES_PER_ROW + ')', String.valueOf(DEFAULT_MAX_SIMILARITIES_PER_ROW));
+ addOption("excludeSelfSimilarity", "ess", "compute similarity of rows to themselves?", String.valueOf(false));
+ addOption("threshold", "tr", "discard row pairs with a similarity value below this", false);
+ addOption("maxObservationsPerRow", null, "sample rows down to this number of entries",
+ String.valueOf(DEFAULT_MAX_OBSERVATIONS_PER_ROW));
+ addOption("maxObservationsPerColumn", null, "sample columns down to this number of entries",
+ String.valueOf(DEFAULT_MAX_OBSERVATIONS_PER_COLUMN));
+ addOption("randomSeed", null, "use this seed for sampling", false);
+ addOption(DefaultOptionCreator.overwriteOption().create());
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ int numberOfColumns;
+
+ if (hasOption("numberOfColumns")) {
+ // Number of columns explicitly specified via CLI
+ numberOfColumns = Integer.parseInt(getOption("numberOfColumns"));
+ } else {
+ // else get the number of columns by determining the cardinality of a vector in the input matrix
+ numberOfColumns = getDimensions(getInputPath());
+ }
+
+ String similarityClassnameArg = getOption("similarityClassname");
+ String similarityClassname;
+ try {
+ similarityClassname = VectorSimilarityMeasures.valueOf(similarityClassnameArg).getClassname();
+ } catch (IllegalArgumentException iae) {
+ similarityClassname = similarityClassnameArg;
+ }
+
+ // Clear the output and temp paths if the overwrite option has been set
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ // Clear the temp path
+ HadoopUtil.delete(getConf(), getTempPath());
+ // Clear the output path
+ HadoopUtil.delete(getConf(), getOutputPath());
+ }
+
+ int maxSimilaritiesPerRow = Integer.parseInt(getOption("maxSimilaritiesPerRow"));
+ boolean excludeSelfSimilarity = Boolean.parseBoolean(getOption("excludeSelfSimilarity"));
+ double threshold = hasOption("threshold")
+ ? Double.parseDouble(getOption("threshold")) : NO_THRESHOLD;
+ long randomSeed = hasOption("randomSeed")
+ ? Long.parseLong(getOption("randomSeed")) : NO_FIXED_RANDOM_SEED;
+
+ int maxObservationsPerRow = Integer.parseInt(getOption("maxObservationsPerRow"));
+ int maxObservationsPerColumn = Integer.parseInt(getOption("maxObservationsPerColumn"));
+
+ Path weightsPath = getTempPath("weights");
+ Path normsPath = getTempPath("norms.bin");
+ Path numNonZeroEntriesPath = getTempPath("numNonZeroEntries.bin");
+ Path maxValuesPath = getTempPath("maxValues.bin");
+ Path pairwiseSimilarityPath = getTempPath("pairwiseSimilarity");
+
+ Path observationsPerColumnPath = getTempPath("observationsPerColumn.bin");
+
+ AtomicInteger currentPhase = new AtomicInteger();
+
+ Job countObservations = prepareJob(getInputPath(), getTempPath("notUsed"), CountObservationsMapper.class,
+ NullWritable.class, VectorWritable.class, SumObservationsReducer.class, NullWritable.class,
+ VectorWritable.class);
+ countObservations.setCombinerClass(VectorSumCombiner.class);
+ countObservations.getConfiguration().set(OBSERVATIONS_PER_COLUMN_PATH, observationsPerColumnPath.toString());
+ countObservations.setNumReduceTasks(1);
+ countObservations.waitForCompletion(true);
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ Job normsAndTranspose = prepareJob(getInputPath(), weightsPath, VectorNormMapper.class, IntWritable.class,
+ VectorWritable.class, MergeVectorsReducer.class, IntWritable.class, VectorWritable.class);
+ normsAndTranspose.setCombinerClass(MergeVectorsCombiner.class);
+ Configuration normsAndTransposeConf = normsAndTranspose.getConfiguration();
+ normsAndTransposeConf.set(THRESHOLD, String.valueOf(threshold));
+ normsAndTransposeConf.set(NORMS_PATH, normsPath.toString());
+ normsAndTransposeConf.set(NUM_NON_ZERO_ENTRIES_PATH, numNonZeroEntriesPath.toString());
+ normsAndTransposeConf.set(MAXVALUES_PATH, maxValuesPath.toString());
+ normsAndTransposeConf.set(SIMILARITY_CLASSNAME, similarityClassname);
+ normsAndTransposeConf.set(OBSERVATIONS_PER_COLUMN_PATH, observationsPerColumnPath.toString());
+ normsAndTransposeConf.set(MAX_OBSERVATIONS_PER_ROW, String.valueOf(maxObservationsPerRow));
+ normsAndTransposeConf.set(MAX_OBSERVATIONS_PER_COLUMN, String.valueOf(maxObservationsPerColumn));
+ normsAndTransposeConf.set(RANDOM_SEED, String.valueOf(randomSeed));
+
+ boolean succeeded = normsAndTranspose.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ Job pairwiseSimilarity = prepareJob(weightsPath, pairwiseSimilarityPath, CooccurrencesMapper.class,
+ IntWritable.class, VectorWritable.class, SimilarityReducer.class, IntWritable.class, VectorWritable.class);
+ pairwiseSimilarity.setCombinerClass(VectorSumReducer.class);
+ Configuration pairwiseConf = pairwiseSimilarity.getConfiguration();
+ pairwiseConf.set(THRESHOLD, String.valueOf(threshold));
+ pairwiseConf.set(NORMS_PATH, normsPath.toString());
+ pairwiseConf.set(NUM_NON_ZERO_ENTRIES_PATH, numNonZeroEntriesPath.toString());
+ pairwiseConf.set(MAXVALUES_PATH, maxValuesPath.toString());
+ pairwiseConf.set(SIMILARITY_CLASSNAME, similarityClassname);
+ pairwiseConf.setInt(NUMBER_OF_COLUMNS, numberOfColumns);
+ pairwiseConf.setBoolean(EXCLUDE_SELF_SIMILARITY, excludeSelfSimilarity);
+ boolean succeeded = pairwiseSimilarity.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ Job asMatrix = prepareJob(pairwiseSimilarityPath, getOutputPath(), UnsymmetrifyMapper.class,
+ IntWritable.class, VectorWritable.class, MergeToTopKSimilaritiesReducer.class, IntWritable.class,
+ VectorWritable.class);
+ asMatrix.setCombinerClass(MergeToTopKSimilaritiesReducer.class);
+ asMatrix.getConfiguration().setInt(MAX_SIMILARITIES_PER_ROW, maxSimilaritiesPerRow);
+ boolean succeeded = asMatrix.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ return 0;
+ }
+
+ public static class CountObservationsMapper extends Mapper<IntWritable,VectorWritable,NullWritable,VectorWritable> {
+
+ private Vector columnCounts = new RandomAccessSparseVector(Integer.MAX_VALUE);
+
+ @Override
+ protected void map(IntWritable rowIndex, VectorWritable rowVectorWritable, Context ctx)
+ throws IOException, InterruptedException {
+
+ Vector row = rowVectorWritable.get();
+ for (Vector.Element elem : row.nonZeroes()) {
+ columnCounts.setQuick(elem.index(), columnCounts.getQuick(elem.index()) + 1);
+ }
+ }
+
+ @Override
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ ctx.write(NullWritable.get(), new VectorWritable(columnCounts));
+ }
+ }
+
+ public static class SumObservationsReducer extends Reducer<NullWritable,VectorWritable,NullWritable,VectorWritable> {
+ @Override
+ protected void reduce(NullWritable nullWritable, Iterable<VectorWritable> partialVectors, Context ctx)
+ throws IOException, InterruptedException {
+ Vector counts = Vectors.sum(partialVectors.iterator());
+ Vectors.write(counts, new Path(ctx.getConfiguration().get(OBSERVATIONS_PER_COLUMN_PATH)), ctx.getConfiguration());
+ }
+ }
+
+ public static class VectorNormMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private VectorSimilarityMeasure similarity;
+ private Vector norms;
+ private Vector nonZeroEntries;
+ private Vector maxValues;
+ private double threshold;
+
+ private OpenIntIntHashMap observationsPerColumn;
+ private int maxObservationsPerRow;
+ private int maxObservationsPerColumn;
+
+ private Random random;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+
+ Configuration conf = ctx.getConfiguration();
+
+ similarity = ClassUtils.instantiateAs(conf.get(SIMILARITY_CLASSNAME), VectorSimilarityMeasure.class);
+ norms = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ nonZeroEntries = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ maxValues = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ threshold = Double.parseDouble(conf.get(THRESHOLD));
+
+ observationsPerColumn = Vectors.readAsIntMap(new Path(conf.get(OBSERVATIONS_PER_COLUMN_PATH)), conf);
+ maxObservationsPerRow = conf.getInt(MAX_OBSERVATIONS_PER_ROW, DEFAULT_MAX_OBSERVATIONS_PER_ROW);
+ maxObservationsPerColumn = conf.getInt(MAX_OBSERVATIONS_PER_COLUMN, DEFAULT_MAX_OBSERVATIONS_PER_COLUMN);
+
+ long seed = Long.parseLong(conf.get(RANDOM_SEED));
+ if (seed == NO_FIXED_RANDOM_SEED) {
+ random = RandomUtils.getRandom();
+ } else {
+ random = RandomUtils.getRandom(seed);
+ }
+ }
+
+ private Vector sampleDown(Vector rowVector, Context ctx) {
+
+ int observationsPerRow = rowVector.getNumNondefaultElements();
+ double rowSampleRate = (double) Math.min(maxObservationsPerRow, observationsPerRow) / (double) observationsPerRow;
+
+ Vector downsampledRow = rowVector.like();
+ long usedObservations = 0;
+ long neglectedObservations = 0;
+
+ for (Vector.Element elem : rowVector.nonZeroes()) {
+
+ int columnCount = observationsPerColumn.get(elem.index());
+ double columnSampleRate = (double) Math.min(maxObservationsPerColumn, columnCount) / (double) columnCount;
+
+ if (random.nextDouble() <= Math.min(rowSampleRate, columnSampleRate)) {
+ downsampledRow.setQuick(elem.index(), elem.get());
+ usedObservations++;
+ } else {
+ neglectedObservations++;
+ }
+
+ }
+
+ ctx.getCounter(Counters.USED_OBSERVATIONS).increment(usedObservations);
+ ctx.getCounter(Counters.NEGLECTED_OBSERVATIONS).increment(neglectedObservations);
+
+ return downsampledRow;
+ }
+
+ @Override
+ protected void map(IntWritable row, VectorWritable vectorWritable, Context ctx)
+ throws IOException, InterruptedException {
+
+ Vector sampledRowVector = sampleDown(vectorWritable.get(), ctx);
+
+ Vector rowVector = similarity.normalize(sampledRowVector);
+
+ int numNonZeroEntries = 0;
+ double maxValue = Double.MIN_VALUE;
+
+ for (Vector.Element element : rowVector.nonZeroes()) {
+ RandomAccessSparseVector partialColumnVector = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ partialColumnVector.setQuick(row.get(), element.get());
+ ctx.write(new IntWritable(element.index()), new VectorWritable(partialColumnVector));
+
+ numNonZeroEntries++;
+ if (maxValue < element.get()) {
+ maxValue = element.get();
+ }
+ }
+
+ if (threshold != NO_THRESHOLD) {
+ nonZeroEntries.setQuick(row.get(), numNonZeroEntries);
+ maxValues.setQuick(row.get(), maxValue);
+ }
+ norms.setQuick(row.get(), similarity.norm(rowVector));
+
+ ctx.getCounter(Counters.ROWS).increment(1);
+ }
+
+ @Override
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ ctx.write(new IntWritable(NORM_VECTOR_MARKER), new VectorWritable(norms));
+ ctx.write(new IntWritable(NUM_NON_ZERO_ENTRIES_VECTOR_MARKER), new VectorWritable(nonZeroEntries));
+ ctx.write(new IntWritable(MAXVALUE_VECTOR_MARKER), new VectorWritable(maxValues));
+ }
+ }
+
+ private static class MergeVectorsCombiner extends Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+ @Override
+ protected void reduce(IntWritable row, Iterable<VectorWritable> partialVectors, Context ctx)
+ throws IOException, InterruptedException {
+ ctx.write(row, new VectorWritable(Vectors.merge(partialVectors)));
+ }
+ }
+
+ public static class MergeVectorsReducer extends Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private Path normsPath;
+ private Path numNonZeroEntriesPath;
+ private Path maxValuesPath;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ normsPath = new Path(ctx.getConfiguration().get(NORMS_PATH));
+ numNonZeroEntriesPath = new Path(ctx.getConfiguration().get(NUM_NON_ZERO_ENTRIES_PATH));
+ maxValuesPath = new Path(ctx.getConfiguration().get(MAXVALUES_PATH));
+ }
+
+ @Override
+ protected void reduce(IntWritable row, Iterable<VectorWritable> partialVectors, Context ctx)
+ throws IOException, InterruptedException {
+ Vector partialVector = Vectors.merge(partialVectors);
+
+ if (row.get() == NORM_VECTOR_MARKER) {
+ Vectors.write(partialVector, normsPath, ctx.getConfiguration());
+ } else if (row.get() == MAXVALUE_VECTOR_MARKER) {
+ Vectors.write(partialVector, maxValuesPath, ctx.getConfiguration());
+ } else if (row.get() == NUM_NON_ZERO_ENTRIES_VECTOR_MARKER) {
+ Vectors.write(partialVector, numNonZeroEntriesPath, ctx.getConfiguration(), true);
+ } else {
+ ctx.write(row, new VectorWritable(partialVector));
+ }
+ }
+ }
+
+
+ public static class CooccurrencesMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private VectorSimilarityMeasure similarity;
+
+ private OpenIntIntHashMap numNonZeroEntries;
+ private Vector maxValues;
+ private double threshold;
+
+ private static final Comparator<Vector.Element> BY_INDEX = new Comparator<Vector.Element>() {
+ @Override
+ public int compare(Vector.Element one, Vector.Element two) {
+ return Ints.compare(one.index(), two.index());
+ }
+ };
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ similarity = ClassUtils.instantiateAs(ctx.getConfiguration().get(SIMILARITY_CLASSNAME),
+ VectorSimilarityMeasure.class);
+ numNonZeroEntries = Vectors.readAsIntMap(new Path(ctx.getConfiguration().get(NUM_NON_ZERO_ENTRIES_PATH)),
+ ctx.getConfiguration());
+ maxValues = Vectors.read(new Path(ctx.getConfiguration().get(MAXVALUES_PATH)), ctx.getConfiguration());
+ threshold = Double.parseDouble(ctx.getConfiguration().get(THRESHOLD));
+ }
+
+ private boolean consider(Vector.Element occurrenceA, Vector.Element occurrenceB) {
+ int numNonZeroEntriesA = numNonZeroEntries.get(occurrenceA.index());
+ int numNonZeroEntriesB = numNonZeroEntries.get(occurrenceB.index());
+
+ double maxValueA = maxValues.get(occurrenceA.index());
+ double maxValueB = maxValues.get(occurrenceB.index());
+
+ return similarity.consider(numNonZeroEntriesA, numNonZeroEntriesB, maxValueA, maxValueB, threshold);
+ }
+
+ @Override
+ protected void map(IntWritable column, VectorWritable occurrenceVector, Context ctx)
+ throws IOException, InterruptedException {
+ Vector.Element[] occurrences = Vectors.toArray(occurrenceVector);
+ Arrays.sort(occurrences, BY_INDEX);
+
+ int cooccurrences = 0;
+ int prunedCooccurrences = 0;
+ for (int n = 0; n < occurrences.length; n++) {
+ Vector.Element occurrenceA = occurrences[n];
+ Vector dots = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ for (int m = n; m < occurrences.length; m++) {
+ Vector.Element occurrenceB = occurrences[m];
+ if (threshold == NO_THRESHOLD || consider(occurrenceA, occurrenceB)) {
+ dots.setQuick(occurrenceB.index(), similarity.aggregate(occurrenceA.get(), occurrenceB.get()));
+ cooccurrences++;
+ } else {
+ prunedCooccurrences++;
+ }
+ }
+ ctx.write(new IntWritable(occurrenceA.index()), new VectorWritable(dots));
+ }
+ ctx.getCounter(Counters.COOCCURRENCES).increment(cooccurrences);
+ ctx.getCounter(Counters.PRUNED_COOCCURRENCES).increment(prunedCooccurrences);
+ }
+ }
+
+
+ public static class SimilarityReducer extends Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private VectorSimilarityMeasure similarity;
+ private int numberOfColumns;
+ private boolean excludeSelfSimilarity;
+ private Vector norms;
+ private double treshold;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ similarity = ClassUtils.instantiateAs(ctx.getConfiguration().get(SIMILARITY_CLASSNAME),
+ VectorSimilarityMeasure.class);
+ numberOfColumns = ctx.getConfiguration().getInt(NUMBER_OF_COLUMNS, -1);
+ Preconditions.checkArgument(numberOfColumns > 0, "Number of columns must be greater then 0! But numberOfColumns = " + numberOfColumns);
+ excludeSelfSimilarity = ctx.getConfiguration().getBoolean(EXCLUDE_SELF_SIMILARITY, false);
+ norms = Vectors.read(new Path(ctx.getConfiguration().get(NORMS_PATH)), ctx.getConfiguration());
+ treshold = Double.parseDouble(ctx.getConfiguration().get(THRESHOLD));
+ }
+
+ @Override
+ protected void reduce(IntWritable row, Iterable<VectorWritable> partialDots, Context ctx)
+ throws IOException, InterruptedException {
+ Iterator<VectorWritable> partialDotsIterator = partialDots.iterator();
+ Vector dots = partialDotsIterator.next().get();
+ while (partialDotsIterator.hasNext()) {
+ Vector toAdd = partialDotsIterator.next().get();
+ for (Element nonZeroElement : toAdd.nonZeroes()) {
+ dots.setQuick(nonZeroElement.index(), dots.getQuick(nonZeroElement.index()) + nonZeroElement.get());
+ }
+ }
+
+ Vector similarities = dots.like();
+ double normA = norms.getQuick(row.get());
+ for (Element b : dots.nonZeroes()) {
+ double similarityValue = similarity.similarity(b.get(), normA, norms.getQuick(b.index()), numberOfColumns);
+ if (similarityValue >= treshold) {
+ similarities.set(b.index(), similarityValue);
+ }
+ }
+ if (excludeSelfSimilarity) {
+ similarities.setQuick(row.get(), 0);
+ }
+ ctx.write(row, new VectorWritable(similarities));
+ }
+ }
+
+ public static class UnsymmetrifyMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private int maxSimilaritiesPerRow;
+
+ @Override
+ protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
+ maxSimilaritiesPerRow = ctx.getConfiguration().getInt(MAX_SIMILARITIES_PER_ROW, 0);
+ Preconditions.checkArgument(maxSimilaritiesPerRow > 0, "Maximum number of similarities per row must be greater then 0!");
+ }
+
+ @Override
+ protected void map(IntWritable row, VectorWritable similaritiesWritable, Context ctx)
+ throws IOException, InterruptedException {
+ Vector similarities = similaritiesWritable.get();
+ // For performance, the creation of transposedPartial is moved out of the while loop and it is reused inside
+ Vector transposedPartial = new RandomAccessSparseVector(similarities.size(), 1);
+ TopElementsQueue topKQueue = new TopElementsQueue(maxSimilaritiesPerRow);
+ for (Element nonZeroElement : similarities.nonZeroes()) {
+ MutableElement top = topKQueue.top();
+ double candidateValue = nonZeroElement.get();
+ if (candidateValue > top.get()) {
+ top.setIndex(nonZeroElement.index());
+ top.set(candidateValue);
+ topKQueue.updateTop();
+ }
+
+ transposedPartial.setQuick(row.get(), candidateValue);
+ ctx.write(new IntWritable(nonZeroElement.index()), new VectorWritable(transposedPartial));
+ transposedPartial.setQuick(row.get(), 0.0);
+ }
+ Vector topKSimilarities = new RandomAccessSparseVector(similarities.size(), maxSimilaritiesPerRow);
+ for (Vector.Element topKSimilarity : topKQueue.getTopElements()) {
+ topKSimilarities.setQuick(topKSimilarity.index(), topKSimilarity.get());
+ }
+ ctx.write(row, new VectorWritable(topKSimilarities));
+ }
+ }
+
+ public static class MergeToTopKSimilaritiesReducer
+ extends Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private int maxSimilaritiesPerRow;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ maxSimilaritiesPerRow = ctx.getConfiguration().getInt(MAX_SIMILARITIES_PER_ROW, 0);
+ Preconditions.checkArgument(maxSimilaritiesPerRow > 0, "Maximum number of similarities per row must be greater then 0!");
+ }
+
+ @Override
+ protected void reduce(IntWritable row, Iterable<VectorWritable> partials, Context ctx)
+ throws IOException, InterruptedException {
+ Vector allSimilarities = Vectors.merge(partials);
+ Vector topKSimilarities = Vectors.topKElements(maxSimilaritiesPerRow, allSimilarities);
+ ctx.write(row, new VectorWritable(topKSimilarities));
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/TopElementsQueue.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/TopElementsQueue.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/TopElementsQueue.java
new file mode 100644
index 0000000..34135ac
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/TopElementsQueue.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence;
+
+import com.google.common.collect.Lists;
+import org.apache.lucene.util.PriorityQueue;
+
+import java.util.Collections;
+import java.util.List;
+
+public class TopElementsQueue extends PriorityQueue<MutableElement> {
+
+ private final int maxSize;
+
+ private static final int SENTINEL_INDEX = Integer.MIN_VALUE;
+
+ public TopElementsQueue(int maxSize) {
+ super(maxSize);
+ this.maxSize = maxSize;
+ }
+
+ public List<MutableElement> getTopElements() {
+ List<MutableElement> topElements = Lists.newArrayListWithCapacity(maxSize);
+ while (size() > 0) {
+ MutableElement top = pop();
+ // filter out "sentinel" objects necessary for maintaining an efficient priority queue
+ if (top.index() != SENTINEL_INDEX) {
+ topElements.add(top);
+ }
+ }
+ Collections.reverse(topElements);
+ return topElements;
+ }
+
+ @Override
+ protected MutableElement getSentinelObject() {
+ return new MutableElement(SENTINEL_INDEX, Double.MIN_VALUE);
+ }
+
+ @Override
+ protected boolean lessThan(MutableElement e1, MutableElement e2) {
+ return e1.get() < e2.get();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/Vectors.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/Vectors.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/Vectors.java
new file mode 100644
index 0000000..66fb0ae
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/Vectors.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence;
+
+import java.io.DataInput;
+import java.io.IOException;
+import java.util.Iterator;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.iterator.FixedSizeSamplingIterator;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Varint;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.map.OpenIntIntHashMap;
+
+public final class Vectors {
+
+ private Vectors() {}
+
+ public static Vector maybeSample(Vector original, int sampleSize) {
+ if (original.getNumNondefaultElements() <= sampleSize) {
+ return original;
+ }
+ Vector sample = new RandomAccessSparseVector(original.size(), sampleSize);
+ Iterator<Element> sampledElements =
+ new FixedSizeSamplingIterator<>(sampleSize, original.nonZeroes().iterator());
+ while (sampledElements.hasNext()) {
+ Element elem = sampledElements.next();
+ sample.setQuick(elem.index(), elem.get());
+ }
+ return sample;
+ }
+
+ public static Vector topKElements(int k, Vector original) {
+ if (original.getNumNondefaultElements() <= k) {
+ return original;
+ }
+
+ TopElementsQueue topKQueue = new TopElementsQueue(k);
+ for (Element nonZeroElement : original.nonZeroes()) {
+ MutableElement top = topKQueue.top();
+ double candidateValue = nonZeroElement.get();
+ if (candidateValue > top.get()) {
+ top.setIndex(nonZeroElement.index());
+ top.set(candidateValue);
+ topKQueue.updateTop();
+ }
+ }
+
+ Vector topKSimilarities = new RandomAccessSparseVector(original.size(), k);
+ for (Vector.Element topKSimilarity : topKQueue.getTopElements()) {
+ topKSimilarities.setQuick(topKSimilarity.index(), topKSimilarity.get());
+ }
+ return topKSimilarities;
+ }
+
+ public static Vector merge(Iterable<VectorWritable> partialVectors) {
+ Iterator<VectorWritable> vectors = partialVectors.iterator();
+ Vector accumulator = vectors.next().get();
+ while (vectors.hasNext()) {
+ VectorWritable v = vectors.next();
+ if (v != null) {
+ for (Element nonZeroElement : v.get().nonZeroes()) {
+ accumulator.setQuick(nonZeroElement.index(), nonZeroElement.get());
+ }
+ }
+ }
+ return accumulator;
+ }
+
+ public static Vector sum(Iterator<VectorWritable> vectors) {
+ Vector sum = vectors.next().get();
+ while (vectors.hasNext()) {
+ sum.assign(vectors.next().get(), Functions.PLUS);
+ }
+ return sum;
+ }
+
+ static class TemporaryElement implements Vector.Element {
+
+ private final int index;
+ private double value;
+
+ TemporaryElement(int index, double value) {
+ this.index = index;
+ this.value = value;
+ }
+
+ TemporaryElement(Vector.Element toClone) {
+ this(toClone.index(), toClone.get());
+ }
+
+ @Override
+ public double get() {
+ return value;
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set(double value) {
+ this.value = value;
+ }
+ }
+
+ public static Vector.Element[] toArray(VectorWritable vectorWritable) {
+ Vector.Element[] elements = new Vector.Element[vectorWritable.get().getNumNondefaultElements()];
+ int k = 0;
+ for (Element nonZeroElement : vectorWritable.get().nonZeroes()) {
+ elements[k++] = new TemporaryElement(nonZeroElement.index(), nonZeroElement.get());
+ }
+ return elements;
+ }
+
+ public static void write(Vector vector, Path path, Configuration conf) throws IOException {
+ write(vector, path, conf, false);
+ }
+
+ public static void write(Vector vector, Path path, Configuration conf, boolean laxPrecision) throws IOException {
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ FSDataOutputStream out = fs.create(path);
+ try {
+ VectorWritable vectorWritable = new VectorWritable(vector);
+ vectorWritable.setWritesLaxPrecision(laxPrecision);
+ vectorWritable.write(out);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public static OpenIntIntHashMap readAsIntMap(Path path, Configuration conf) throws IOException {
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ FSDataInputStream in = fs.open(path);
+ try {
+ return readAsIntMap(in);
+ } finally {
+ Closeables.close(in, true);
+ }
+ }
+
+ /* ugly optimization for loading sparse vectors containing ints only */
+ private static OpenIntIntHashMap readAsIntMap(DataInput in) throws IOException {
+ int flags = in.readByte();
+ Preconditions.checkArgument(flags >> VectorWritable.NUM_FLAGS == 0,
+ "Unknown flags set: %d", Integer.toString(flags, 2));
+ boolean dense = (flags & VectorWritable.FLAG_DENSE) != 0;
+ boolean sequential = (flags & VectorWritable.FLAG_SEQUENTIAL) != 0;
+ boolean laxPrecision = (flags & VectorWritable.FLAG_LAX_PRECISION) != 0;
+ Preconditions.checkState(!dense && !sequential, "Only for reading sparse vectors!");
+
+ Varint.readUnsignedVarInt(in);
+
+ OpenIntIntHashMap values = new OpenIntIntHashMap();
+ int numNonDefaultElements = Varint.readUnsignedVarInt(in);
+ for (int i = 0; i < numNonDefaultElements; i++) {
+ int index = Varint.readUnsignedVarInt(in);
+ double value = laxPrecision ? in.readFloat() : in.readDouble();
+ values.put(index, (int) value);
+ }
+ return values;
+ }
+
+ public static Vector read(Path path, Configuration conf) throws IOException {
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ FSDataInputStream in = fs.open(path);
+ try {
+ return VectorWritable.readVector(in);
+ } finally {
+ Closeables.close(in, true);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CityBlockSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CityBlockSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CityBlockSimilarity.java
new file mode 100644
index 0000000..0435d84
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CityBlockSimilarity.java
@@ -0,0 +1,26 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+public class CityBlockSimilarity extends CountbasedMeasure {
+
+ @Override
+ public double similarity(double dots, double normA, double normB, int numberOfColumns) {
+ return 1.0 / (1.0 + normA + normB - 2 * dots);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CooccurrenceCountSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CooccurrenceCountSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CooccurrenceCountSimilarity.java
new file mode 100644
index 0000000..61d071f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CooccurrenceCountSimilarity.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+public class CooccurrenceCountSimilarity extends CountbasedMeasure {
+
+ @Override
+ public double similarity(double dots, double normA, double normB, int numberOfColumns) {
+ return dots;
+ }
+
+ @Override
+ public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold) {
+ return numNonZeroEntriesA >= threshold && numNonZeroEntriesB >= threshold;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CosineSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CosineSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CosineSimilarity.java
new file mode 100644
index 0000000..3f4946b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CosineSimilarity.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.Vector;
+
+public class CosineSimilarity implements VectorSimilarityMeasure {
+
+ @Override
+ public Vector normalize(Vector vector) {
+ return vector.normalize();
+ }
+
+ @Override
+ public double norm(Vector vector) {
+ return VectorSimilarityMeasure.NO_NORM;
+ }
+
+ @Override
+ public double aggregate(double valueA, double nonZeroValueB) {
+ return valueA * nonZeroValueB;
+ }
+
+ @Override
+ public double similarity(double dots, double normA, double normB, int numberOfColumns) {
+ return dots;
+ }
+
+ @Override
+ public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold) {
+ return numNonZeroEntriesB >= threshold / maxValueA
+ && numNonZeroEntriesA >= threshold / maxValueB;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CountbasedMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CountbasedMeasure.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CountbasedMeasure.java
new file mode 100644
index 0000000..105df2b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CountbasedMeasure.java
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.Vector;
+
+public abstract class CountbasedMeasure implements VectorSimilarityMeasure {
+
+ @Override
+ public Vector normalize(Vector vector) {
+ return vector;
+ }
+
+ @Override
+ public double norm(Vector vector) {
+ return vector.norm(0);
+ }
+
+ @Override
+ public double aggregate(double valueA, double nonZeroValueB) {
+ return 1;
+ }
+
+ @Override
+ public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold) {
+ return true;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java
new file mode 100644
index 0000000..e61c3eb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.Vector;
+
+public class EuclideanDistanceSimilarity implements VectorSimilarityMeasure {
+
+ @Override
+ public Vector normalize(Vector vector) {
+ return vector;
+ }
+
+ @Override
+ public double norm(Vector vector) {
+ double norm = 0;
+ for (Vector.Element e : vector.nonZeroes()) {
+ double value = e.get();
+ norm += value * value;
+ }
+ return norm;
+ }
+
+ @Override
+ public double aggregate(double valueA, double nonZeroValueB) {
+ return valueA * nonZeroValueB;
+ }
+
+ @Override
+ public double similarity(double dots, double normA, double normB, int numberOfColumns) {
+ // Arg can't be negative in theory, but can in practice due to rounding, so cap it.
+ // Also note that normA / normB are actually the squares of the norms.
+ double euclideanDistance = Math.sqrt(Math.max(0.0, normA - 2 * dots + normB));
+ return 1.0 / (1.0 + euclideanDistance);
+ }
+
+ @Override
+ public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold) {
+ return true;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java
new file mode 100644
index 0000000..7544b5d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java
@@ -0,0 +1,34 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.stats.LogLikelihood;
+
+public class LoglikelihoodSimilarity extends CountbasedMeasure {
+
+ @Override
+ public double similarity(double summedAggregations, double normA, double normB, int numberOfColumns) {
+ double logLikelihood =
+ LogLikelihood.logLikelihoodRatio((long) summedAggregations,
+ (long) (normB - summedAggregations),
+ (long) (normA - summedAggregations),
+ (long) (numberOfColumns - normA - normB + summedAggregations));
+ return 1.0 - 1.0 / (1.0 + logLikelihood);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java
new file mode 100644
index 0000000..c650d8f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.Vector;
+
+public class PearsonCorrelationSimilarity extends CosineSimilarity {
+
+ @Override
+ public Vector normalize(Vector vector) {
+ if (vector.getNumNondefaultElements() == 0) {
+ return vector;
+ }
+
+ // center non-zero elements
+ double average = vector.norm(1) / vector.getNumNonZeroElements();
+ for (Vector.Element e : vector.nonZeroes()) {
+ e.set(e.get() - average);
+ }
+ return super.normalize(vector);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java
new file mode 100644
index 0000000..e000579
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java
@@ -0,0 +1,34 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+public class TanimotoCoefficientSimilarity extends CountbasedMeasure {
+
+ @Override
+ public double similarity(double dots, double normA, double normB, int numberOfColumns) {
+ // Return 0 even when dots == 0 since this will cause it to be ignored -- not NaN
+ return dots / (normA + normB - dots);
+ }
+
+ @Override
+ public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold) {
+ return numNonZeroEntriesA >= numNonZeroEntriesB * threshold
+ && numNonZeroEntriesB >= numNonZeroEntriesA * threshold;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java
new file mode 100644
index 0000000..77125c2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.Vector;
+
+public interface VectorSimilarityMeasure {
+
+ double NO_NORM = 0.0;
+
+ Vector normalize(Vector vector);
+ double norm(Vector vector);
+ double aggregate(double nonZeroValueA, double nonZeroValueB);
+ double similarity(double summedAggregations, double normA, double normB, int numberOfColumns);
+ boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java
new file mode 100644
index 0000000..9d1160e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import java.util.Arrays;
+
+public enum VectorSimilarityMeasures {
+
+ SIMILARITY_COOCCURRENCE(CooccurrenceCountSimilarity.class),
+ SIMILARITY_LOGLIKELIHOOD(LoglikelihoodSimilarity.class),
+ SIMILARITY_TANIMOTO_COEFFICIENT(TanimotoCoefficientSimilarity.class),
+ SIMILARITY_CITY_BLOCK(CityBlockSimilarity.class),
+ SIMILARITY_COSINE(CosineSimilarity.class),
+ SIMILARITY_PEARSON_CORRELATION(PearsonCorrelationSimilarity.class),
+ SIMILARITY_EUCLIDEAN_DISTANCE(EuclideanDistanceSimilarity.class);
+
+ private final Class<? extends VectorSimilarityMeasure> implementingClass;
+
+ VectorSimilarityMeasures(Class<? extends VectorSimilarityMeasure> impl) {
+ this.implementingClass = impl;
+ }
+
+ public String getClassname() {
+ return implementingClass.getName();
+ }
+
+ public static String list() {
+ return Arrays.toString(values());
+ }
+
+}
[13/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountCombiner.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountCombiner.java b/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountCombiner.java
new file mode 100644
index 0000000..4c63333
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountCombiner.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.term;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Reducer;
+
+import java.io.IOException;
+
+/**
+ * @see TermCountReducer
+ */
+public class TermCountCombiner extends Reducer<Text, LongWritable, Text, LongWritable> {
+
+ @Override
+ protected void reduce(Text key, Iterable<LongWritable> values, Context context)
+ throws IOException, InterruptedException {
+ long sum = 0;
+ for (LongWritable value : values) {
+ sum += value.get();
+ }
+ context.write(key, new LongWritable(sum));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountMapper.java b/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountMapper.java
new file mode 100644
index 0000000..9af3d57
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountMapper.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.term;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.math.function.ObjectLongProcedure;
+import org.apache.mahout.math.map.OpenObjectLongHashMap;
+
+/**
+ * TextVectorizer Term Count Mapper. Tokenizes a text document and outputs the count of the words
+ */
+public class TermCountMapper extends Mapper<Text, StringTuple, Text, LongWritable> {
+
+ @Override
+ protected void map(Text key, StringTuple value, final Context context) throws IOException, InterruptedException {
+ OpenObjectLongHashMap<String> wordCount = new OpenObjectLongHashMap<>();
+ for (String word : value.getEntries()) {
+ if (wordCount.containsKey(word)) {
+ wordCount.put(word, wordCount.get(word) + 1);
+ } else {
+ wordCount.put(word, 1);
+ }
+ }
+ wordCount.forEachPair(new ObjectLongProcedure<String>() {
+ @Override
+ public boolean apply(String first, long second) {
+ try {
+ context.write(new Text(first), new LongWritable(second));
+ } catch (IOException e) {
+ context.getCounter("Exception", "Output IO Exception").increment(1);
+ } catch (InterruptedException e) {
+ context.getCounter("Exception", "Interrupted Exception").increment(1);
+ }
+ return true;
+ }
+ });
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountReducer.java b/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountReducer.java
new file mode 100644
index 0000000..388bfc2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/term/TermCountReducer.java
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.term;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.vectorizer.DictionaryVectorizer;
+
+/**
+ * This accumulates all the words and the weights and sums them up.
+ *
+ * @see TermCountCombiner
+ */
+public class TermCountReducer extends Reducer<Text, LongWritable, Text, LongWritable> {
+
+ private int minSupport;
+
+ @Override
+ protected void reduce(Text key, Iterable<LongWritable> values, Context context)
+ throws IOException, InterruptedException {
+ long sum = 0;
+ for (LongWritable value : values) {
+ sum += value.get();
+ }
+ if (sum >= minSupport) {
+ context.write(key, new LongWritable(sum));
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ minSupport = context.getConfiguration().getInt(DictionaryVectorizer.MIN_SUPPORT,
+ DictionaryVectorizer.DEFAULT_MIN_SUPPORT);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountMapper.java b/mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountMapper.java
new file mode 100644
index 0000000..30828bf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountMapper.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.term;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * TextVectorizer Document Frequency Count Mapper. Outputs 1 for each feature
+ */
+public class TermDocumentCountMapper extends Mapper<WritableComparable<?>, VectorWritable, IntWritable, LongWritable> {
+
+ private static final LongWritable ONE = new LongWritable(1);
+
+ private static final IntWritable TOTAL_COUNT = new IntWritable(-1);
+
+ private final IntWritable out = new IntWritable();
+
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ Vector vector = value.get();
+ for (Vector.Element e : vector.nonZeroes()) {
+ out.set(e.index());
+ context.write(out, ONE);
+ }
+ context.write(TOTAL_COUNT, ONE);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountReducer.java b/mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountReducer.java
new file mode 100644
index 0000000..c815692
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/term/TermDocumentCountReducer.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.term;
+
+import java.io.IOException;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+
+/**
+ * Can also be used as a local Combiner. This accumulates all the features and the weights and sums them up.
+ */
+public class TermDocumentCountReducer extends Reducer<IntWritable, LongWritable, IntWritable, LongWritable> {
+
+ @Override
+ protected void reduce(IntWritable key, Iterable<LongWritable> values, Context context)
+ throws IOException, InterruptedException {
+ long sum = 0;
+ for (LongWritable value : values) {
+ sum += value.get();
+ }
+ context.write(key, new LongWritable(sum));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFConverter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFConverter.java b/mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFConverter.java
new file mode 100644
index 0000000..5f9d666
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFConverter.java
@@ -0,0 +1,361 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.tfidf;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.vectorizer.common.PartialVectorMerger;
+import org.apache.mahout.vectorizer.term.TermDocumentCountMapper;
+import org.apache.mahout.vectorizer.term.TermDocumentCountReducer;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * This class converts a set of input vectors with term frequencies to TfIdf vectors. The Sequence file input
+ * should have a {@link org.apache.hadoop.io.WritableComparable} key containing and a
+ * {@link VectorWritable} value containing the
+ * term frequency vector. This is conversion class uses multiple map/reduces to convert the vectors to TfIdf
+ * format
+ *
+ */
+public final class TFIDFConverter {
+
+ public static final String VECTOR_COUNT = "vector.count";
+ public static final String FEATURE_COUNT = "feature.count";
+ public static final String MIN_DF = "min.df";
+ public static final String MAX_DF = "max.df";
+ //public static final String TFIDF_OUTPUT_FOLDER = "tfidf";
+
+ private static final String DOCUMENT_VECTOR_OUTPUT_FOLDER = "tfidf-vectors";
+ public static final String FREQUENCY_FILE = "frequency.file-";
+ private static final int MAX_CHUNKSIZE = 10000;
+ private static final int MIN_CHUNKSIZE = 100;
+ private static final String OUTPUT_FILES_PATTERN = "part-*";
+ private static final int SEQUENCEFILE_BYTE_OVERHEAD = 45;
+ private static final String VECTOR_OUTPUT_FOLDER = "partial-vectors-";
+ public static final String WORDCOUNT_OUTPUT_FOLDER = "df-count";
+
+ /**
+ * Cannot be initialized. Use the static functions
+ */
+ private TFIDFConverter() {}
+
+ /**
+ * Create Term Frequency-Inverse Document Frequency (Tf-Idf) Vectors from the input set of vectors in
+ * {@link SequenceFile} format. This job uses a fixed limit on the maximum memory used by the feature chunk
+ * per node thereby splitting the process across multiple map/reduces.
+ * Before using this method calculateDF should be called
+ *
+ * @param input
+ * input directory of the vectors in {@link SequenceFile} format
+ * @param output
+ * output directory where {@link org.apache.mahout.math.RandomAccessSparseVector}'s of the document
+ * are generated
+ * @param datasetFeatures
+ * Document frequencies information calculated by calculateDF
+ * @param minDf
+ * The minimum document frequency. Default 1
+ * @param maxDF
+ * The max percentage of vectors for the DF. Can be used to remove really high frequency features.
+ * Expressed as an integer between 0 and 100. Default 99
+ * @param numReducers
+ * The number of reducers to spawn. This also affects the possible parallelism since each reducer
+ * will typically produce a single output file containing tf-idf vectors for a subset of the
+ * documents in the corpus.
+ */
+ public static void processTfIdf(Path input,
+ Path output,
+ Configuration baseConf,
+ Pair<Long[], List<Path>> datasetFeatures,
+ int minDf,
+ long maxDF,
+ float normPower,
+ boolean logNormalize,
+ boolean sequentialAccessOutput,
+ boolean namedVector,
+ int numReducers) throws IOException, InterruptedException, ClassNotFoundException {
+ Preconditions.checkArgument(normPower == PartialVectorMerger.NO_NORMALIZING || normPower >= 0,
+ "If specified normPower must be nonnegative", normPower);
+ Preconditions.checkArgument(normPower == PartialVectorMerger.NO_NORMALIZING
+ || (normPower > 1 && !Double.isInfinite(normPower))
+ || !logNormalize,
+ "normPower must be > 1 and not infinite if log normalization is chosen", normPower);
+
+ int partialVectorIndex = 0;
+ List<Path> partialVectorPaths = Lists.newArrayList();
+ List<Path> dictionaryChunks = datasetFeatures.getSecond();
+ for (Path dictionaryChunk : dictionaryChunks) {
+ Path partialVectorOutputPath = new Path(output, VECTOR_OUTPUT_FOLDER + partialVectorIndex++);
+ partialVectorPaths.add(partialVectorOutputPath);
+ makePartialVectors(input,
+ baseConf,
+ datasetFeatures.getFirst()[0],
+ datasetFeatures.getFirst()[1],
+ minDf,
+ maxDF,
+ dictionaryChunk,
+ partialVectorOutputPath,
+ sequentialAccessOutput,
+ namedVector);
+ }
+
+ Configuration conf = new Configuration(baseConf);
+
+ Path outputDir = new Path(output, DOCUMENT_VECTOR_OUTPUT_FOLDER);
+
+ PartialVectorMerger.mergePartialVectors(partialVectorPaths,
+ outputDir,
+ baseConf,
+ normPower,
+ logNormalize,
+ datasetFeatures.getFirst()[0].intValue(),
+ sequentialAccessOutput,
+ namedVector,
+ numReducers);
+ HadoopUtil.delete(conf, partialVectorPaths);
+
+ }
+
+ /**
+ * Calculates the document frequencies of all terms from the input set of vectors in
+ * {@link SequenceFile} format. This job uses a fixed limit on the maximum memory used by the feature chunk
+ * per node thereby splitting the process across multiple map/reduces.
+ *
+ * @param input
+ * input directory of the vectors in {@link SequenceFile} format
+ * @param output
+ * output directory where document frequencies will be stored
+ * @param chunkSizeInMegabytes
+ * the size in MB of the feature => id chunk to be kept in memory at each node during Map/Reduce
+ * stage. Its recommended you calculated this based on the number of cores and the free memory
+ * available to you per node. Say, you have 2 cores and around 1GB extra memory to spare we
+ * recommend you use a split size of around 400-500MB so that two simultaneous reducers can create
+ * partial vectors without thrashing the system due to increased swapping
+ */
+ public static Pair<Long[],List<Path>> calculateDF(Path input,
+ Path output,
+ Configuration baseConf,
+ int chunkSizeInMegabytes)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ if (chunkSizeInMegabytes < MIN_CHUNKSIZE) {
+ chunkSizeInMegabytes = MIN_CHUNKSIZE;
+ } else if (chunkSizeInMegabytes > MAX_CHUNKSIZE) { // 10GB
+ chunkSizeInMegabytes = MAX_CHUNKSIZE;
+ }
+
+ Path wordCountPath = new Path(output, WORDCOUNT_OUTPUT_FOLDER);
+
+ startDFCounting(input, wordCountPath, baseConf);
+
+ return createDictionaryChunks(wordCountPath, output, baseConf, chunkSizeInMegabytes);
+ }
+
+ /**
+ * Read the document frequency List which is built at the end of the DF Count Job. This will use constant
+ * memory and will run at the speed of your disk read
+ */
+ private static Pair<Long[], List<Path>> createDictionaryChunks(Path featureCountPath,
+ Path dictionaryPathBase,
+ Configuration baseConf,
+ int chunkSizeInMegabytes) throws IOException {
+ List<Path> chunkPaths = Lists.newArrayList();
+ Configuration conf = new Configuration(baseConf);
+
+ FileSystem fs = FileSystem.get(featureCountPath.toUri(), conf);
+
+ long chunkSizeLimit = chunkSizeInMegabytes * 1024L * 1024L;
+ int chunkIndex = 0;
+ Path chunkPath = new Path(dictionaryPathBase, FREQUENCY_FILE + chunkIndex);
+ chunkPaths.add(chunkPath);
+ SequenceFile.Writer freqWriter =
+ new SequenceFile.Writer(fs, conf, chunkPath, IntWritable.class, LongWritable.class);
+
+ try {
+ long currentChunkSize = 0;
+ long featureCount = 0;
+ long vectorCount = Long.MAX_VALUE;
+ Path filesPattern = new Path(featureCountPath, OUTPUT_FILES_PATTERN);
+ for (Pair<IntWritable,LongWritable> record
+ : new SequenceFileDirIterable<IntWritable,LongWritable>(filesPattern,
+ PathType.GLOB,
+ null,
+ null,
+ true,
+ conf)) {
+
+ if (currentChunkSize > chunkSizeLimit) {
+ Closeables.close(freqWriter, false);
+ chunkIndex++;
+
+ chunkPath = new Path(dictionaryPathBase, FREQUENCY_FILE + chunkIndex);
+ chunkPaths.add(chunkPath);
+
+ freqWriter = new SequenceFile.Writer(fs, conf, chunkPath, IntWritable.class, LongWritable.class);
+ currentChunkSize = 0;
+ }
+
+ int fieldSize = SEQUENCEFILE_BYTE_OVERHEAD + Integer.SIZE / 8 + Long.SIZE / 8;
+ currentChunkSize += fieldSize;
+ IntWritable key = record.getFirst();
+ LongWritable value = record.getSecond();
+ if (key.get() >= 0) {
+ freqWriter.append(key, value);
+ } else if (key.get() == -1) {
+ vectorCount = value.get();
+ }
+ featureCount = Math.max(key.get(), featureCount);
+
+ }
+ featureCount++;
+ Long[] counts = {featureCount, vectorCount};
+ return new Pair<>(counts, chunkPaths);
+ } finally {
+ Closeables.close(freqWriter, false);
+ }
+ }
+
+ /**
+ * Create a partial tfidf vector using a chunk of features from the input vectors. The input vectors has to
+ * be in the {@link SequenceFile} format
+ *
+ * @param input
+ * input directory of the vectors in {@link SequenceFile} format
+ * @param featureCount
+ * Number of unique features in the dataset
+ * @param vectorCount
+ * Number of vectors in the dataset
+ * @param minDf
+ * The minimum document frequency. Default 1
+ * @param maxDF
+ * The max percentage of vectors for the DF. Can be used to remove really high frequency features.
+ * Expressed as an integer between 0 and 100. Default 99
+ * @param dictionaryFilePath
+ * location of the chunk of features and the id's
+ * @param output
+ * output directory were the partial vectors have to be created
+ * @param sequentialAccess
+ * output vectors should be optimized for sequential access
+ * @param namedVector
+ * output vectors should be named, retaining key (doc id) as a label
+ */
+ private static void makePartialVectors(Path input,
+ Configuration baseConf,
+ Long featureCount,
+ Long vectorCount,
+ int minDf,
+ long maxDF,
+ Path dictionaryFilePath,
+ Path output,
+ boolean sequentialAccess,
+ boolean namedVector)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ Configuration conf = new Configuration(baseConf);
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ conf.setLong(FEATURE_COUNT, featureCount);
+ conf.setLong(VECTOR_COUNT, vectorCount);
+ conf.setInt(MIN_DF, minDf);
+ conf.setLong(MAX_DF, maxDF);
+ conf.setBoolean(PartialVectorMerger.SEQUENTIAL_ACCESS, sequentialAccess);
+ conf.setBoolean(PartialVectorMerger.NAMED_VECTOR, namedVector);
+ DistributedCache.addCacheFile(dictionaryFilePath.toUri(), conf);
+
+ Job job = new Job(conf);
+ job.setJobName(": MakePartialVectors: input-folder: " + input + ", dictionary-file: "
+ + dictionaryFilePath.toString());
+ job.setJarByClass(TFIDFConverter.class);
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(VectorWritable.class);
+ FileInputFormat.setInputPaths(job, input);
+
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setMapperClass(Mapper.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setReducerClass(TFIDFPartialVectorReducer.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ HadoopUtil.delete(conf, output);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+ /**
+ * Count the document frequencies of features in parallel using Map/Reduce. The input documents have to be
+ * in {@link SequenceFile} format
+ */
+ private static void startDFCounting(Path input, Path output, Configuration baseConf)
+ throws IOException, InterruptedException, ClassNotFoundException {
+
+ Configuration conf = new Configuration(baseConf);
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+
+ Job job = new Job(conf);
+ job.setJobName("VectorTfIdf Document Frequency Count running over input: " + input);
+ job.setJarByClass(TFIDFConverter.class);
+
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(LongWritable.class);
+
+ FileInputFormat.setInputPaths(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setMapperClass(TermDocumentCountMapper.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setCombinerClass(TermDocumentCountReducer.class);
+ job.setReducerClass(TermDocumentCountReducer.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ HadoopUtil.delete(conf, output);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFPartialVectorReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFPartialVectorReducer.java b/mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFPartialVectorReducer.java
new file mode 100644
index 0000000..1e71ed8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/tfidf/TFIDFPartialVectorReducer.java
@@ -0,0 +1,114 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.tfidf;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+import org.apache.mahout.vectorizer.TFIDF;
+import org.apache.mahout.vectorizer.common.PartialVectorMerger;
+
+/**
+ * Converts a document into a sparse vector
+ */
+public class TFIDFPartialVectorReducer extends
+ Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
+
+ private final OpenIntLongHashMap dictionary = new OpenIntLongHashMap();
+ private final TFIDF tfidf = new TFIDF();
+
+ private int minDf = 1;
+ private long maxDf = -1;
+ private long vectorCount = 1;
+ private long featureCount;
+ private boolean sequentialAccess;
+ private boolean namedVector;
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context context)
+ throws IOException, InterruptedException {
+ Iterator<VectorWritable> it = values.iterator();
+ if (!it.hasNext()) {
+ return;
+ }
+ Vector value = it.next().get();
+ Vector vector = new RandomAccessSparseVector((int) featureCount, value.getNumNondefaultElements());
+ for (Vector.Element e : value.nonZeroes()) {
+ if (!dictionary.containsKey(e.index())) {
+ continue;
+ }
+ long df = dictionary.get(e.index());
+ if (maxDf > -1 && (100.0 * df) / vectorCount > maxDf) {
+ continue;
+ }
+ if (df < minDf) {
+ df = minDf;
+ }
+ vector.setQuick(e.index(), tfidf.calculate((int) e.get(), (int) df, (int) featureCount, (int) vectorCount));
+ }
+ if (sequentialAccess) {
+ vector = new SequentialAccessSparseVector(vector);
+ }
+
+ if (namedVector) {
+ vector = new NamedVector(vector, key.toString());
+ }
+
+ VectorWritable vectorWritable = new VectorWritable(vector);
+ context.write(key, vectorWritable);
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+
+ vectorCount = conf.getLong(TFIDFConverter.VECTOR_COUNT, 1);
+ featureCount = conf.getLong(TFIDFConverter.FEATURE_COUNT, 1);
+ minDf = conf.getInt(TFIDFConverter.MIN_DF, 1);
+ maxDf = conf.getLong(TFIDFConverter.MAX_DF, -1);
+ sequentialAccess = conf.getBoolean(PartialVectorMerger.SEQUENTIAL_ACCESS, false);
+ namedVector = conf.getBoolean(PartialVectorMerger.NAMED_VECTOR, false);
+
+ URI[] localFiles = DistributedCache.getCacheFiles(conf);
+ Path dictionaryFile = HadoopUtil.findInCacheByPartOfFilename(TFIDFConverter.FREQUENCY_FILE, localFiles);
+ // key is feature, value is the document frequency
+ for (Pair<IntWritable,LongWritable> record
+ : new SequenceFileIterable<IntWritable,LongWritable>(dictionaryFile, true, conf)) {
+ dictionary.put(record.getFirst().get(), record.getSecond().get());
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/resources/version
----------------------------------------------------------------------
diff --git a/mr/src/main/resources/version b/mr/src/main/resources/version
new file mode 100644
index 0000000..f2ab45c
--- /dev/null
+++ b/mr/src/main/resources/version
@@ -0,0 +1 @@
+${project.version}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/common/CommonTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/common/CommonTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/common/CommonTest.java
new file mode 100644
index 0000000..c37bcd3
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/common/CommonTest.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.common;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+
+import java.io.ByteArrayOutputStream;
+import java.io.OutputStreamWriter;
+import java.io.PrintStream;
+import java.io.PrintWriter;
+
+/** <p>Tests common classes.</p> */
+public final class CommonTest extends TasteTestCase {
+
+ @Test
+ public void testTasteException() {
+ // Just make sure this all doesn't, ah, throw an exception
+ TasteException te1 = new TasteException();
+ TasteException te2 = new TasteException(te1);
+ TasteException te3 = new TasteException(te2.toString(), te2);
+ TasteException te4 = new TasteException(te3.toString());
+ te4.printStackTrace(new PrintStream(new ByteArrayOutputStream()));
+ te4.printStackTrace(new PrintWriter(new OutputStreamWriter(new ByteArrayOutputStream())));
+ }
+
+ @Test
+ public void testNSUException() {
+ // Just make sure this all doesn't, ah, throw an exception
+ TasteException te1 = new NoSuchUserException();
+ TasteException te4 = new NoSuchUserException(te1.toString());
+ te4.printStackTrace(new PrintStream(new ByteArrayOutputStream()));
+ te4.printStackTrace(new PrintWriter(new OutputStreamWriter(new ByteArrayOutputStream())));
+ }
+
+ @Test
+ public void testNSIException() {
+ // Just make sure this all doesn't, ah, throw an exception
+ TasteException te1 = new NoSuchItemException();
+ TasteException te4 = new NoSuchItemException(te1.toString());
+ te4.printStackTrace(new PrintStream(new ByteArrayOutputStream()));
+ te4.printStackTrace(new PrintWriter(new OutputStreamWriter(new ByteArrayOutputStream())));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtilsTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtilsTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtilsTest.java
new file mode 100644
index 0000000..b299b35
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TasteHadoopUtilsTest.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+
+/** <p>Tests {@link TasteHadoopUtils}.</p> */
+public class TasteHadoopUtilsTest extends TasteTestCase {
+
+ @Test
+ public void testWithinRange() {
+ assertTrue(TasteHadoopUtils.idToIndex(0) >= 0);
+ assertTrue(TasteHadoopUtils.idToIndex(0) < Integer.MAX_VALUE);
+
+ assertTrue(TasteHadoopUtils.idToIndex(1) >= 0);
+ assertTrue(TasteHadoopUtils.idToIndex(1) < Integer.MAX_VALUE);
+
+ assertTrue(TasteHadoopUtils.idToIndex(Long.MAX_VALUE) >= 0);
+ assertTrue(TasteHadoopUtils.idToIndex(Long.MAX_VALUE) < Integer.MAX_VALUE);
+
+ assertTrue(TasteHadoopUtils.idToIndex(Integer.MAX_VALUE) >= 0);
+ assertTrue(TasteHadoopUtils.idToIndex(Integer.MAX_VALUE) < Integer.MAX_VALUE);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueueTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueueTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueueTest.java
new file mode 100644
index 0000000..9465def
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/TopItemsQueueTest.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+import java.util.List;
+
+public class TopItemsQueueTest extends TasteTestCase {
+
+ @Test
+ public void topK() {
+
+ float[] ratings = {0.5f, 0.6f, 0.7f, 2.0f, 0.0f};
+
+ List<RecommendedItem> topItems = findTop(ratings, 2);
+
+ assertEquals(2, topItems.size());
+ assertEquals(3L, topItems.get(0).getItemID());
+ assertEquals(2.0f, topItems.get(0).getValue(), MahoutTestCase.EPSILON);
+ assertEquals(2L, topItems.get(1).getItemID());
+ assertEquals(0.7f, topItems.get(1).getValue(), MahoutTestCase.EPSILON);
+ }
+
+ @Test
+ public void topKInputSmallerThanK() {
+
+ float[] ratings = {0.7f, 2.0f};
+
+ List<RecommendedItem> topItems = findTop(ratings, 3);
+
+ assertEquals(2, topItems.size());
+ assertEquals(1L, topItems.get(0).getItemID());
+ assertEquals(2.0f, topItems.get(0).getValue(), MahoutTestCase.EPSILON);
+ assertEquals(0L, topItems.get(1).getItemID());
+ assertEquals(0.7f, topItems.get(1).getValue(), MahoutTestCase.EPSILON);
+ }
+
+
+ private static List<RecommendedItem> findTop(float[] ratings, int k) {
+ TopItemsQueue queue = new TopItemsQueue(k);
+
+ for (int item = 0; item < ratings.length; item++) {
+ MutableRecommendedItem top = queue.top();
+ if (ratings[item] > top.getValue()) {
+ top.set(item, ratings[item]);
+ queue.updateTop();
+ }
+ }
+
+ return queue.getTopItems();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
new file mode 100644
index 0000000..9d37da2
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/als/ParallelALSFactorizationJobTest.java
@@ -0,0 +1,379 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.als;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.TasteHadoopUtils;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+import org.apache.mahout.math.map.OpenIntObjectHashMap;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.File;
+
+public class ParallelALSFactorizationJobTest extends TasteTestCase {
+
+ private static final Logger log = LoggerFactory.getLogger(ParallelALSFactorizationJobTest.class);
+
+ private File inputFile;
+ private File intermediateDir;
+ private File outputDir;
+ private File tmpDir;
+ private Configuration conf;
+
+ @Before
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ inputFile = getTestTempFile("prefs.txt");
+ intermediateDir = getTestTempDir("intermediate");
+ intermediateDir.delete();
+ outputDir = getTestTempDir("output");
+ outputDir.delete();
+ tmpDir = getTestTempDir("tmp");
+
+ conf = getConfiguration();
+ // reset as we run all tests in the same JVM
+ SharingMapper.reset();
+ }
+
+ @Test
+ public void completeJobToyExample() throws Exception {
+ explicitExample(1);
+ }
+
+ @Test
+ public void completeJobToyExampleMultithreaded() throws Exception {
+ explicitExample(2);
+ }
+
+ /**
+ * small integration test that runs the full job
+ *
+ * <pre>
+ *
+ * user-item-matrix
+ *
+ * burger hotdog berries icecream
+ * dog 5 5 2 -
+ * rabbit 2 - 3 5
+ * cow - 5 - 3
+ * donkey 3 - - 5
+ *
+ * </pre>
+ */
+ private void explicitExample(int numThreads) throws Exception {
+
+ Double na = Double.NaN;
+ Matrix preferences = new SparseRowMatrix(4, 4, new Vector[] {
+ new DenseVector(new double[] { 5.0, 5.0, 2.0, na }),
+ new DenseVector(new double[] { 2.0, na, 3.0, 5.0 }),
+ new DenseVector(new double[] { na, 5.0, na, 3.0 }),
+ new DenseVector(new double[] { 3.0, na, na, 5.0 }) });
+
+ writeLines(inputFile, preferencesAsText(preferences));
+
+ ParallelALSFactorizationJob alsFactorization = new ParallelALSFactorizationJob();
+ alsFactorization.setConf(conf);
+
+ int numFeatures = 3;
+ int numIterations = 5;
+ double lambda = 0.065;
+
+ alsFactorization.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "--tempDir", tmpDir.getAbsolutePath(), "--lambda", String.valueOf(lambda),
+ "--numFeatures", String.valueOf(numFeatures), "--numIterations", String.valueOf(numIterations),
+ "--numThreadsPerSolver", String.valueOf(numThreads) });
+
+ Matrix u = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "U/part-m-00000"),
+ preferences.numRows(), numFeatures);
+ Matrix m = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "M/part-m-00000"),
+ preferences.numCols(), numFeatures);
+
+ StringBuilder info = new StringBuilder();
+ info.append("\nA - users x items\n\n");
+ info.append(MathHelper.nice(preferences));
+ info.append("\nU - users x features\n\n");
+ info.append(MathHelper.nice(u));
+ info.append("\nM - items x features\n\n");
+ info.append(MathHelper.nice(m));
+ Matrix Ak = u.times(m.transpose());
+ info.append("\nAk - users x items\n\n");
+ info.append(MathHelper.nice(Ak));
+ info.append('\n');
+
+ log.info(info.toString());
+
+ RunningAverage avg = new FullRunningAverage();
+ for (MatrixSlice slice : preferences) {
+ for (Element e : slice.nonZeroes()) {
+ if (!Double.isNaN(e.get())) {
+ double pref = e.get();
+ double estimate = u.viewRow(slice.index()).dot(m.viewRow(e.index()));
+ double err = pref - estimate;
+ avg.addDatum(err * err);
+ log.info("Comparing preference of user [{}] towards item [{}], was [{}] estimate is [{}]",
+ slice.index(), e.index(), pref, estimate);
+ }
+ }
+ }
+ double rmse = Math.sqrt(avg.getAverage());
+ log.info("RMSE: {}", rmse);
+
+ assertTrue(rmse < 0.2);
+ }
+
+ @Test
+ public void completeJobImplicitToyExample() throws Exception {
+ implicitExample(1);
+ }
+
+ @Test
+ public void completeJobImplicitToyExampleMultithreaded() throws Exception {
+ implicitExample(2);
+ }
+
+ public void implicitExample(int numThreads) throws Exception {
+ Matrix observations = new SparseRowMatrix(4, 4, new Vector[] {
+ new DenseVector(new double[] { 5.0, 5.0, 2.0, 0 }),
+ new DenseVector(new double[] { 2.0, 0, 3.0, 5.0 }),
+ new DenseVector(new double[] { 0, 5.0, 0, 3.0 }),
+ new DenseVector(new double[] { 3.0, 0, 0, 5.0 }) });
+
+ Matrix preferences = new SparseRowMatrix(4, 4, new Vector[] {
+ new DenseVector(new double[] { 1.0, 1.0, 1.0, 0 }),
+ new DenseVector(new double[] { 1.0, 0, 1.0, 1.0 }),
+ new DenseVector(new double[] { 0, 1.0, 0, 1.0 }),
+ new DenseVector(new double[] { 1.0, 0, 0, 1.0 }) });
+
+ writeLines(inputFile, preferencesAsText(observations));
+
+ ParallelALSFactorizationJob alsFactorization = new ParallelALSFactorizationJob();
+ alsFactorization.setConf(conf);
+
+ int numFeatures = 3;
+ int numIterations = 5;
+ double lambda = 0.065;
+ double alpha = 20;
+
+ alsFactorization.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "--tempDir", tmpDir.getAbsolutePath(), "--lambda", String.valueOf(lambda),
+ "--implicitFeedback", String.valueOf(true), "--alpha", String.valueOf(alpha),
+ "--numFeatures", String.valueOf(numFeatures), "--numIterations", String.valueOf(numIterations),
+ "--numThreadsPerSolver", String.valueOf(numThreads) });
+
+ Matrix u = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "U/part-m-00000"),
+ observations.numRows(), numFeatures);
+ Matrix m = MathHelper.readMatrix(conf, new Path(outputDir.getAbsolutePath(), "M/part-m-00000"),
+ observations.numCols(), numFeatures);
+
+ StringBuilder info = new StringBuilder();
+ info.append("\nObservations - users x items\n");
+ info.append(MathHelper.nice(observations));
+ info.append("\nA - users x items\n\n");
+ info.append(MathHelper.nice(preferences));
+ info.append("\nU - users x features\n\n");
+ info.append(MathHelper.nice(u));
+ info.append("\nM - items x features\n\n");
+ info.append(MathHelper.nice(m));
+ Matrix Ak = u.times(m.transpose());
+ info.append("\nAk - users x items\n\n");
+ info.append(MathHelper.nice(Ak));
+ info.append('\n');
+
+ log.info(info.toString());
+
+ RunningAverage avg = new FullRunningAverage();
+ for (MatrixSlice slice : preferences) {
+ for (Element e : slice.nonZeroes()) {
+ if (!Double.isNaN(e.get())) {
+ double pref = e.get();
+ double estimate = u.viewRow(slice.index()).dot(m.viewRow(e.index()));
+ double confidence = 1 + alpha * observations.getQuick(slice.index(), e.index());
+ double err = confidence * (pref - estimate) * (pref - estimate);
+ avg.addDatum(err);
+ log.info("Comparing preference of user [{}] towards item [{}], was [{}] with confidence [{}] "
+ + "estimate is [{}]", slice.index(), e.index(), pref, confidence, estimate);
+ }
+ }
+ }
+ double rmse = Math.sqrt(avg.getAverage());
+ log.info("RMSE: {}", rmse);
+
+ assertTrue(rmse < 0.4);
+ }
+
+ @Test
+ public void exampleWithIDMapping() throws Exception {
+
+ String[] preferencesWithLongIDs = {
+ "5568227754922264005,-4758971626494767444,5.0",
+ "5568227754922264005,3688396615879561990,5.0",
+ "5568227754922264005,4594226737871995304,2.0",
+ "550945997885173934,-4758971626494767444,2.0",
+ "550945997885173934,4594226737871995304,3.0",
+ "550945997885173934,706816485922781596,5.0",
+ "2448095297482319463,3688396615879561990,5.0",
+ "2448095297482319463,706816485922781596,3.0",
+ "6839920411763636962,-4758971626494767444,3.0",
+ "6839920411763636962,706816485922781596,5.0" };
+
+ writeLines(inputFile, preferencesWithLongIDs);
+
+ ParallelALSFactorizationJob alsFactorization = new ParallelALSFactorizationJob();
+ alsFactorization.setConf(conf);
+
+ int numFeatures = 3;
+ int numIterations = 5;
+ double lambda = 0.065;
+
+ alsFactorization.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "--tempDir", tmpDir.getAbsolutePath(), "--lambda", String.valueOf(lambda),
+ "--numFeatures", String.valueOf(numFeatures), "--numIterations", String.valueOf(numIterations),
+ "--numThreadsPerSolver", String.valueOf(1), "--usesLongIDs", String.valueOf(true) });
+
+
+ OpenIntLongHashMap userIDIndex =
+ TasteHadoopUtils.readIDIndexMap(outputDir.getAbsolutePath() + "/userIDIndex/part-r-00000", conf);
+ assertEquals(4, userIDIndex.size());
+
+ OpenIntLongHashMap itemIDIndex =
+ TasteHadoopUtils.readIDIndexMap(outputDir.getAbsolutePath() + "/itemIDIndex/part-r-00000", conf);
+ assertEquals(4, itemIDIndex.size());
+
+ OpenIntObjectHashMap<Vector> u =
+ MathHelper.readMatrixRows(conf, new Path(outputDir.getAbsolutePath(), "U/part-m-00000"));
+ OpenIntObjectHashMap<Vector> m =
+ MathHelper.readMatrixRows(conf, new Path(outputDir.getAbsolutePath(), "M/part-m-00000"));
+
+ assertEquals(4, u.size());
+ assertEquals(4, m.size());
+
+ RunningAverage avg = new FullRunningAverage();
+ for (String line : preferencesWithLongIDs) {
+ String[] tokens = TasteHadoopUtils.splitPrefTokens(line);
+ long userID = Long.parseLong(tokens[TasteHadoopUtils.USER_ID_POS]);
+ long itemID = Long.parseLong(tokens[TasteHadoopUtils.ITEM_ID_POS]);
+ double rating = Double.parseDouble(tokens[2]);
+
+ Vector userFeatures = u.get(TasteHadoopUtils.idToIndex(userID));
+ Vector itemFeatures = m.get(TasteHadoopUtils.idToIndex(itemID));
+
+ double estimate = userFeatures.dot(itemFeatures);
+
+ double err = rating - estimate;
+ avg.addDatum(err * err);
+ }
+
+ double rmse = Math.sqrt(avg.getAverage());
+ log.info("RMSE: {}", rmse);
+
+ assertTrue(rmse < 0.2);
+ }
+
+ protected static String preferencesAsText(Matrix preferences) {
+ StringBuilder prefsAsText = new StringBuilder();
+ String separator = "";
+ for (MatrixSlice slice : preferences) {
+ for (Element e : slice.nonZeroes()) {
+ if (!Double.isNaN(e.get())) {
+ prefsAsText.append(separator)
+ .append(slice.index()).append(',').append(e.index()).append(',').append(e.get());
+ separator = "\n";
+ }
+ }
+ }
+ System.out.println(prefsAsText.toString());
+ return prefsAsText.toString();
+ }
+
+ @Test
+ public void recommenderJobWithIDMapping() throws Exception {
+
+ String[] preferencesWithLongIDs = {
+ "5568227754922264005,-4758971626494767444,5.0",
+ "5568227754922264005,3688396615879561990,5.0",
+ "5568227754922264005,4594226737871995304,2.0",
+ "550945997885173934,-4758971626494767444,2.0",
+ "550945997885173934,4594226737871995304,3.0",
+ "550945997885173934,706816485922781596,5.0",
+ "2448095297482319463,3688396615879561990,5.0",
+ "2448095297482319463,706816485922781596,3.0",
+ "6839920411763636962,-4758971626494767444,3.0",
+ "6839920411763636962,706816485922781596,5.0" };
+
+ writeLines(inputFile, preferencesWithLongIDs);
+
+ ParallelALSFactorizationJob alsFactorization = new ParallelALSFactorizationJob();
+ alsFactorization.setConf(conf);
+
+ int numFeatures = 3;
+ int numIterations = 5;
+ double lambda = 0.065;
+
+ Configuration conf = getConfiguration();
+
+ int success = ToolRunner.run(alsFactorization, new String[] {
+ "-Dhadoop.tmp.dir=" + conf.get("hadoop.tmp.dir"),
+ "--input", inputFile.getAbsolutePath(),
+ "--output", intermediateDir.getAbsolutePath(),
+ "--tempDir", tmpDir.getAbsolutePath(),
+ "--lambda", String.valueOf(lambda),
+ "--numFeatures", String.valueOf(numFeatures),
+ "--numIterations", String.valueOf(numIterations),
+ "--numThreadsPerSolver", String.valueOf(1),
+ "--usesLongIDs", String.valueOf(true) });
+
+ assertEquals(0, success);
+
+ // reset as we run in the same JVM
+ SharingMapper.reset();
+
+ RecommenderJob recommender = new RecommenderJob();
+
+ success = ToolRunner.run(recommender, new String[] {
+ "-Dhadoop.tmp.dir=" + conf.get("hadoop.tmp.dir"),
+ "--input", intermediateDir.getAbsolutePath() + "/userRatings/",
+ "--userFeatures", intermediateDir.getAbsolutePath() + "/U/",
+ "--itemFeatures", intermediateDir.getAbsolutePath() + "/M/",
+ "--numRecommendations", String.valueOf(2),
+ "--maxRating", String.valueOf(5.0),
+ "--numThreads", String.valueOf(2),
+ "--usesLongIDs", String.valueOf(true),
+ "--userIDIndex", intermediateDir.getAbsolutePath() + "/userIDIndex/",
+ "--itemIDIndex", intermediateDir.getAbsolutePath() + "/itemIDIndex/",
+ "--output", outputDir.getAbsolutePath() });
+
+ assertEquals(0, success);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/IDReaderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/IDReaderTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/IDReaderTest.java
new file mode 100644
index 0000000..650ca98
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/hadoop/item/IDReaderTest.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.hadoop.item;
+
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+
+public class IDReaderTest extends TasteTestCase {
+
+ static final String USER_ITEM_FILTER_FIELD = "userItemFilter";
+
+ @Test
+ public void testUserItemFilter() throws Exception {
+ Configuration conf = getConfiguration();
+ IDReader idReader = new IDReader(conf);
+ Map<Long, FastIDSet> userItemFilter = Maps.newHashMap();
+
+ long user1 = 1;
+ long user2 = 2;
+
+ idReader.addUserAndItemIdToUserItemFilter(userItemFilter, user1, 100L);
+ idReader.addUserAndItemIdToUserItemFilter(userItemFilter, user1, 200L);
+ idReader.addUserAndItemIdToUserItemFilter(userItemFilter, user2, 300L);
+
+ FastIDSet userIds = IDReader.extractAllUserIdsFromUserItemFilter(userItemFilter);
+
+ assertEquals(2, userIds.size());
+ assertTrue(userIds.contains(user1));
+ assertTrue(userIds.contains(user1));
+
+ setField(idReader, USER_ITEM_FILTER_FIELD, userItemFilter);
+
+ FastIDSet itemsForUser1 = idReader.getItemsToRecommendForUser(user1);
+ assertEquals(2, itemsForUser1.size());
+ assertTrue(itemsForUser1.contains(100L));
+ assertTrue(itemsForUser1.contains(200L));
+
+ FastIDSet itemsForUser2 = idReader.getItemsToRecommendForUser(user2);
+ assertEquals(1, itemsForUser2.size());
+ assertTrue(itemsForUser2.contains(300L));
+
+ FastIDSet itemsForNonExistingUser = idReader.getItemsToRecommendForUser(3L);
+ assertTrue(itemsForNonExistingUser.isEmpty());
+ }
+
+}
[07/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/mlp/Datasets.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/mlp/Datasets.java b/mr/src/test/java/org/apache/mahout/classifier/mlp/Datasets.java
new file mode 100644
index 0000000..76b1d3f
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/mlp/Datasets.java
@@ -0,0 +1,866 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.mlp;
+
+public class Datasets {
+
+ public static final String[] IRIS = new String[] {
+ "Sepal.Length,Sepal.Width,Petal.Length,Petal.Width,Species",
+ "5.1,3.5,1.4,0.2,setosa",
+ "4.9,3.0,1.4,0.2,setosa",
+ "4.7,3.2,1.3,0.2,setosa",
+ "4.6,3.1,1.5,0.2,setosa",
+ "5.0,3.6,1.4,0.2,setosa",
+ "5.4,3.9,1.7,0.4,setosa",
+ "4.6,3.4,1.4,0.3,setosa",
+ "5.0,3.4,1.5,0.2,setosa",
+ "4.4,2.9,1.4,0.2,setosa",
+ "4.9,3.1,1.5,0.1,setosa",
+ "5.4,3.7,1.5,0.2,setosa",
+ "4.8,3.4,1.6,0.2,setosa",
+ "4.8,3.0,1.4,0.1,setosa",
+ "4.3,3.0,1.1,0.1,setosa",
+ "5.8,4.0,1.2,0.2,setosa",
+ "5.7,4.4,1.5,0.4,setosa",
+ "5.4,3.9,1.3,0.4,setosa",
+ "5.1,3.5,1.4,0.3,setosa",
+ "5.7,3.8,1.7,0.3,setosa",
+ "5.1,3.8,1.5,0.3,setosa",
+ "5.4,3.4,1.7,0.2,setosa",
+ "5.1,3.7,1.5,0.4,setosa",
+ "4.6,3.6,1.0,0.2,setosa",
+ "5.1,3.3,1.7,0.5,setosa",
+ "4.8,3.4,1.9,0.2,setosa",
+ "5.0,3.0,1.6,0.2,setosa",
+ "5.0,3.4,1.6,0.4,setosa",
+ "5.2,3.5,1.5,0.2,setosa",
+ "5.2,3.4,1.4,0.2,setosa",
+ "4.7,3.2,1.6,0.2,setosa",
+ "4.8,3.1,1.6,0.2,setosa",
+ "5.4,3.4,1.5,0.4,setosa",
+ "5.2,4.1,1.5,0.1,setosa",
+ "5.5,4.2,1.4,0.2,setosa",
+ "4.9,3.1,1.5,0.2,setosa",
+ "5.0,3.2,1.2,0.2,setosa",
+ "5.5,3.5,1.3,0.2,setosa",
+ "4.9,3.6,1.4,0.1,setosa",
+ "4.4,3.0,1.3,0.2,setosa",
+ "5.1,3.4,1.5,0.2,setosa",
+ "5.0,3.5,1.3,0.3,setosa",
+ "4.5,2.3,1.3,0.3,setosa",
+ "4.4,3.2,1.3,0.2,setosa",
+ "5.0,3.5,1.6,0.6,setosa",
+ "5.1,3.8,1.9,0.4,setosa",
+ "4.8,3.0,1.4,0.3,setosa",
+ "5.1,3.8,1.6,0.2,setosa",
+ "4.6,3.2,1.4,0.2,setosa",
+ "5.3,3.7,1.5,0.2,setosa",
+ "5.0,3.3,1.4,0.2,setosa",
+ "7.0,3.2,4.7,1.4,versicolor",
+ "6.4,3.2,4.5,1.5,versicolor",
+ "6.9,3.1,4.9,1.5,versicolor",
+ "5.5,2.3,4.0,1.3,versicolor",
+ "6.5,2.8,4.6,1.5,versicolor",
+ "5.7,2.8,4.5,1.3,versicolor",
+ "6.3,3.3,4.7,1.6,versicolor",
+ "4.9,2.4,3.3,1.0,versicolor",
+ "6.6,2.9,4.6,1.3,versicolor",
+ "5.2,2.7,3.9,1.4,versicolor",
+ "5.0,2.0,3.5,1.0,versicolor",
+ "5.9,3.0,4.2,1.5,versicolor",
+ "6.0,2.2,4.0,1.0,versicolor",
+ "6.1,2.9,4.7,1.4,versicolor",
+ "5.6,2.9,3.6,1.3,versicolor",
+ "6.7,3.1,4.4,1.4,versicolor",
+ "5.6,3.0,4.5,1.5,versicolor",
+ "5.8,2.7,4.1,1.0,versicolor",
+ "6.2,2.2,4.5,1.5,versicolor",
+ "5.6,2.5,3.9,1.1,versicolor",
+ "5.9,3.2,4.8,1.8,versicolor",
+ "6.1,2.8,4.0,1.3,versicolor",
+ "6.3,2.5,4.9,1.5,versicolor",
+ "6.1,2.8,4.7,1.2,versicolor",
+ "6.4,2.9,4.3,1.3,versicolor",
+ "6.6,3.0,4.4,1.4,versicolor",
+ "6.8,2.8,4.8,1.4,versicolor",
+ "6.7,3.0,5.0,1.7,versicolor",
+ "6.0,2.9,4.5,1.5,versicolor",
+ "5.7,2.6,3.5,1.0,versicolor",
+ "5.5,2.4,3.8,1.1,versicolor",
+ "5.5,2.4,3.7,1.0,versicolor",
+ "5.8,2.7,3.9,1.2,versicolor",
+ "6.0,2.7,5.1,1.6,versicolor",
+ "5.4,3.0,4.5,1.5,versicolor",
+ "6.0,3.4,4.5,1.6,versicolor",
+ "6.7,3.1,4.7,1.5,versicolor",
+ "6.3,2.3,4.4,1.3,versicolor",
+ "5.6,3.0,4.1,1.3,versicolor",
+ "5.5,2.5,4.0,1.3,versicolor",
+ "5.5,2.6,4.4,1.2,versicolor",
+ "6.1,3.0,4.6,1.4,versicolor",
+ "5.8,2.6,4.0,1.2,versicolor",
+ "5.0,2.3,3.3,1.0,versicolor",
+ "5.6,2.7,4.2,1.3,versicolor",
+ "5.7,3.0,4.2,1.2,versicolor",
+ "5.7,2.9,4.2,1.3,versicolor",
+ "6.2,2.9,4.3,1.3,versicolor",
+ "5.1,2.5,3.0,1.1,versicolor",
+ "5.7,2.8,4.1,1.3,versicolor",
+ "6.3,3.3,6.0,2.5,virginica",
+ "5.8,2.7,5.1,1.9,virginica",
+ "7.1,3.0,5.9,2.1,virginica",
+ "6.3,2.9,5.6,1.8,virginica",
+ "6.5,3.0,5.8,2.2,virginica",
+ "7.6,3.0,6.6,2.1,virginica",
+ "4.9,2.5,4.5,1.7,virginica",
+ "7.3,2.9,6.3,1.8,virginica",
+ "6.7,2.5,5.8,1.8,virginica",
+ "7.2,3.6,6.1,2.5,virginica",
+ "6.5,3.2,5.1,2.0,virginica",
+ "6.4,2.7,5.3,1.9,virginica",
+ "6.8,3.0,5.5,2.1,virginica",
+ "5.7,2.5,5.0,2.0,virginica",
+ "5.8,2.8,5.1,2.4,virginica",
+ "6.4,3.2,5.3,2.3,virginica",
+ "6.5,3.0,5.5,1.8,virginica",
+ "7.7,3.8,6.7,2.2,virginica",
+ "7.7,2.6,6.9,2.3,virginica",
+ "6.0,2.2,5.0,1.5,virginica",
+ "6.9,3.2,5.7,2.3,virginica",
+ "5.6,2.8,4.9,2.0,virginica",
+ "7.7,2.8,6.7,2.0,virginica",
+ "6.3,2.7,4.9,1.8,virginica",
+ "6.7,3.3,5.7,2.1,virginica",
+ "7.2,3.2,6.0,1.8,virginica",
+ "6.2,2.8,4.8,1.8,virginica",
+ "6.1,3.0,4.9,1.8,virginica",
+ "6.4,2.8,5.6,2.1,virginica",
+ "7.2,3.0,5.8,1.6,virginica",
+ "7.4,2.8,6.1,1.9,virginica",
+ "7.9,3.8,6.4,2.0,virginica",
+ "6.4,2.8,5.6,2.2,virginica",
+ "6.3,2.8,5.1,1.5,virginica",
+ "6.1,2.6,5.6,1.4,virginica",
+ "7.7,3.0,6.1,2.3,virginica",
+ "6.3,3.4,5.6,2.4,virginica",
+ "6.4,3.1,5.5,1.8,virginica",
+ "6.0,3.0,4.8,1.8,virginica",
+ "6.9,3.1,5.4,2.1,virginica",
+ "6.7,3.1,5.6,2.4,virginica",
+ "6.9,3.1,5.1,2.3,virginica",
+ "5.8,2.7,5.1,1.9,virginica",
+ "6.8,3.2,5.9,2.3,virginica",
+ "6.7,3.3,5.7,2.5,virginica",
+ "6.7,3.0,5.2,2.3,virginica",
+ "6.3,2.5,5.0,1.9,virginica",
+ "6.5,3.0,5.2,2.0,virginica",
+ "6.2,3.4,5.4,2.3,virginica",
+ "5.9,3.0,5.1,1.8,virginica"
+ };
+
+ public static final String[] CANCER = new String[] {
+ "\"V1\",\"V2\",\"V3\",\"V4\",\"V5\",\"V6\",\"V7\",\"V8\",\"V9\",\"target\"",
+ "5,1,1,1,2,1,3,1,1,0",
+ "5,4,4,5,7,10,3,2,1,0",
+ "3,1,1,1,2,2,3,1,1,0",
+ "6,8,8,1,3,4,3,7,1,0",
+ "4,1,1,3,2,1,3,1,1,0",
+ "8,10,10,8,7,10,9,7,1,1",
+ "1,1,1,1,2,10,3,1,1,0",
+ "2,1,2,1,2,1,3,1,1,0",
+ "2,1,1,1,2,1,1,1,5,0",
+ "4,2,1,1,2,1,2,1,1,0",
+ "1,1,1,1,1,1,3,1,1,0",
+ "2,1,1,1,2,1,2,1,1,0",
+ "5,3,3,3,2,3,4,4,1,1",
+ "1,1,1,1,2,3,3,1,1,0",
+ "8,7,5,10,7,9,5,5,4,1",
+ "7,4,6,4,6,1,4,3,1,1",
+ "4,1,1,1,2,1,2,1,1,0",
+ "4,1,1,1,2,1,3,1,1,0",
+ "10,7,7,6,4,10,4,1,2,1",
+ "6,1,1,1,2,1,3,1,1,0",
+ "7,3,2,10,5,10,5,4,4,1",
+ "10,5,5,3,6,7,7,10,1,1",
+ "3,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "5,2,3,4,2,7,3,6,1,1",
+ "3,2,1,1,1,1,2,1,1,0",
+ "5,1,1,1,2,1,2,1,1,0",
+ "2,1,1,1,2,1,2,1,1,0",
+ "1,1,3,1,2,1,1,1,1,0",
+ "3,1,1,1,1,1,2,1,1,0",
+ "2,1,1,1,2,1,3,1,1,0",
+ "10,7,7,3,8,5,7,4,3,1",
+ "2,1,1,2,2,1,3,1,1,0",
+ "3,1,2,1,2,1,2,1,1,0",
+ "2,1,1,1,2,1,2,1,1,0",
+ "10,10,10,8,6,1,8,9,1,1",
+ "6,2,1,1,1,1,7,1,1,0",
+ "5,4,4,9,2,10,5,6,1,1",
+ "2,5,3,3,6,7,7,5,1,1",
+ "10,4,3,1,3,3,6,5,2,1",
+ "6,10,10,2,8,10,7,3,3,1",
+ "5,6,5,6,10,1,3,1,1,1",
+ "10,10,10,4,8,1,8,10,1,1",
+ "1,1,1,1,2,1,2,1,2,0",
+ "3,7,7,4,4,9,4,8,1,1",
+ "1,1,1,1,2,1,2,1,1,0",
+ "4,1,1,3,2,1,3,1,1,0",
+ "7,8,7,2,4,8,3,8,2,1",
+ "9,5,8,1,2,3,2,1,5,1",
+ "5,3,3,4,2,4,3,4,1,1",
+ "10,3,6,2,3,5,4,10,2,1",
+ "5,5,5,8,10,8,7,3,7,1",
+ "10,5,5,6,8,8,7,1,1,1",
+ "10,6,6,3,4,5,3,6,1,1",
+ "8,10,10,1,3,6,3,9,1,1",
+ "8,2,4,1,5,1,5,4,4,1",
+ "5,2,3,1,6,10,5,1,1,1",
+ "9,5,5,2,2,2,5,1,1,1",
+ "5,3,5,5,3,3,4,10,1,1",
+ "1,1,1,1,2,2,2,1,1,0",
+ "9,10,10,1,10,8,3,3,1,1",
+ "6,3,4,1,5,2,3,9,1,1",
+ "1,1,1,1,2,1,2,1,1,0",
+ "10,4,2,1,3,2,4,3,10,1",
+ "4,1,1,1,2,1,3,1,1,0",
+ "5,3,4,1,8,10,4,9,1,1",
+ "8,3,8,3,4,9,8,9,8,1",
+ "1,1,1,1,2,1,3,2,1,0",
+ "5,1,3,1,2,1,2,1,1,0",
+ "6,10,2,8,10,2,7,8,10,1",
+ "1,3,3,2,2,1,7,2,1,0",
+ "9,4,5,10,6,10,4,8,1,1",
+ "10,6,4,1,3,4,3,2,3,1",
+ "1,1,2,1,2,2,4,2,1,0",
+ "1,1,4,1,2,1,2,1,1,0",
+ "5,3,1,2,2,1,2,1,1,0",
+ "3,1,1,1,2,3,3,1,1,0",
+ "2,1,1,1,3,1,2,1,1,0",
+ "2,2,2,1,1,1,7,1,1,0",
+ "4,1,1,2,2,1,2,1,1,0",
+ "5,2,1,1,2,1,3,1,1,0",
+ "3,1,1,1,2,2,7,1,1,0",
+ "3,5,7,8,8,9,7,10,7,1",
+ "5,10,6,1,10,4,4,10,10,1",
+ "3,3,6,4,5,8,4,4,1,1",
+ "3,6,6,6,5,10,6,8,3,1",
+ "4,1,1,1,2,1,3,1,1,0",
+ "2,1,1,2,3,1,2,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "3,1,1,2,2,1,1,1,1,0",
+ "4,1,1,1,2,1,3,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "2,1,1,1,2,1,3,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "2,1,1,2,2,1,1,1,1,0",
+ "5,1,1,1,2,1,3,1,1,0",
+ "9,6,9,2,10,6,2,9,10,1",
+ "7,5,6,10,5,10,7,9,4,1",
+ "10,3,5,1,10,5,3,10,2,1",
+ "2,3,4,4,2,5,2,5,1,1",
+ "4,1,2,1,2,1,3,1,1,0",
+ "8,2,3,1,6,3,7,1,1,1",
+ "10,10,10,10,10,1,8,8,8,1",
+ "7,3,4,4,3,3,3,2,7,1",
+ "10,10,10,8,2,10,4,1,1,1",
+ "1,6,8,10,8,10,5,7,1,1",
+ "1,1,1,1,2,1,2,3,1,0",
+ "6,5,4,4,3,9,7,8,3,1",
+ "1,3,1,2,2,2,5,3,2,0",
+ "8,6,4,3,5,9,3,1,1,1",
+ "10,3,3,10,2,10,7,3,3,1",
+ "10,10,10,3,10,8,8,1,1,1",
+ "3,3,2,1,2,3,3,1,1,0",
+ "1,1,1,1,2,5,1,1,1,0",
+ "8,3,3,1,2,2,3,2,1,0",
+ "4,5,5,10,4,10,7,5,8,1",
+ "1,1,1,1,4,3,1,1,1,0",
+ "3,2,1,1,2,2,3,1,1,0",
+ "1,1,2,2,2,1,3,1,1,0",
+ "4,2,1,1,2,2,3,1,1,0",
+ "10,10,10,2,10,10,5,3,3,1",
+ "5,3,5,1,8,10,5,3,1,1",
+ "5,4,6,7,9,7,8,10,1,1",
+ "1,1,1,1,2,1,2,1,1,0",
+ "7,5,3,7,4,10,7,5,5,1",
+ "3,1,1,1,2,1,3,1,1,0",
+ "8,3,5,4,5,10,1,6,2,1",
+ "1,1,1,1,10,1,1,1,1,0",
+ "5,1,3,1,2,1,2,1,1,0",
+ "2,1,1,1,2,1,3,1,1,0",
+ "5,10,8,10,8,10,3,6,3,1",
+ "3,1,1,1,2,1,2,2,1,0",
+ "3,1,1,1,3,1,2,1,1,0",
+ "5,1,1,1,2,2,3,3,1,0",
+ "4,1,1,1,2,1,2,1,1,0",
+ "3,1,1,1,2,1,1,1,1,0",
+ "4,1,2,1,2,1,2,1,1,0",
+ "3,1,1,1,2,1,1,1,1,0",
+ "2,1,1,1,2,1,1,1,1,0",
+ "9,5,5,4,4,5,4,3,3,1",
+ "1,1,1,1,2,5,1,1,1,0",
+ "2,1,1,1,2,1,2,1,1,0",
+ "3,4,5,2,6,8,4,1,1,1",
+ "1,1,1,1,3,2,2,1,1,0",
+ "3,1,1,3,8,1,5,8,1,0",
+ "8,8,7,4,10,10,7,8,7,1",
+ "1,1,1,1,1,1,3,1,1,0",
+ "7,2,4,1,6,10,5,4,3,1",
+ "10,10,8,6,4,5,8,10,1,1",
+ "4,1,1,1,2,3,1,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "5,5,5,6,3,10,3,1,1,1",
+ "1,2,2,1,2,1,2,1,1,0",
+ "2,1,1,1,2,1,3,1,1,0",
+ "9,9,10,3,6,10,7,10,6,1",
+ "10,7,7,4,5,10,5,7,2,1",
+ "4,1,1,1,2,1,3,2,1,0",
+ "3,1,1,1,2,1,3,1,1,0",
+ "1,1,1,2,1,3,1,1,7,0",
+ "4,1,1,1,2,2,3,2,1,0",
+ "5,6,7,8,8,10,3,10,3,1",
+ "10,8,10,10,6,1,3,1,10,1",
+ "3,1,1,1,2,1,3,1,1,0",
+ "1,1,1,2,1,1,1,1,1,0",
+ "3,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "6,10,10,10,8,10,10,10,7,1",
+ "8,6,5,4,3,10,6,1,1,1",
+ "5,8,7,7,10,10,5,7,1,1",
+ "2,1,1,1,2,1,3,1,1,0",
+ "5,10,10,3,8,1,5,10,3,1",
+ "4,1,1,1,2,1,3,1,1,0",
+ "5,3,3,3,6,10,3,1,1,1",
+ "1,1,1,1,1,1,3,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "6,1,1,1,2,1,3,1,1,0",
+ "5,8,8,8,5,10,7,8,1,1",
+ "8,7,6,4,4,10,5,1,1,1",
+ "2,1,1,1,1,1,3,1,1,0",
+ "1,5,8,6,5,8,7,10,1,1",
+ "10,5,6,10,6,10,7,7,10,1",
+ "5,8,4,10,5,8,9,10,1,1",
+ "1,2,3,1,2,1,3,1,1,0",
+ "10,10,10,8,6,8,7,10,1,1",
+ "7,5,10,10,10,10,4,10,3,1",
+ "5,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "3,1,1,1,2,1,3,1,1,0",
+ "4,1,1,1,2,1,3,1,1,0",
+ "8,4,4,5,4,7,7,8,2,0",
+ "5,1,1,4,2,1,3,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "9,7,7,5,5,10,7,8,3,1",
+ "10,8,8,4,10,10,8,1,1,1",
+ "1,1,1,1,2,1,3,1,1,0",
+ "5,1,1,1,2,1,3,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "5,10,10,9,6,10,7,10,5,1",
+ "10,10,9,3,7,5,3,5,1,1",
+ "1,1,1,1,1,1,3,1,1,0",
+ "1,1,1,1,1,1,3,1,1,0",
+ "5,1,1,1,1,1,3,1,1,0",
+ "8,10,10,10,5,10,8,10,6,1",
+ "8,10,8,8,4,8,7,7,1,1",
+ "1,1,1,1,2,1,3,1,1,0",
+ "10,10,10,10,7,10,7,10,4,1",
+ "10,10,10,10,3,10,10,6,1,1",
+ "8,7,8,7,5,5,5,10,2,1",
+ "1,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "6,10,7,7,6,4,8,10,2,1",
+ "6,1,3,1,2,1,3,1,1,0",
+ "1,1,1,2,2,1,3,1,1,0",
+ "10,6,4,3,10,10,9,10,1,1",
+ "4,1,1,3,1,5,2,1,1,1",
+ "7,5,6,3,3,8,7,4,1,1",
+ "10,5,5,6,3,10,7,9,2,1",
+ "1,1,1,1,2,1,2,1,1,0",
+ "10,5,7,4,4,10,8,9,1,1",
+ "8,9,9,5,3,5,7,7,1,1",
+ "1,1,1,1,1,1,3,1,1,0",
+ "10,10,10,3,10,10,9,10,1,1",
+ "7,4,7,4,3,7,7,6,1,1",
+ "6,8,7,5,6,8,8,9,2,1",
+ "8,4,6,3,3,1,4,3,1,0",
+ "10,4,5,5,5,10,4,1,1,1",
+ "3,3,2,1,3,1,3,6,1,0",
+ "10,8,8,2,8,10,4,8,10,1",
+ "9,8,8,5,6,2,4,10,4,1",
+ "8,10,10,8,6,9,3,10,10,1",
+ "10,4,3,2,3,10,5,3,2,1",
+ "5,1,3,3,2,2,2,3,1,0",
+ "3,1,1,3,1,1,3,1,1,0",
+ "2,1,1,1,2,1,3,1,1,0",
+ "1,1,1,1,2,5,5,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "5,1,1,2,2,2,3,1,1,0",
+ "8,10,10,8,5,10,7,8,1,1",
+ "8,4,4,1,2,9,3,3,1,1",
+ "4,1,1,1,2,1,3,6,1,0",
+ "1,2,2,1,2,1,1,1,1,0",
+ "10,4,4,10,2,10,5,3,3,1",
+ "6,3,3,5,3,10,3,5,3,0",
+ "6,10,10,2,8,10,7,3,3,1",
+ "9,10,10,1,10,8,3,3,1,1",
+ "5,6,6,2,4,10,3,6,1,1",
+ "3,1,1,1,2,1,1,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "3,1,1,1,2,1,3,1,1,0",
+ "5,7,7,1,5,8,3,4,1,0",
+ "10,5,8,10,3,10,5,1,3,1",
+ "5,10,10,6,10,10,10,6,5,1",
+ "8,8,9,4,5,10,7,8,1,1",
+ "10,4,4,10,6,10,5,5,1,1",
+ "7,9,4,10,10,3,5,3,3,1",
+ "5,1,4,1,2,1,3,2,1,0",
+ "10,10,6,3,3,10,4,3,2,1",
+ "3,3,5,2,3,10,7,1,1,1",
+ "10,8,8,2,3,4,8,7,8,1",
+ "1,1,1,1,2,1,3,1,1,0",
+ "8,4,7,1,3,10,3,9,2,1",
+ "5,1,1,1,2,1,3,1,1,0",
+ "3,3,5,2,3,10,7,1,1,1",
+ "7,2,4,1,3,4,3,3,1,1",
+ "3,1,1,1,2,1,3,2,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "10,5,7,3,3,7,3,3,8,1",
+ "3,1,1,1,2,1,3,1,1,0",
+ "2,1,1,2,2,1,3,1,1,0",
+ "1,4,3,10,4,10,5,6,1,1",
+ "10,4,6,1,2,10,5,3,1,1",
+ "7,4,5,10,2,10,3,8,2,1",
+ "8,10,10,10,8,10,10,7,3,1",
+ "10,10,10,10,10,10,4,10,10,1",
+ "3,1,1,1,3,1,2,1,1,0",
+ "6,1,3,1,4,5,5,10,1,1",
+ "5,6,6,8,6,10,4,10,4,1",
+ "1,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "10,4,4,6,2,10,2,3,1,1",
+ "5,5,7,8,6,10,7,4,1,1",
+ "5,3,4,3,4,5,4,7,1,0",
+ "8,2,1,1,5,1,1,1,1,0",
+ "9,1,2,6,4,10,7,7,2,1",
+ "8,4,10,5,4,4,7,10,1,1",
+ "1,1,1,1,2,1,3,1,1,0",
+ "10,10,10,7,9,10,7,10,10,1",
+ "1,1,1,1,2,1,3,1,1,0",
+ "8,3,4,9,3,10,3,3,1,1",
+ "10,8,4,4,4,10,3,10,4,1",
+ "1,1,1,1,2,1,3,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "7,8,7,6,4,3,8,8,4,1",
+ "3,1,1,1,2,5,5,1,1,0",
+ "2,1,1,1,3,1,2,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "8,6,4,10,10,1,3,5,1,1",
+ "1,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,1,1,2,1,1,0",
+ "5,5,5,2,5,10,4,3,1,1",
+ "6,8,7,8,6,8,8,9,1,1",
+ "1,1,1,1,5,1,3,1,1,0",
+ "4,4,4,4,6,5,7,3,1,0",
+ "7,6,3,2,5,10,7,4,6,1",
+ "3,1,1,1,2,1,3,1,1,0",
+ "5,4,6,10,2,10,4,1,1,1",
+ "1,1,1,1,2,1,3,1,1,0",
+ "3,2,2,1,2,1,2,3,1,0",
+ "10,1,1,1,2,10,5,4,1,1",
+ "1,1,1,1,2,1,2,1,1,0",
+ "8,10,3,2,6,4,3,10,1,1",
+ "10,4,6,4,5,10,7,1,1,1",
+ "10,4,7,2,2,8,6,1,1,1",
+ "5,1,1,1,2,1,3,1,2,0",
+ "5,2,2,2,2,1,2,2,1,0",
+ "5,4,6,6,4,10,4,3,1,1",
+ "8,6,7,3,3,10,3,4,2,1",
+ "1,1,1,1,2,1,1,1,1,0",
+ "6,5,5,8,4,10,3,4,1,1",
+ "1,1,1,1,2,1,3,1,1,0",
+ "1,1,1,1,1,1,2,1,1,0",
+ "8,5,5,5,2,10,4,3,1,1",
+ "10,3,3,1,2,10,7,6,1,1",
+ "1,1,1,1,2,1,3,1,1,0",
+ "2,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "7,6,4,8,10,10,9,5,3,1",
+ "1,1,1,1,2,1,1,1,1,0",
+ "5,2,2,2,3,1,1,3,1,0",
+ "1,1,1,1,1,1,1,3,1,0",
+ "3,4,4,10,5,1,3,3,1,1",
+ "4,2,3,5,3,8,7,6,1,1",
+ "5,1,1,3,2,1,1,1,1,0",
+ "2,1,1,1,2,1,3,1,1,0",
+ "3,4,5,3,7,3,4,6,1,0",
+ "2,7,10,10,7,10,4,9,4,1",
+ "1,1,1,1,2,1,2,1,1,0",
+ "4,1,1,1,3,1,2,2,1,0",
+ "5,3,3,1,3,3,3,3,3,1",
+ "8,10,10,7,10,10,7,3,8,1",
+ "8,10,5,3,8,4,4,10,3,1",
+ "10,3,5,4,3,7,3,5,3,1",
+ "6,10,10,10,10,10,8,10,10,1",
+ "3,10,3,10,6,10,5,1,4,1",
+ "3,2,2,1,4,3,2,1,1,0",
+ "4,4,4,2,2,3,2,1,1,0",
+ "2,1,1,1,2,1,3,1,1,0",
+ "2,1,1,1,2,1,2,1,1,0",
+ "6,10,10,10,8,10,7,10,7,1",
+ "5,8,8,10,5,10,8,10,3,1",
+ "1,1,3,1,2,1,1,1,1,0",
+ "1,1,3,1,1,1,2,1,1,0",
+ "4,3,2,1,3,1,2,1,1,0",
+ "1,1,3,1,2,1,1,1,1,0",
+ "4,1,2,1,2,1,2,1,1,0",
+ "5,1,1,2,2,1,2,1,1,0",
+ "3,1,2,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,1,1,2,1,1,0",
+ "3,1,1,4,3,1,2,2,1,0",
+ "5,3,4,1,4,1,3,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "10,6,3,6,4,10,7,8,4,1",
+ "3,2,2,2,2,1,3,2,1,0",
+ "2,1,1,1,2,1,1,1,1,0",
+ "2,1,1,1,2,1,1,1,1,0",
+ "3,3,2,2,3,1,1,2,3,0",
+ "7,6,6,3,2,10,7,1,1,1",
+ "5,3,3,2,3,1,3,1,1,0",
+ "2,1,1,1,2,1,2,2,1,0",
+ "5,1,1,1,3,2,2,2,1,0",
+ "1,1,1,2,2,1,2,1,1,0",
+ "10,8,7,4,3,10,7,9,1,1",
+ "3,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,1,1,1,1,1,0",
+ "1,2,3,1,2,1,2,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "3,1,1,1,2,1,3,1,1,0",
+ "4,1,1,1,2,1,1,1,1,0",
+ "3,2,1,1,2,1,2,2,1,0",
+ "1,2,3,1,2,1,1,1,1,0",
+ "3,10,8,7,6,9,9,3,8,1",
+ "3,1,1,1,2,1,1,1,1,0",
+ "5,3,3,1,2,1,2,1,1,0",
+ "3,1,1,1,2,4,1,1,1,0",
+ "1,2,1,3,2,1,1,2,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "4,2,2,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "2,3,2,2,2,2,3,1,1,0",
+ "3,1,2,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "10,10,10,6,8,4,8,5,1,1",
+ "5,1,2,1,2,1,3,1,1,0",
+ "8,5,6,2,3,10,6,6,1,1",
+ "3,3,2,6,3,3,3,5,1,0",
+ "8,7,8,5,10,10,7,2,1,1",
+ "1,1,1,1,2,1,2,1,1,0",
+ "5,2,2,2,2,2,3,2,2,0",
+ "2,3,1,1,5,1,1,1,1,0",
+ "3,2,2,3,2,3,3,1,1,0",
+ "10,10,10,7,10,10,8,2,1,1",
+ "4,3,3,1,2,1,3,3,1,0",
+ "5,1,3,1,2,1,2,1,1,0",
+ "3,1,1,1,2,1,1,1,1,0",
+ "9,10,10,10,10,10,10,10,1,1",
+ "5,3,6,1,2,1,1,1,1,0",
+ "8,7,8,2,4,2,5,10,1,1",
+ "1,1,1,1,2,1,2,1,1,0",
+ "2,1,1,1,2,1,2,1,1,0",
+ "1,3,1,1,2,1,2,2,1,0",
+ "5,1,1,3,4,1,3,2,1,0",
+ "5,1,1,1,2,1,2,2,1,0",
+ "3,2,2,3,2,1,1,1,1,0",
+ "6,9,7,5,5,8,4,2,1,0",
+ "10,8,10,1,3,10,5,1,1,1",
+ "10,10,10,1,6,1,2,8,1,1",
+ "4,1,1,1,2,1,1,1,1,0",
+ "4,1,3,3,2,1,1,1,1,0",
+ "5,1,1,1,2,1,1,1,1,0",
+ "10,4,3,10,4,10,10,1,1,1",
+ "5,2,2,4,2,4,1,1,1,0",
+ "1,1,1,3,2,3,1,1,1,0",
+ "1,1,1,1,2,2,1,1,1,0",
+ "5,1,1,6,3,1,2,1,1,0",
+ "2,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "5,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,1,1,1,1,1,0",
+ "5,7,9,8,6,10,8,10,1,1",
+ "4,1,1,3,1,1,2,1,1,0",
+ "5,1,1,1,2,1,1,1,1,0",
+ "3,1,1,3,2,1,1,1,1,0",
+ "4,5,5,8,6,10,10,7,1,1",
+ "2,3,1,1,3,1,1,1,1,0",
+ "10,2,2,1,2,6,1,1,2,1",
+ "10,6,5,8,5,10,8,6,1,1",
+ "8,8,9,6,6,3,10,10,1,1",
+ "5,1,2,1,2,1,1,1,1,0",
+ "5,1,3,1,2,1,1,1,1,0",
+ "5,1,1,3,2,1,1,1,1,0",
+ "3,1,1,1,2,5,1,1,1,0",
+ "6,1,1,3,2,1,1,1,1,0",
+ "4,1,1,1,2,1,1,2,1,0",
+ "4,1,1,1,2,1,1,1,1,0",
+ "10,9,8,7,6,4,7,10,3,1",
+ "10,6,6,2,4,10,9,7,1,1",
+ "6,6,6,5,4,10,7,6,2,1",
+ "4,1,1,1,2,1,1,1,1,0",
+ "1,1,2,1,2,1,2,1,1,0",
+ "3,1,1,1,1,1,2,1,1,0",
+ "6,1,1,3,2,1,1,1,1,0",
+ "6,1,1,1,1,1,1,1,1,0",
+ "4,1,1,1,2,1,1,1,1,0",
+ "5,1,1,1,2,1,1,1,1,0",
+ "3,1,1,1,2,1,1,1,1,0",
+ "4,1,2,1,2,1,1,1,1,0",
+ "4,1,1,1,2,1,1,1,1,0",
+ "5,2,1,1,2,1,1,1,1,0",
+ "4,8,7,10,4,10,7,5,1,1",
+ "5,1,1,1,1,1,1,1,1,0",
+ "5,3,2,4,2,1,1,1,1,0",
+ "9,10,10,10,10,5,10,10,10,1",
+ "8,7,8,5,5,10,9,10,1,1",
+ "5,1,2,1,2,1,1,1,1,0",
+ "1,1,1,3,1,3,1,1,1,0",
+ "3,1,1,1,1,1,2,1,1,0",
+ "10,10,10,10,6,10,8,1,5,1",
+ "3,6,4,10,3,3,3,4,1,1",
+ "6,3,2,1,3,4,4,1,1,1",
+ "1,1,1,1,2,1,1,1,1,0",
+ "5,8,9,4,3,10,7,1,1,1",
+ "4,1,1,1,1,1,2,1,1,0",
+ "5,10,10,10,6,10,6,5,2,1",
+ "5,1,2,10,4,5,2,1,1,0",
+ "3,1,1,1,1,1,2,1,1,0",
+ "1,1,1,1,1,1,1,1,1,0",
+ "4,2,1,1,2,1,1,1,1,0",
+ "4,1,1,1,2,1,2,1,1,0",
+ "4,1,1,1,2,1,2,1,1,0",
+ "6,1,1,1,2,1,3,1,1,0",
+ "4,1,1,1,2,1,2,1,1,0",
+ "4,1,1,2,2,1,2,1,1,0",
+ "4,1,1,1,2,1,3,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "3,3,1,1,2,1,1,1,1,0",
+ "8,10,10,10,7,5,4,8,7,1",
+ "1,1,1,1,2,4,1,1,1,0",
+ "5,1,1,1,2,1,1,1,1,0",
+ "2,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "5,1,1,1,2,1,2,1,1,0",
+ "5,1,1,1,2,1,1,1,1,0",
+ "3,1,1,1,1,1,2,1,1,0",
+ "6,6,7,10,3,10,8,10,2,1",
+ "4,10,4,7,3,10,9,10,1,1",
+ "1,1,1,1,1,1,1,1,1,0",
+ "1,1,1,1,1,1,2,1,1,0",
+ "3,1,2,2,2,1,1,1,1,0",
+ "4,7,8,3,4,10,9,1,1,1",
+ "1,1,1,1,3,1,1,1,1,0",
+ "4,1,1,1,3,1,1,1,1,0",
+ "10,4,5,4,3,5,7,3,1,1",
+ "7,5,6,10,4,10,5,3,1,1",
+ "3,1,1,1,2,1,2,1,1,0",
+ "3,1,1,2,2,1,1,1,1,0",
+ "4,1,1,1,2,1,1,1,1,0",
+ "4,1,1,1,2,1,3,1,1,0",
+ "6,1,3,2,2,1,1,1,1,0",
+ "4,1,1,1,1,1,2,1,1,0",
+ "7,4,4,3,4,10,6,9,1,1",
+ "4,2,2,1,2,1,2,1,1,0",
+ "1,1,1,1,1,1,3,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "2,1,1,1,2,1,2,1,1,0",
+ "1,1,3,2,2,1,3,1,1,0",
+ "5,1,1,1,2,1,3,1,1,0",
+ "5,1,2,1,2,1,3,1,1,0",
+ "4,1,1,1,2,1,2,1,1,0",
+ "6,1,1,1,2,1,2,1,1,0",
+ "5,1,1,1,2,2,2,1,1,0",
+ "3,1,1,1,2,1,1,1,1,0",
+ "5,3,1,1,2,1,1,1,1,0",
+ "4,1,1,1,2,1,2,1,1,0",
+ "2,1,3,2,2,1,2,1,1,0",
+ "5,1,1,1,2,1,2,1,1,0",
+ "6,10,10,10,4,10,7,10,1,1",
+ "2,1,1,1,1,1,1,1,1,0",
+ "3,1,1,1,1,1,1,1,1,0",
+ "7,8,3,7,4,5,7,8,2,1",
+ "3,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "3,2,2,2,2,1,4,2,1,0",
+ "4,4,2,1,2,5,2,1,2,0",
+ "3,1,1,1,2,1,1,1,1,0",
+ "4,3,1,1,2,1,4,8,1,0",
+ "5,2,2,2,1,1,2,1,1,0",
+ "5,1,1,3,2,1,1,1,1,0",
+ "2,1,1,1,2,1,2,1,1,0",
+ "5,1,1,1,2,1,2,1,1,0",
+ "5,1,1,1,2,1,3,1,1,0",
+ "5,1,1,1,2,1,3,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "4,1,1,1,2,1,3,2,1,0",
+ "5,7,10,10,5,10,10,10,1,1",
+ "3,1,2,1,2,1,3,1,1,0",
+ "4,1,1,1,2,3,2,1,1,0",
+ "8,4,4,1,6,10,2,5,2,1",
+ "10,10,8,10,6,5,10,3,1,1",
+ "8,10,4,4,8,10,8,2,1,1",
+ "7,6,10,5,3,10,9,10,2,1",
+ "3,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "10,9,7,3,4,2,7,7,1,1",
+ "5,1,2,1,2,1,3,1,1,0",
+ "5,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,3,1,1,0",
+ "5,1,2,1,2,1,2,1,1,0",
+ "5,7,10,6,5,10,7,5,1,1",
+ "6,10,5,5,4,10,6,10,1,1",
+ "3,1,1,1,2,1,1,1,1,0",
+ "5,1,1,6,3,1,1,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "8,10,10,10,6,10,10,10,1,1",
+ "5,1,1,1,2,1,2,2,1,0",
+ "9,8,8,9,6,3,4,1,1,1",
+ "5,1,1,1,2,1,1,1,1,0",
+ "4,10,8,5,4,1,10,1,1,1",
+ "2,5,7,6,4,10,7,6,1,1",
+ "10,3,4,5,3,10,4,1,1,1",
+ "5,1,2,1,2,1,1,1,1,0",
+ "4,8,6,3,4,10,7,1,1,1",
+ "5,1,1,1,2,1,2,1,1,0",
+ "4,1,2,1,2,1,2,1,1,0",
+ "5,1,3,1,2,1,3,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "5,2,4,1,1,1,1,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,1,1,2,1,1,0",
+ "4,1,1,1,2,1,2,1,1,0",
+ "5,4,6,8,4,1,8,10,1,1",
+ "5,3,2,8,5,10,8,1,2,1",
+ "10,5,10,3,5,8,7,8,3,1",
+ "4,1,1,2,2,1,1,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "5,10,10,10,10,10,10,1,1,1",
+ "5,1,1,1,2,1,1,1,1,0",
+ "10,4,3,10,3,10,7,1,2,1",
+ "5,10,10,10,5,2,8,5,1,1",
+ "8,10,10,10,6,10,10,10,10,1",
+ "2,3,1,1,2,1,2,1,1,0",
+ "2,1,1,1,1,1,2,1,1,0",
+ "4,1,3,1,2,1,2,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "4,1,1,1,2,1,2,1,1,0",
+ "5,1,1,1,2,1,2,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "6,3,3,3,3,2,6,1,1,0",
+ "7,1,2,3,2,1,2,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "5,1,1,2,1,1,2,1,1,0",
+ "3,1,3,1,3,4,1,1,1,0",
+ "4,6,6,5,7,6,7,7,3,1",
+ "2,1,1,1,2,5,1,1,1,0",
+ "2,1,1,1,2,1,1,1,1,0",
+ "4,1,1,1,2,1,1,1,1,0",
+ "6,2,3,1,2,1,1,1,1,0",
+ "5,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "8,7,4,4,5,3,5,10,1,1",
+ "3,1,1,1,2,1,1,1,1,0",
+ "3,1,4,1,2,1,1,1,1,0",
+ "10,10,7,8,7,1,10,10,3,1",
+ "4,2,4,3,2,2,2,1,1,0",
+ "4,1,1,1,2,1,1,1,1,0",
+ "5,1,1,3,2,1,1,1,1,0",
+ "4,1,1,3,2,1,1,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "2,1,1,1,2,1,1,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "1,2,2,1,2,1,1,1,1,0",
+ "1,1,1,3,2,1,1,1,1,0",
+ "5,10,10,10,10,2,10,10,10,1",
+ "3,1,1,1,2,1,2,1,1,0",
+ "3,1,1,2,3,4,1,1,1,0",
+ "1,2,1,3,2,1,2,1,1,0",
+ "5,1,1,1,2,1,2,2,1,0",
+ "4,1,1,1,2,1,2,1,1,0",
+ "3,1,1,1,2,1,3,1,1,0",
+ "3,1,1,1,2,1,2,1,1,0",
+ "5,1,1,1,2,1,2,1,1,0",
+ "5,4,5,1,8,1,3,6,1,0",
+ "7,8,8,7,3,10,7,2,3,1",
+ "1,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "4,1,1,1,2,1,3,1,1,0",
+ "1,1,3,1,2,1,2,1,1,0",
+ "1,1,3,1,2,1,2,1,1,0",
+ "3,1,1,3,2,1,2,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "5,2,2,2,2,1,1,1,2,0",
+ "3,1,1,1,2,1,3,1,1,0",
+ "5,7,4,1,6,1,7,10,3,1",
+ "5,10,10,8,5,5,7,10,1,1",
+ "3,10,7,8,5,8,7,4,1,1",
+ "3,2,1,2,2,1,3,1,1,0",
+ "2,1,1,1,2,1,3,1,1,0",
+ "5,3,2,1,3,1,1,1,1,0",
+ "1,1,1,1,2,1,2,1,1,0",
+ "4,1,4,1,2,1,1,1,1,0",
+ "1,1,2,1,2,1,2,1,1,0",
+ "5,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "2,1,1,1,2,1,1,1,1,0",
+ "10,10,10,10,5,10,10,10,7,1",
+ "5,10,10,10,4,10,5,6,3,1",
+ "5,1,1,1,2,1,3,2,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,1,1,1,0",
+ "3,1,1,1,2,1,2,3,1,0",
+ "4,1,1,1,2,1,1,1,1,0",
+ "1,1,1,1,2,1,1,1,8,0",
+ "1,1,1,3,2,1,1,1,1,0",
+ "5,10,10,5,4,5,4,4,1,1",
+ "3,1,1,1,2,1,1,1,1,0",
+ "3,1,1,1,2,1,2,1,2,0",
+ "3,1,1,1,3,2,1,1,1,0",
+ "2,1,1,1,2,1,1,1,1,0",
+ "5,10,10,3,7,3,8,10,2,1",
+ "4,8,6,4,3,4,10,6,1,1",
+ "4,8,8,5,4,5,10,4,1,1"
+ };
+
+ private Datasets() {}
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptronTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptronTest.java b/mr/src/test/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptronTest.java
new file mode 100644
index 0000000..522ac4a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/mlp/RunMultilayerPerceptronTest.java
@@ -0,0 +1,66 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.File;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public class RunMultilayerPerceptronTest extends MahoutTestCase {
+
+ @Test
+ public void runMultilayerPerceptron() throws Exception {
+
+ // Train a model first
+ String modelFileName = "mlp.model";
+ File modelFile = getTestTempFile(modelFileName);
+
+ File irisDataset = getTestTempFile("iris.csv");
+ writeLines(irisDataset, Datasets.IRIS);
+
+ String[] argsTrain = {
+ "-i", irisDataset.getAbsolutePath(),
+ "-sh",
+ "-labels", "setosa", "versicolor", "virginica",
+ "-mo", modelFile.getAbsolutePath(),
+ "-u",
+ "-ls", "4", "8", "3"
+ };
+
+ TrainMultilayerPerceptron.main(argsTrain);
+
+ assertTrue(modelFile.exists());
+
+ String outputFileName = "labelResult.txt";
+ File outputFile = getTestTempFile(outputFileName);
+
+ String[] argsLabeling = {
+ "-i", irisDataset.getAbsolutePath(),
+ "-sh",
+ "-cr", "0", "3",
+ "-mo", modelFile.getAbsolutePath(),
+ "-o", outputFile.getAbsolutePath()
+ };
+
+ RunMultilayerPerceptron.main(argsLabeling);
+
+ assertTrue(outputFile.exists());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java b/mr/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java
new file mode 100644
index 0000000..93013b6
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/mlp/TestMultilayerPerceptron.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.mlp;
+
+import java.io.File;
+import java.io.IOException;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.Arrays;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+/**
+ * Test the functionality of {@link MultilayerPerceptron}
+ */
+public class TestMultilayerPerceptron extends MahoutTestCase {
+
+ @Test
+ public void testMLP() throws IOException {
+ testMLP("testMLPXORLocal", false, false, 8000);
+ testMLP("testMLPXORLocalWithMomentum", true, false, 4000);
+ testMLP("testMLPXORLocalWithRegularization", true, true, 2000);
+ }
+
+ private void testMLP(String modelFilename, boolean useMomentum,
+ boolean useRegularization, int iterations) throws IOException {
+ MultilayerPerceptron mlp = new MultilayerPerceptron();
+ mlp.addLayer(2, false, "Sigmoid");
+ mlp.addLayer(3, false, "Sigmoid");
+ mlp.addLayer(1, true, "Sigmoid");
+ mlp.setCostFunction("Minus_Squared").setLearningRate(0.2);
+ if (useMomentum) {
+ mlp.setMomentumWeight(0.6);
+ }
+
+ if (useRegularization) {
+ mlp.setRegularizationWeight(0.01);
+ }
+
+ double[][] instances = { { 0, 1, 1 }, { 0, 0, 0 }, { 1, 0, 1 }, { 1, 1, 0 } };
+ for (int i = 0; i < iterations; ++i) {
+ for (double[] instance : instances) {
+ Vector features = new DenseVector(Arrays.copyOf(instance, instance.length - 1));
+ mlp.train((int) instance[2], features);
+ }
+ }
+
+ for (double[] instance : instances) {
+ Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+ // the expected output is the last element in array
+ double actual = instance[2];
+ double expected = mlp.getOutput(input).get(0);
+ assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+ }
+
+ // write model into file and read out
+ File modelFile = this.getTestTempFile(modelFilename);
+ mlp.setModelPath(modelFile.getAbsolutePath());
+ mlp.writeModelToFile();
+ mlp.close();
+
+ MultilayerPerceptron mlpCopy = new MultilayerPerceptron(modelFile.getAbsolutePath());
+ // test on instances
+ for (double[] instance : instances) {
+ Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+ // the expected output is the last element in array
+ double actual = instance[2];
+ double expected = mlpCopy.getOutput(input).get(0);
+ assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+ }
+ mlpCopy.close();
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java b/mr/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java
new file mode 100644
index 0000000..ebe5424
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/mlp/TestNeuralNetwork.java
@@ -0,0 +1,353 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.mlp;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.commons.csv.CSVUtils;
+import org.apache.mahout.classifier.mlp.NeuralNetwork.TrainingMethod;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Lists;
+import com.google.common.io.Files;
+
+/** Test the functionality of {@link NeuralNetwork}. */
+public class TestNeuralNetwork extends MahoutTestCase {
+
+
+ @Test
+ public void testReadWrite() throws IOException {
+ NeuralNetwork ann = new MultilayerPerceptron();
+ ann.addLayer(2, false, "Identity");
+ ann.addLayer(5, false, "Identity");
+ ann.addLayer(1, true, "Identity");
+ ann.setCostFunction("Minus_Squared");
+ double learningRate = 0.2;
+ double momentumWeight = 0.5;
+ double regularizationWeight = 0.05;
+ ann.setLearningRate(learningRate)
+ .setMomentumWeight(momentumWeight)
+ .setRegularizationWeight(regularizationWeight);
+
+ // Manually set weights
+ Matrix[] matrices = new DenseMatrix[2];
+ matrices[0] = new DenseMatrix(5, 3);
+ matrices[0].assign(0.2);
+ matrices[1] = new DenseMatrix(1, 6);
+ matrices[1].assign(0.8);
+ ann.setWeightMatrices(matrices);
+
+ // Write to file
+ String modelFilename = "testNeuralNetworkReadWrite";
+ File tmpModelFile = this.getTestTempFile(modelFilename);
+ ann.setModelPath(tmpModelFile.getAbsolutePath());
+ ann.writeModelToFile();
+
+ // Read from file
+ NeuralNetwork annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath());
+ assertEquals(annCopy.getClass().getSimpleName(), annCopy.getModelType());
+ assertEquals(tmpModelFile.getAbsolutePath(), annCopy.getModelPath());
+ assertEquals(learningRate, annCopy.getLearningRate(), EPSILON);
+ assertEquals(momentumWeight, annCopy.getMomentumWeight(), EPSILON);
+ assertEquals(regularizationWeight, annCopy.getRegularizationWeight(), EPSILON);
+ assertEquals(TrainingMethod.GRADIENT_DESCENT, annCopy.getTrainingMethod());
+
+ // Compare weights
+ Matrix[] weightsMatrices = annCopy.getWeightMatrices();
+ for (int i = 0; i < weightsMatrices.length; ++i) {
+ Matrix expectMat = matrices[i];
+ Matrix actualMat = weightsMatrices[i];
+ for (int j = 0; j < expectMat.rowSize(); ++j) {
+ for (int k = 0; k < expectMat.columnSize(); ++k) {
+ assertEquals(expectMat.get(j, k), actualMat.get(j, k), EPSILON);
+ }
+ }
+ }
+ }
+
+ /** Test the forward functionality. */
+ @Test
+ public void testOutput() {
+ // First network
+ NeuralNetwork ann = new MultilayerPerceptron();
+ ann.addLayer(2, false, "Identity");
+ ann.addLayer(5, false, "Identity");
+ ann.addLayer(1, true, "Identity");
+ ann.setCostFunction("Minus_Squared").setLearningRate(0.1);
+
+ // Intentionally initialize all weights to 0.5
+ Matrix[] matrices = new Matrix[2];
+ matrices[0] = new DenseMatrix(5, 3);
+ matrices[0].assign(0.5);
+ matrices[1] = new DenseMatrix(1, 6);
+ matrices[1].assign(0.5);
+ ann.setWeightMatrices(matrices);
+
+ double[] arr = new double[] { 0, 1 };
+ Vector training = new DenseVector(arr);
+ Vector result = ann.getOutput(training);
+ assertEquals(1, result.size());
+
+ // Second network
+ NeuralNetwork ann2 = new MultilayerPerceptron();
+ ann2.addLayer(2, false, "Sigmoid");
+ ann2.addLayer(3, false, "Sigmoid");
+ ann2.addLayer(1, true, "Sigmoid");
+ ann2.setCostFunction("Minus_Squared");
+ ann2.setLearningRate(0.3);
+
+ // Intentionally initialize all weights to 0.5
+ Matrix[] matrices2 = new Matrix[2];
+ matrices2[0] = new DenseMatrix(3, 3);
+ matrices2[0].assign(0.5);
+ matrices2[1] = new DenseMatrix(1, 4);
+ matrices2[1].assign(0.5);
+ ann2.setWeightMatrices(matrices2);
+
+ double[] test = { 0, 0 };
+ double[] result2 = { 0.807476 };
+
+ Vector vec = ann2.getOutput(new DenseVector(test));
+ double[] arrVec = new double[vec.size()];
+ for (int i = 0; i < arrVec.length; ++i) {
+ arrVec[i] = vec.getQuick(i);
+ }
+ assertArrayEquals(result2, arrVec, EPSILON);
+
+ NeuralNetwork ann3 = new MultilayerPerceptron();
+ ann3.addLayer(2, false, "Sigmoid");
+ ann3.addLayer(3, false, "Sigmoid");
+ ann3.addLayer(1, true, "Sigmoid");
+ ann3.setCostFunction("Minus_Squared").setLearningRate(0.3);
+
+ // Intentionally initialize all weights to 0.5
+ Matrix[] initMatrices = new Matrix[2];
+ initMatrices[0] = new DenseMatrix(3, 3);
+ initMatrices[0].assign(0.5);
+ initMatrices[1] = new DenseMatrix(1, 4);
+ initMatrices[1].assign(0.5);
+ ann3.setWeightMatrices(initMatrices);
+
+ double[] instance = {0, 1};
+ Vector output = ann3.getOutput(new DenseVector(instance));
+ assertEquals(0.8315410, output.get(0), EPSILON);
+ }
+
+ @Test
+ public void testNeuralNetwork() throws IOException {
+ testNeuralNetwork("testNeuralNetworkXORLocal", false, false, 10000);
+ testNeuralNetwork("testNeuralNetworkXORWithMomentum", true, false, 5000);
+ testNeuralNetwork("testNeuralNetworkXORWithRegularization", true, true, 5000);
+ }
+
+ private void testNeuralNetwork(String modelFilename, boolean useMomentum,
+ boolean useRegularization, int iterations) throws IOException {
+ NeuralNetwork ann = new MultilayerPerceptron();
+ ann.addLayer(2, false, "Sigmoid");
+ ann.addLayer(3, false, "Sigmoid");
+ ann.addLayer(1, true, "Sigmoid");
+ ann.setCostFunction("Minus_Squared").setLearningRate(0.1);
+
+ if (useMomentum) {
+ ann.setMomentumWeight(0.6);
+ }
+
+ if (useRegularization) {
+ ann.setRegularizationWeight(0.01);
+ }
+
+ double[][] instances = { { 0, 1, 1 }, { 0, 0, 0 }, { 1, 0, 1 }, { 1, 1, 0 } };
+ for (int i = 0; i < iterations; ++i) {
+ for (double[] instance : instances) {
+ ann.trainOnline(new DenseVector(instance));
+ }
+ }
+
+ for (double[] instance : instances) {
+ Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+ // The expected output is the last element in array
+ double actual = instance[2];
+ double expected = ann.getOutput(input).get(0);
+ assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+ }
+
+ // Write model into file and read out
+ File tmpModelFile = this.getTestTempFile(modelFilename);
+ ann.setModelPath(tmpModelFile.getAbsolutePath());
+ ann.writeModelToFile();
+
+ NeuralNetwork annCopy = new MultilayerPerceptron(tmpModelFile.getAbsolutePath());
+ // Test on instances
+ for (double[] instance : instances) {
+ Vector input = new DenseVector(instance).viewPart(0, instance.length - 1);
+ // The expected output is the last element in array
+ double actual = instance[2];
+ double expected = annCopy.getOutput(input).get(0);
+ assertTrue(actual < 0.5 && expected < 0.5 || actual >= 0.5 && expected >= 0.5);
+ }
+ }
+
+ @Test
+ public void testWithCancerDataSet() throws IOException {
+
+ File cancerDataset = getTestTempFile("cancer.csv");
+ writeLines(cancerDataset, Datasets.CANCER);
+
+ List<Vector> records = Lists.newArrayList();
+ // Returns a mutable list of the data
+ List<String> cancerDataSetList = Files.readLines(cancerDataset, Charsets.UTF_8);
+ // Skip the header line, hence remove the first element in the list
+ cancerDataSetList.remove(0);
+ for (String line : cancerDataSetList) {
+ String[] tokens = CSVUtils.parseLine(line);
+ double[] values = new double[tokens.length];
+ for (int i = 0; i < tokens.length; ++i) {
+ values[i] = Double.parseDouble(tokens[i]);
+ }
+ records.add(new DenseVector(values));
+ }
+
+ int splitPoint = (int) (records.size() * 0.8);
+ List<Vector> trainingSet = records.subList(0, splitPoint);
+ List<Vector> testSet = records.subList(splitPoint, records.size());
+
+ // initialize neural network model
+ NeuralNetwork ann = new MultilayerPerceptron();
+ int featureDimension = records.get(0).size() - 1;
+ ann.addLayer(featureDimension, false, "Sigmoid");
+ ann.addLayer(featureDimension * 2, false, "Sigmoid");
+ ann.addLayer(1, true, "Sigmoid");
+ ann.setLearningRate(0.05).setMomentumWeight(0.5).setRegularizationWeight(0.001);
+
+ int iteration = 2000;
+ for (int i = 0; i < iteration; ++i) {
+ for (Vector trainingInstance : trainingSet) {
+ ann.trainOnline(trainingInstance);
+ }
+ }
+
+ int correctInstances = 0;
+ for (Vector testInstance : testSet) {
+ Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - 1));
+ double actual = res.get(0);
+ double expected = testInstance.get(testInstance.size() - 1);
+ if (Math.abs(actual - expected) <= 0.1) {
+ ++correctInstances;
+ }
+ }
+ double accuracy = (double) correctInstances / testSet.size() * 100;
+ assertTrue("The classifier is even worse than a random guesser!", accuracy > 50);
+ System.out.printf("Cancer DataSet. Classification precision: %d/%d = %f%%\n", correctInstances, testSet.size(), accuracy);
+ }
+
+ @Test
+ public void testWithIrisDataSet() throws IOException {
+
+ File irisDataset = getTestTempFile("iris.csv");
+ writeLines(irisDataset, Datasets.IRIS);
+
+ int numOfClasses = 3;
+ List<Vector> records = Lists.newArrayList();
+ // Returns a mutable list of the data
+ List<String> irisDataSetList = Files.readLines(irisDataset, Charsets.UTF_8);
+ // Skip the header line, hence remove the first element in the list
+ irisDataSetList.remove(0);
+
+ for (String line : irisDataSetList) {
+ String[] tokens = CSVUtils.parseLine(line);
+ // Last three dimensions represent the labels
+ double[] values = new double[tokens.length + numOfClasses - 1];
+ Arrays.fill(values, 0.0);
+ for (int i = 0; i < tokens.length - 1; ++i) {
+ values[i] = Double.parseDouble(tokens[i]);
+ }
+ // Add label values
+ String label = tokens[tokens.length - 1];
+ if (label.equalsIgnoreCase("setosa")) {
+ values[values.length - 3] = 1;
+ } else if (label.equalsIgnoreCase("versicolor")) {
+ values[values.length - 2] = 1;
+ } else { // label 'virginica'
+ values[values.length - 1] = 1;
+ }
+ records.add(new DenseVector(values));
+ }
+
+ Collections.shuffle(records);
+
+ int splitPoint = (int) (records.size() * 0.8);
+ List<Vector> trainingSet = records.subList(0, splitPoint);
+ List<Vector> testSet = records.subList(splitPoint, records.size());
+
+ // Initialize neural network model
+ NeuralNetwork ann = new MultilayerPerceptron();
+ int featureDimension = records.get(0).size() - numOfClasses;
+ ann.addLayer(featureDimension, false, "Sigmoid");
+ ann.addLayer(featureDimension * 2, false, "Sigmoid");
+ ann.addLayer(3, true, "Sigmoid"); // 3-class classification
+ ann.setLearningRate(0.05).setMomentumWeight(0.4).setRegularizationWeight(0.005);
+
+ int iteration = 2000;
+ for (int i = 0; i < iteration; ++i) {
+ for (Vector trainingInstance : trainingSet) {
+ ann.trainOnline(trainingInstance);
+ }
+ }
+
+ int correctInstances = 0;
+ for (Vector testInstance : testSet) {
+ Vector res = ann.getOutput(testInstance.viewPart(0, testInstance.size() - numOfClasses));
+ double[] actualLabels = new double[numOfClasses];
+ for (int i = 0; i < numOfClasses; ++i) {
+ actualLabels[i] = res.get(i);
+ }
+ double[] expectedLabels = new double[numOfClasses];
+ for (int i = 0; i < numOfClasses; ++i) {
+ expectedLabels[i] = testInstance.get(testInstance.size() - numOfClasses + i);
+ }
+
+ boolean allCorrect = true;
+ for (int i = 0; i < numOfClasses; ++i) {
+ if (Math.abs(expectedLabels[i] - actualLabels[i]) >= 0.1) {
+ allCorrect = false;
+ break;
+ }
+ }
+ if (allCorrect) {
+ ++correctInstances;
+ }
+ }
+
+ double accuracy = (double) correctInstances / testSet.size() * 100;
+ assertTrue("The model is even worse than a random guesser.", accuracy > 50);
+
+ System.out.printf("Iris DataSet. Classification precision: %d/%d = %f%%\n",
+ correctInstances, testSet.size(), accuracy);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptronTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptronTest.java b/mr/src/test/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptronTest.java
new file mode 100644
index 0000000..b905509
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/mlp/TrainMultilayerPerceptronTest.java
@@ -0,0 +1,105 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.mlp;
+
+import java.io.File;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public class TrainMultilayerPerceptronTest extends MahoutTestCase {
+
+ @Test
+ public void testIrisDataset() throws Exception {
+ String modelFileName = "mlp.model";
+ File modelFile = getTestTempFile(modelFileName);
+
+ File irisDataset = getTestTempFile("iris.csv");
+ writeLines(irisDataset, Datasets.IRIS);
+
+ String[] args = {
+ "-i", irisDataset.getAbsolutePath(),
+ "-sh",
+ "-labels", "setosa", "versicolor", "virginica",
+ "-mo", modelFile.getAbsolutePath(),
+ "-u",
+ "-ls", "4", "8", "3"
+ };
+
+ TrainMultilayerPerceptron.main(args);
+
+ assertTrue(modelFile.exists());
+ }
+
+ @Test
+ public void initializeModelWithDifferentParameters() throws Exception {
+ String modelFileName = "mlp.model";
+ File modelFile1 = getTestTempFile(modelFileName);
+
+ File irisDataset = getTestTempFile("iris.csv");
+ writeLines(irisDataset, Datasets.IRIS);
+
+ String[] args1 = {
+ "-i", irisDataset.getAbsolutePath(),
+ "-sh",
+ "-labels", "setosa", "versicolor", "virginica",
+ "-mo", modelFile1.getAbsolutePath(),
+ "-u",
+ "-ls", "4", "8", "3",
+ "-l", "0.2", "-m", "0.35", "-r", "0.0001"
+ };
+
+ MultilayerPerceptron mlp1 = trainModel(args1, modelFile1);
+ assertEquals(0.2, mlp1.getLearningRate(), EPSILON);
+ assertEquals(0.35, mlp1.getMomentumWeight(), EPSILON);
+ assertEquals(0.0001, mlp1.getRegularizationWeight(), EPSILON);
+
+ assertEquals(4, mlp1.getLayerSize(0) - 1);
+ assertEquals(8, mlp1.getLayerSize(1) - 1);
+ assertEquals(3, mlp1.getLayerSize(2)); // Final layer has no bias neuron
+
+ // MLP with default learning rate, momemtum weight, and regularization weight
+ File modelFile2 = this.getTestTempFile(modelFileName);
+
+ String[] args2 = {
+ "-i", irisDataset.getAbsolutePath(),
+ "-sh",
+ "-labels", "setosa", "versicolor", "virginica",
+ "-mo", modelFile2.getAbsolutePath(),
+ "-ls", "4", "10", "18", "3"
+ };
+
+ MultilayerPerceptron mlp2 = trainModel(args2, modelFile2);
+ assertEquals(0.5, mlp2.getLearningRate(), EPSILON);
+ assertEquals(0.1, mlp2.getMomentumWeight(), EPSILON);
+ assertEquals(0, mlp2.getRegularizationWeight(), EPSILON);
+
+ assertEquals(4, mlp2.getLayerSize(0) - 1);
+ assertEquals(10, mlp2.getLayerSize(1) - 1);
+ assertEquals(18, mlp2.getLayerSize(2) - 1);
+ assertEquals(3, mlp2.getLayerSize(3)); // Final layer has no bias neuron
+
+ }
+
+ private MultilayerPerceptron trainModel(String[] args, File modelFile) throws Exception {
+ TrainMultilayerPerceptron.main(args);
+ return new MultilayerPerceptron(modelFile.getAbsolutePath());
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
new file mode 100644
index 0000000..f658738
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.math.DenseVector;
+import org.junit.Before;
+import org.junit.Test;
+
+public final class ComplementaryNaiveBayesClassifierTest extends NaiveBayesTestBase {
+
+ private ComplementaryNaiveBayesClassifier classifier;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ NaiveBayesModel model = createComplementaryNaiveBayesModel();
+ classifier = new ComplementaryNaiveBayesClassifier(model);
+ }
+
+ @Test
+ public void testNaiveBayes() throws Exception {
+ assertEquals(4, classifier.numCategories());
+ assertEquals(0, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
+ assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
+ assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
+ assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
new file mode 100644
index 0000000..3b83492
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.junit.Test;
+
+public class NaiveBayesModelTest extends NaiveBayesTestBase {
+
+ @Test
+ public void testRandomModelGeneration() {
+ // make sure we generate a valid random model
+ NaiveBayesModel standardModel = getStandardModel();
+ // check whether the model is valid
+ standardModel.validate();
+
+ // same for Complementary
+ NaiveBayesModel complementaryModel = getComplementaryModel();
+ complementaryModel.validate();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
new file mode 100644
index 0000000..974b90c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java
@@ -0,0 +1,135 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import java.io.File;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.MathHelper;
+import org.junit.Before;
+import org.junit.Test;
+
+public class NaiveBayesTest extends MahoutTestCase {
+
+ private Configuration conf;
+ private File inputFile;
+ private File outputDir;
+ private File tempDir;
+
+ static final Text LABEL_STOLEN = new Text("/stolen/");
+ static final Text LABEL_NOT_STOLEN = new Text("/not_stolen/");
+
+ static final Vector.Element COLOR_RED = MathHelper.elem(0, 1);
+ static final Vector.Element COLOR_YELLOW = MathHelper.elem(1, 1);
+ static final Vector.Element TYPE_SPORTS = MathHelper.elem(2, 1);
+ static final Vector.Element TYPE_SUV = MathHelper.elem(3, 1);
+ static final Vector.Element ORIGIN_DOMESTIC = MathHelper.elem(4, 1);
+ static final Vector.Element ORIGIN_IMPORTED = MathHelper.elem(5, 1);
+
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+
+ conf = getConfiguration();
+
+ inputFile = getTestTempFile("trainingInstances.seq");
+ outputDir = getTestTempDir("output");
+ outputDir.delete();
+ tempDir = getTestTempDir("tmp");
+
+ SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(conf), conf,
+ new Path(inputFile.getAbsolutePath()), Text.class, VectorWritable.class);
+
+ try {
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC));
+ writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED));
+ writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED));
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ @Test
+ public void toyData() throws Exception {
+ TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
+ trainNaiveBayes.setConf(conf);
+ trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "-el", "--tempDir", tempDir.getAbsolutePath() });
+
+ NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf);
+
+ AbstractVectorClassifier classifier = new StandardNaiveBayesClassifier(naiveBayesModel);
+
+ assertEquals(2, classifier.numCategories());
+
+ Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
+
+ // should be classified as not stolen
+ assertTrue(prediction.get(0) < prediction.get(1));
+ }
+
+ @Test
+ public void toyDataComplementary() throws Exception {
+ TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob();
+ trainNaiveBayes.setConf(conf);
+ trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(),
+ "-el", "--trainComplementary",
+ "--tempDir", tempDir.getAbsolutePath() });
+
+ NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf);
+
+ AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel);
+
+ assertEquals(2, classifier.numCategories());
+
+ Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get());
+
+ // should be classified as not stolen
+ assertTrue(prediction.get(0) < prediction.get(1));
+ }
+
+ static VectorWritable trainingInstance(Vector.Element... elems) {
+ DenseVector trainingInstance = new DenseVector(6);
+ for (Vector.Element elem : elems) {
+ trainingInstance.set(elem.index(), elem.get());
+ }
+ return new VectorWritable(trainingInstance);
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
new file mode 100644
index 0000000..a943b7b
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java
@@ -0,0 +1,135 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+
+public abstract class NaiveBayesTestBase extends MahoutTestCase {
+
+ private NaiveBayesModel standardModel;
+ private NaiveBayesModel complementaryModel;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ standardModel = createStandardNaiveBayesModel();
+ standardModel.validate();
+ complementaryModel = createComplementaryNaiveBayesModel();
+ complementaryModel.validate();
+ }
+
+ protected NaiveBayesModel getStandardModel() {
+ return standardModel;
+ }
+ protected NaiveBayesModel getComplementaryModel() {
+ return complementaryModel;
+ }
+
+ protected static double complementaryNaiveBayesThetaWeight(int label,
+ Matrix weightMatrix,
+ Vector labelSum,
+ Vector featureSum) {
+ double weight = 0.0;
+ double alpha = 1.0;
+ for (int i = 0; i < featureSum.size(); i++) {
+ double score = weightMatrix.get(i, label);
+ double lSum = labelSum.get(label);
+ double fSum = featureSum.get(i);
+ double totalSum = featureSum.zSum();
+ double numerator = fSum - score + alpha;
+ double denominator = totalSum - lSum + featureSum.size();
+ weight += Math.abs(Math.log(numerator / denominator));
+ }
+ return weight;
+ }
+
+ protected static double naiveBayesThetaWeight(int label,
+ Matrix weightMatrix,
+ Vector labelSum,
+ Vector featureSum) {
+ double weight = 0.0;
+ double alpha = 1.0;
+ for (int feature = 0; feature < featureSum.size(); feature++) {
+ double score = weightMatrix.get(feature, label);
+ double lSum = labelSum.get(label);
+ double numerator = score + alpha;
+ double denominator = lSum + featureSum.size();
+ weight += Math.abs(Math.log(numerator / denominator));
+ }
+ return weight;
+ }
+
+ protected static NaiveBayesModel createStandardNaiveBayesModel() {
+ double[][] matrix = {
+ { 0.7, 0.1, 0.1, 0.3 },
+ { 0.4, 0.4, 0.1, 0.1 },
+ { 0.1, 0.0, 0.8, 0.1 },
+ { 0.1, 0.1, 0.1, 0.7 } };
+
+ double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 };
+ double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 };
+
+ DenseMatrix weightMatrix = new DenseMatrix(matrix);
+ DenseVector labelSum = new DenseVector(labelSumArray);
+ DenseVector featureSum = new DenseVector(featureSumArray);
+
+ // now generate the model
+ return new NaiveBayesModel(weightMatrix, featureSum, labelSum, null, 1.0f, false);
+ }
+
+ protected static NaiveBayesModel createComplementaryNaiveBayesModel() {
+ double[][] matrix = {
+ { 0.7, 0.1, 0.1, 0.3 },
+ { 0.4, 0.4, 0.1, 0.1 },
+ { 0.1, 0.0, 0.8, 0.1 },
+ { 0.1, 0.1, 0.1, 0.7 } };
+
+ double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 };
+ double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 };
+
+ DenseMatrix weightMatrix = new DenseMatrix(matrix);
+ DenseVector labelSum = new DenseVector(labelSumArray);
+ DenseVector featureSum = new DenseVector(featureSumArray);
+
+ double[] thetaNormalizerSum = {
+ complementaryNaiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum),
+ complementaryNaiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum),
+ complementaryNaiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum),
+ complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) };
+
+ // now generate the model
+ return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum), 1.0f, true);
+ }
+
+ protected static int maxIndex(Vector instance) {
+ int maxIndex = -1;
+ double maxScore = Integer.MIN_VALUE;
+ for (Element label : instance.all()) {
+ if (label.get() >= maxScore) {
+ maxIndex = label.index();
+ maxScore = label.get();
+ }
+ }
+ return maxIndex;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
new file mode 100644
index 0000000..a432ac9
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes;
+
+import org.apache.mahout.math.DenseVector;
+import org.junit.Before;
+import org.junit.Test;
+
+
+public final class StandardNaiveBayesClassifierTest extends NaiveBayesTestBase {
+
+ private StandardNaiveBayesClassifier classifier;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ NaiveBayesModel model = createStandardNaiveBayesModel();
+ classifier = new StandardNaiveBayesClassifier(model);
+ }
+
+ @Test
+ public void testNaiveBayes() throws Exception {
+ assertEquals(4, classifier.numCategories());
+ assertEquals(0, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 }))));
+ assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 }))));
+ assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 }))));
+ assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 }))));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
new file mode 100644
index 0000000..a9541c9
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java
@@ -0,0 +1,85 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Counter;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+import org.easymock.EasyMock;
+import org.junit.Before;
+import org.junit.Test;
+
+public class IndexInstancesMapperTest extends MahoutTestCase {
+
+ private Mapper.Context ctx;
+ private OpenObjectIntHashMap<String> labelIndex;
+ private VectorWritable instance;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+
+ ctx = EasyMock.createMock(Mapper.Context.class);
+ instance = new VectorWritable(new DenseVector(new double[] { 1, 0, 1, 1, 0 }));
+
+ labelIndex = new OpenObjectIntHashMap<String>();
+ labelIndex.put("bird", 0);
+ labelIndex.put("cat", 1);
+ }
+
+
+ @Test
+ public void index() throws Exception {
+
+ ctx.write(new IntWritable(0), instance);
+
+ EasyMock.replay(ctx);
+
+ IndexInstancesMapper indexInstances = new IndexInstancesMapper();
+ setField(indexInstances, "labelIndex", labelIndex);
+
+ indexInstances.map(new Text("/bird/"), instance, ctx);
+
+ EasyMock.verify(ctx);
+ }
+
+ @Test
+ public void skip() throws Exception {
+
+ Counter skippedInstances = EasyMock.createMock(Counter.class);
+
+ EasyMock.expect(ctx.getCounter(IndexInstancesMapper.Counter.SKIPPED_INSTANCES)).andReturn(skippedInstances);
+ skippedInstances.increment(1);
+
+ EasyMock.replay(ctx, skippedInstances);
+
+ IndexInstancesMapper indexInstances = new IndexInstancesMapper();
+ setField(indexInstances, "labelIndex", labelIndex);
+
+ indexInstances.map(new Text("/fish/"), instance, ctx);
+
+ EasyMock.verify(ctx, skippedInstances);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java
new file mode 100644
index 0000000..746ae0d
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java
@@ -0,0 +1,61 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+public class ThetaMapperTest extends MahoutTestCase {
+
+ @Test
+ public void standard() throws Exception {
+
+ Mapper.Context ctx = EasyMock.createMock(Mapper.Context.class);
+ ComplementaryThetaTrainer trainer = EasyMock.createMock(ComplementaryThetaTrainer.class);
+
+ Vector instance1 = new DenseVector(new double[] { 1, 2, 3 });
+ Vector instance2 = new DenseVector(new double[] { 4, 5, 6 });
+
+ Vector perLabelThetaNormalizer = new DenseVector(new double[] { 7, 8 });
+
+ ThetaMapper thetaMapper = new ThetaMapper();
+ setField(thetaMapper, "trainer", trainer);
+
+ trainer.train(0, instance1);
+ trainer.train(1, instance2);
+ EasyMock.expect(trainer.retrievePerLabelThetaNormalizer()).andReturn(perLabelThetaNormalizer);
+ ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer));
+
+ EasyMock.replay(ctx, trainer);
+
+ thetaMapper.map(new IntWritable(0), new VectorWritable(instance1), ctx);
+ thetaMapper.map(new IntWritable(1), new VectorWritable(instance2), ctx);
+ thetaMapper.cleanup(ctx);
+
+ EasyMock.verify(ctx, trainer);
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java
new file mode 100644
index 0000000..af0b464
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes.training;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+public class WeightsMapperTest extends MahoutTestCase {
+
+ @Test
+ public void scores() throws Exception {
+
+ Mapper.Context ctx = EasyMock.createMock(Mapper.Context.class);
+ Vector instance1 = new DenseVector(new double[] { 1, 0, 0.5, 0.5, 0 });
+ Vector instance2 = new DenseVector(new double[] { 0, 0.5, 0, 0, 0 });
+ Vector instance3 = new DenseVector(new double[] { 1, 0.5, 1, 1.5, 1 });
+
+ Vector weightsPerLabel = new DenseVector(new double[] { 0, 0 });
+
+ ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE),
+ new VectorWritable(new DenseVector(new double[] { 2, 1, 1.5, 2, 1 })));
+ ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL),
+ new VectorWritable(new DenseVector(new double[] { 2.5, 5 })));
+
+ EasyMock.replay(ctx);
+
+ WeightsMapper weights = new WeightsMapper();
+ setField(weights, "weightsPerLabel", weightsPerLabel);
+
+ weights.map(new IntWritable(0), new VectorWritable(instance1), ctx);
+ weights.map(new IntWritable(0), new VectorWritable(instance2), ctx);
+ weights.map(new IntWritable(1), new VectorWritable(instance3), ctx);
+
+ weights.cleanup(ctx);
+
+ EasyMock.verify(ctx);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java
new file mode 100644
index 0000000..ade25b8
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java
@@ -0,0 +1,164 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.math.Matrix;
+import org.junit.Test;
+
+public class HMMAlgorithmsTest extends HMMTestBase {
+
+ /**
+ * Test the forward algorithm by comparing the alpha values with the values
+ * obtained from HMM R model. We test the test observation sequence "O1" "O0"
+ * "O2" "O2" "O0" "O0" "O1" by comparing the generated alpha values to the
+ * R-generated "reference".
+ */
+ @Test
+ public void testForwardAlgorithm() {
+ // intialize the expected alpha values
+ double[][] alphaExpectedA = {
+ {0.02, 0.0392, 0.002438, 0.00035456, 0.0011554672, 7.158497e-04,
+ 4.614927e-05},
+ {0.01, 0.0054, 0.001824, 0.00069486, 0.0007586904, 2.514137e-04,
+ 1.721505e-05},
+ {0.32, 0.0262, 0.002542, 0.00038026, 0.0001360234, 3.002345e-05,
+ 9.659608e-05},
+ {0.03, 0.0000, 0.013428, 0.00951084, 0.0000000000, 0.000000e+00,
+ 2.428986e-05},};
+ // fetch the alpha matrix using the forward algorithm
+ Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), false);
+ // first do some basic checking
+ assertNotNull(alpha);
+ assertEquals(4, alpha.numCols());
+ assertEquals(7, alpha.numRows());
+ // now compare the resulting matrices
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 7; ++j) {
+ assertEquals(alphaExpectedA[i][j], alpha.get(j, i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testLogScaledForwardAlgorithm() {
+ // intialize the expected alpha values
+ double[][] alphaExpectedA = {
+ {0.02, 0.0392, 0.002438, 0.00035456, 0.0011554672, 7.158497e-04,
+ 4.614927e-05},
+ {0.01, 0.0054, 0.001824, 0.00069486, 0.0007586904, 2.514137e-04,
+ 1.721505e-05},
+ {0.32, 0.0262, 0.002542, 0.00038026, 0.0001360234, 3.002345e-05,
+ 9.659608e-05},
+ {0.03, 0.0000, 0.013428, 0.00951084, 0.0000000000, 0.000000e+00,
+ 2.428986e-05},};
+ // fetch the alpha matrix using the forward algorithm
+ Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), true);
+ // first do some basic checking
+ assertNotNull(alpha);
+ assertEquals(4, alpha.numCols());
+ assertEquals(7, alpha.numRows());
+ // now compare the resulting matrices
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 7; ++j) {
+ assertEquals(Math.log(alphaExpectedA[i][j]), alpha.get(j, i), EPSILON);
+ }
+ }
+ }
+
+ /**
+ * Test the backward algorithm by comparing the beta values with the values
+ * obtained from HMM R model. We test the following observation sequence "O1"
+ * "O0" "O2" "O2" "O0" "O0" "O1" by comparing the generated beta values to the
+ * R-generated "reference".
+ */
+ @Test
+ public void testBackwardAlgorithm() {
+ // intialize the expected beta values
+ double[][] betaExpectedA = {
+ {0.0015730559, 0.003543656, 0.00738264, 0.040692, 0.0848, 0.17, 1},
+ {0.0017191865, 0.002386795, 0.00923652, 0.052232, 0.1018, 0.17, 1},
+ {0.0003825772, 0.001238558, 0.00259464, 0.012096, 0.0664, 0.66, 1},
+ {0.0004390858, 0.007076994, 0.01063512, 0.013556, 0.0304, 0.17, 1}};
+ // fetch the beta matrix using the backward algorithm
+ Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), false);
+ // first do some basic checking
+ assertNotNull(beta);
+ assertEquals(4, beta.numCols());
+ assertEquals(7, beta.numRows());
+ // now compare the resulting matrices
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 7; ++j) {
+ assertEquals(betaExpectedA[i][j], beta.get(j, i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testLogScaledBackwardAlgorithm() {
+ // intialize the expected beta values
+ double[][] betaExpectedA = {
+ {0.0015730559, 0.003543656, 0.00738264, 0.040692, 0.0848, 0.17, 1},
+ {0.0017191865, 0.002386795, 0.00923652, 0.052232, 0.1018, 0.17, 1},
+ {0.0003825772, 0.001238558, 0.00259464, 0.012096, 0.0664, 0.66, 1},
+ {0.0004390858, 0.007076994, 0.01063512, 0.013556, 0.0304, 0.17, 1}};
+ // fetch the beta matrix using the backward algorithm
+ Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), true);
+ // first do some basic checking
+ assertNotNull(beta);
+ assertEquals(4, beta.numCols());
+ assertEquals(7, beta.numRows());
+ // now compare the resulting matrices
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 7; ++j) {
+ assertEquals(Math.log(betaExpectedA[i][j]), beta.get(j, i), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void testViterbiAlgorithm() {
+ // initialize the expected hidden sequence
+ int[] expected = {2, 0, 3, 3, 0, 0, 2};
+ // fetch the viterbi generated sequence
+ int[] computed = HmmAlgorithms.viterbiAlgorithm(getModel(), getSequence(), false);
+ // first make sure we return the correct size
+ assertNotNull(computed);
+ assertEquals(computed.length, getSequence().length);
+ // now check the contents
+ for (int i = 0; i < getSequence().length; ++i) {
+ assertEquals(expected[i], computed[i]);
+ }
+ }
+
+ @Test
+ public void testLogScaledViterbiAlgorithm() {
+ // initialize the expected hidden sequence
+ int[] expected = {2, 0, 3, 3, 0, 0, 2};
+ // fetch the viterbi generated sequence
+ int[] computed = HmmAlgorithms.viterbiAlgorithm(getModel(), getSequence(), true);
+ // first make sure we return the correct size
+ assertNotNull(computed);
+ assertEquals(computed.length, getSequence().length);
+ // now check the contents
+ for (int i = 0; i < getSequence().length; ++i) {
+ assertEquals(expected[i], computed[i]);
+ }
+
+ }
+
+}
[39/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarity.java
new file mode 100644
index 0000000..712b96a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemSimilarity.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.file;
+
+import java.io.File;
+import java.util.Collection;
+import java.util.concurrent.locks.ReentrantLock;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * An {@link ItemSimilarity} backed by a comma-delimited file. This class typically expects a file where each line
+ * contains an item ID, followed by another item ID, followed by a similarity value, separated by commas. You may also
+ * use tabs.
+ * </p>
+ *
+ * <p>
+ * The similarity value is assumed to be parseable as a {@code double} having a value between -1 and 1. The
+ * item IDs are parsed as {@code long}s. Similarities are symmetric so for a pair of items you do not have to
+ * include 2 lines in the file.
+ * </p>
+ *
+ * <p>
+ * This class will reload data from the data file when {@link #refresh(Collection)} is called, unless the file
+ * has been reloaded very recently already.
+ * </p>
+ *
+ * <p>
+ * This class is not intended for use with very large amounts of data. For that, a JDBC-backed {@link ItemSimilarity}
+ * and a database are more appropriate.
+ * </p>
+ */
+public class FileItemSimilarity implements ItemSimilarity {
+
+ public static final long DEFAULT_MIN_RELOAD_INTERVAL_MS = 60 * 1000L; // 1 minute?
+
+ private ItemSimilarity delegate;
+ private final ReentrantLock reloadLock;
+ private final File dataFile;
+ private long lastModified;
+ private final long minReloadIntervalMS;
+
+ private static final Logger log = LoggerFactory.getLogger(FileItemSimilarity.class);
+
+ /**
+ * @param dataFile
+ * file containing the similarity data
+ */
+ public FileItemSimilarity(File dataFile) {
+ this(dataFile, DEFAULT_MIN_RELOAD_INTERVAL_MS);
+ }
+
+ /**
+ * @param minReloadIntervalMS
+ * the minimum interval in milliseconds after which a full reload of the original datafile is done
+ * when refresh() is called
+ * @see #FileItemSimilarity(File)
+ */
+ public FileItemSimilarity(File dataFile, long minReloadIntervalMS) {
+ Preconditions.checkArgument(dataFile != null, "dataFile is null");
+ Preconditions.checkArgument(dataFile.exists() && !dataFile.isDirectory(),
+ "dataFile is missing or a directory: %s", dataFile);
+
+ log.info("Creating FileItemSimilarity for file {}", dataFile);
+
+ this.dataFile = dataFile.getAbsoluteFile();
+ this.lastModified = dataFile.lastModified();
+ this.minReloadIntervalMS = minReloadIntervalMS;
+ this.reloadLock = new ReentrantLock();
+
+ reload();
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
+ return delegate.itemSimilarities(itemID1, itemID2s);
+ }
+
+ @Override
+ public long[] allSimilarItemIDs(long itemID) throws TasteException {
+ return delegate.allSimilarItemIDs(itemID);
+ }
+
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
+ return delegate.itemSimilarity(itemID1, itemID2);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ if (dataFile.lastModified() > lastModified + minReloadIntervalMS) {
+ log.debug("File has changed; reloading...");
+ reload();
+ }
+ }
+
+ protected void reload() {
+ if (reloadLock.tryLock()) {
+ try {
+ long newLastModified = dataFile.lastModified();
+ delegate = new GenericItemSimilarity(new FileItemItemSimilarityIterable(dataFile));
+ lastModified = newLastModified;
+ } finally {
+ reloadLock.unlock();
+ }
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "FileItemSimilarity[dataFile:" + dataFile + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/precompute/FileSimilarItemsWriter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/precompute/FileSimilarItemsWriter.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/precompute/FileSimilarItemsWriter.java
new file mode 100644
index 0000000..ca0e0b2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/precompute/FileSimilarItemsWriter.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.precompute;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+
+import com.google.common.base.Charsets;
+import com.google.common.io.Closeables;
+import org.apache.mahout.cf.taste.similarity.precompute.SimilarItem;
+import org.apache.mahout.cf.taste.similarity.precompute.SimilarItems;
+import org.apache.mahout.cf.taste.similarity.precompute.SimilarItemsWriter;
+
+/**
+ * Persist the precomputed item similarities to a file that can later be used
+ * by a {@link org.apache.mahout.cf.taste.impl.similarity.file.FileItemSimilarity}
+ */
+public class FileSimilarItemsWriter implements SimilarItemsWriter {
+
+ private final File file;
+ private BufferedWriter writer;
+
+ public FileSimilarItemsWriter(File file) {
+ this.file = file;
+ }
+
+ @Override
+ public void open() throws IOException {
+ writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), Charsets.UTF_8));
+ }
+
+ @Override
+ public void add(SimilarItems similarItems) throws IOException {
+ String itemID = String.valueOf(similarItems.getItemID());
+ for (SimilarItem similarItem : similarItems.getSimilarItems()) {
+ writer.write(itemID);
+ writer.write(',');
+ writer.write(String.valueOf(similarItem.getItemID()));
+ writer.write(',');
+ writer.write(String.valueOf(similarItem.getSimilarity()));
+ writer.newLine();
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ Closeables.close(writer, false);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilarities.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilarities.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilarities.java
new file mode 100644
index 0000000..09ca57a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/precompute/MultithreadedBatchItemSimilarities.java
@@ -0,0 +1,230 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.precompute;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.similarity.precompute.BatchItemSimilarities;
+import org.apache.mahout.cf.taste.similarity.precompute.SimilarItems;
+import org.apache.mahout.cf.taste.similarity.precompute.SimilarItemsWriter;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Precompute item similarities in parallel on a single machine. The recommender given to this class must use a
+ * DataModel that holds the interactions in memory (such as
+ * {@link org.apache.mahout.cf.taste.impl.model.GenericDataModel} or
+ * {@link org.apache.mahout.cf.taste.impl.model.file.FileDataModel}) as fast random access to the data is required
+ */
+public class MultithreadedBatchItemSimilarities extends BatchItemSimilarities {
+
+ private int batchSize;
+
+ private static final int DEFAULT_BATCH_SIZE = 100;
+
+ private static final Logger log = LoggerFactory.getLogger(MultithreadedBatchItemSimilarities.class);
+
+ /**
+ * @param recommender recommender to use
+ * @param similarItemsPerItem number of similar items to compute per item
+ */
+ public MultithreadedBatchItemSimilarities(ItemBasedRecommender recommender, int similarItemsPerItem) {
+ this(recommender, similarItemsPerItem, DEFAULT_BATCH_SIZE);
+ }
+
+ /**
+ * @param recommender recommender to use
+ * @param similarItemsPerItem number of similar items to compute per item
+ * @param batchSize size of item batches sent to worker threads
+ */
+ public MultithreadedBatchItemSimilarities(ItemBasedRecommender recommender, int similarItemsPerItem, int batchSize) {
+ super(recommender, similarItemsPerItem);
+ this.batchSize = batchSize;
+ }
+
+ @Override
+ public int computeItemSimilarities(int degreeOfParallelism, int maxDurationInHours, SimilarItemsWriter writer)
+ throws IOException {
+
+ ExecutorService executorService = Executors.newFixedThreadPool(degreeOfParallelism + 1);
+
+ Output output = null;
+ try {
+ writer.open();
+
+ DataModel dataModel = getRecommender().getDataModel();
+
+ BlockingQueue<long[]> itemsIDsInBatches = queueItemIDsInBatches(dataModel, batchSize, degreeOfParallelism);
+ BlockingQueue<List<SimilarItems>> results = new LinkedBlockingQueue<>();
+
+ AtomicInteger numActiveWorkers = new AtomicInteger(degreeOfParallelism);
+ for (int n = 0; n < degreeOfParallelism; n++) {
+ executorService.execute(new SimilarItemsWorker(n, itemsIDsInBatches, results, numActiveWorkers));
+ }
+
+ output = new Output(results, writer, numActiveWorkers);
+ executorService.execute(output);
+
+ } catch (Exception e) {
+ throw new IOException(e);
+ } finally {
+ executorService.shutdown();
+ try {
+ boolean succeeded = executorService.awaitTermination(maxDurationInHours, TimeUnit.HOURS);
+ if (!succeeded) {
+ throw new RuntimeException("Unable to complete the computation in " + maxDurationInHours + " hours!");
+ }
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ Closeables.close(writer, false);
+ }
+
+ return output.getNumSimilaritiesProcessed();
+ }
+
+ private static BlockingQueue<long[]> queueItemIDsInBatches(DataModel dataModel, int batchSize,
+ int degreeOfParallelism) throws TasteException {
+
+ LongPrimitiveIterator itemIDs = dataModel.getItemIDs();
+ int numItems = dataModel.getNumItems();
+
+ BlockingQueue<long[]> itemIDBatches = new LinkedBlockingQueue<>((numItems / batchSize) + 1);
+
+ long[] batch = new long[batchSize];
+ int pos = 0;
+ while (itemIDs.hasNext()) {
+ batch[pos] = itemIDs.nextLong();
+ pos++;
+ if (pos == batchSize) {
+ itemIDBatches.add(batch.clone());
+ pos = 0;
+ }
+ }
+
+ if (pos > 0) {
+ long[] lastBatch = new long[pos];
+ System.arraycopy(batch, 0, lastBatch, 0, pos);
+ itemIDBatches.add(lastBatch);
+ }
+
+ if (itemIDBatches.size() < degreeOfParallelism) {
+ throw new IllegalStateException("Degree of parallelism [" + degreeOfParallelism + "] " +
+ " is larger than number of batches [" + itemIDBatches.size() +"].");
+ }
+
+ log.info("Queued {} items in {} batches", numItems, itemIDBatches.size());
+
+ return itemIDBatches;
+ }
+
+
+ private static class Output implements Runnable {
+
+ private final BlockingQueue<List<SimilarItems>> results;
+ private final SimilarItemsWriter writer;
+ private final AtomicInteger numActiveWorkers;
+ private int numSimilaritiesProcessed = 0;
+
+ Output(BlockingQueue<List<SimilarItems>> results, SimilarItemsWriter writer, AtomicInteger numActiveWorkers) {
+ this.results = results;
+ this.writer = writer;
+ this.numActiveWorkers = numActiveWorkers;
+ }
+
+ private int getNumSimilaritiesProcessed() {
+ return numSimilaritiesProcessed;
+ }
+
+ @Override
+ public void run() {
+ while (numActiveWorkers.get() != 0) {
+ try {
+ List<SimilarItems> similarItemsOfABatch = results.poll(10, TimeUnit.MILLISECONDS);
+ if (similarItemsOfABatch != null) {
+ for (SimilarItems similarItems : similarItemsOfABatch) {
+ writer.add(similarItems);
+ numSimilaritiesProcessed += similarItems.numSimilarItems();
+ }
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+ }
+
+ private class SimilarItemsWorker implements Runnable {
+
+ private final int number;
+ private final BlockingQueue<long[]> itemIDBatches;
+ private final BlockingQueue<List<SimilarItems>> results;
+ private final AtomicInteger numActiveWorkers;
+
+ SimilarItemsWorker(int number, BlockingQueue<long[]> itemIDBatches, BlockingQueue<List<SimilarItems>> results,
+ AtomicInteger numActiveWorkers) {
+ this.number = number;
+ this.itemIDBatches = itemIDBatches;
+ this.results = results;
+ this.numActiveWorkers = numActiveWorkers;
+ }
+
+ @Override
+ public void run() {
+
+ int numBatchesProcessed = 0;
+ while (!itemIDBatches.isEmpty()) {
+ try {
+ long[] itemIDBatch = itemIDBatches.take();
+
+ List<SimilarItems> similarItemsOfBatch = Lists.newArrayListWithCapacity(itemIDBatch.length);
+ for (long itemID : itemIDBatch) {
+ List<RecommendedItem> similarItems = getRecommender().mostSimilarItems(itemID, getSimilarItemsPerItem());
+
+ similarItemsOfBatch.add(new SimilarItems(itemID, similarItems));
+ }
+
+ results.offer(similarItemsOfBatch);
+
+ if (++numBatchesProcessed % 5 == 0) {
+ log.info("worker {} processed {} batches", number, numBatchesProcessed);
+ }
+
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ log.info("worker {} processed {} batches. done.", number, numBatchesProcessed);
+ numActiveWorkers.decrementAndGet();
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/model/DataModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/model/DataModel.java b/mr/src/main/java/org/apache/mahout/cf/taste/model/DataModel.java
new file mode 100644
index 0000000..022d02d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/model/DataModel.java
@@ -0,0 +1,199 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.model;
+
+import java.io.Serializable;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+
+/**
+ * <p>
+ * Implementations represent a repository of information about users and their associated {@link Preference}s
+ * for items.
+ * </p>
+ */
+public interface DataModel extends Refreshable, Serializable {
+
+ /**
+ * @return all user IDs in the model, in order
+ * @throws TasteException
+ * if an error occurs while accessing the data
+ */
+ LongPrimitiveIterator getUserIDs() throws TasteException;
+
+ /**
+ * @param userID
+ * ID of user to get prefs for
+ * @return user's preferences, ordered by item ID
+ * @throws org.apache.mahout.cf.taste.common.NoSuchUserException
+ * if the user does not exist
+ * @throws TasteException
+ * if an error occurs while accessing the data
+ */
+ PreferenceArray getPreferencesFromUser(long userID) throws TasteException;
+
+ /**
+ * @param userID
+ * ID of user to get prefs for
+ * @return IDs of items user expresses a preference for
+ * @throws org.apache.mahout.cf.taste.common.NoSuchUserException
+ * if the user does not exist
+ * @throws TasteException
+ * if an error occurs while accessing the data
+ */
+ FastIDSet getItemIDsFromUser(long userID) throws TasteException;
+
+ /**
+ * @return a {@link LongPrimitiveIterator} of all item IDs in the model, in order
+ * @throws TasteException
+ * if an error occurs while accessing the data
+ */
+ LongPrimitiveIterator getItemIDs() throws TasteException;
+
+ /**
+ * @param itemID
+ * item ID
+ * @return all existing {@link Preference}s expressed for that item, ordered by user ID, as an array
+ * @throws org.apache.mahout.cf.taste.common.NoSuchItemException
+ * if the item does not exist
+ * @throws TasteException
+ * if an error occurs while accessing the data
+ */
+ PreferenceArray getPreferencesForItem(long itemID) throws TasteException;
+
+ /**
+ * Retrieves the preference value for a single user and item.
+ *
+ * @param userID
+ * user ID to get pref value from
+ * @param itemID
+ * item ID to get pref value for
+ * @return preference value from the given user for the given item or null if none exists
+ * @throws org.apache.mahout.cf.taste.common.NoSuchUserException
+ * if the user does not exist
+ * @throws TasteException
+ * if an error occurs while accessing the data
+ */
+ Float getPreferenceValue(long userID, long itemID) throws TasteException;
+
+ /**
+ * Retrieves the time at which a preference value from a user and item was set, if known.
+ * Time is expressed in the usual way, as a number of milliseconds since the epoch.
+ *
+ * @param userID user ID for preference in question
+ * @param itemID item ID for preference in question
+ * @return time at which preference was set or null if no preference exists or its time is not known
+ * @throws org.apache.mahout.cf.taste.common.NoSuchUserException if the user does not exist
+ * @throws TasteException if an error occurs while accessing the data
+ */
+ Long getPreferenceTime(long userID, long itemID) throws TasteException;
+
+ /**
+ * @return total number of items known to the model. This is generally the union of all items preferred by
+ * at least one user but could include more.
+ * @throws TasteException
+ * if an error occurs while accessing the data
+ */
+ int getNumItems() throws TasteException;
+
+ /**
+ * @return total number of users known to the model.
+ * @throws TasteException
+ * if an error occurs while accessing the data
+ */
+ int getNumUsers() throws TasteException;
+
+ /**
+ * @param itemID item ID to check for
+ * @return the number of users who have expressed a preference for the item
+ * @throws TasteException if an error occurs while accessing the data
+ */
+ int getNumUsersWithPreferenceFor(long itemID) throws TasteException;
+
+ /**
+ * @param itemID1 first item ID to check for
+ * @param itemID2 second item ID to check for
+ * @return the number of users who have expressed a preference for the items
+ * @throws TasteException if an error occurs while accessing the data
+ */
+ int getNumUsersWithPreferenceFor(long itemID1, long itemID2) throws TasteException;
+
+ /**
+ * <p>
+ * Sets a particular preference (item plus rating) for a user.
+ * </p>
+ *
+ * @param userID
+ * user to set preference for
+ * @param itemID
+ * item to set preference for
+ * @param value
+ * preference value
+ * @throws org.apache.mahout.cf.taste.common.NoSuchItemException
+ * if the item does not exist
+ * @throws org.apache.mahout.cf.taste.common.NoSuchUserException
+ * if the user does not exist
+ * @throws TasteException
+ * if an error occurs while accessing the data
+ */
+ void setPreference(long userID, long itemID, float value) throws TasteException;
+
+ /**
+ * <p>
+ * Removes a particular preference for a user.
+ * </p>
+ *
+ * @param userID
+ * user from which to remove preference
+ * @param itemID
+ * item to remove preference for
+ * @throws org.apache.mahout.cf.taste.common.NoSuchItemException
+ * if the item does not exist
+ * @throws org.apache.mahout.cf.taste.common.NoSuchUserException
+ * if the user does not exist
+ * @throws TasteException
+ * if an error occurs while accessing the data
+ */
+ void removePreference(long userID, long itemID) throws TasteException;
+
+ /**
+ * @return true if this implementation actually stores and returns distinct preference values;
+ * that is, if it is not a 'boolean' DataModel
+ */
+ boolean hasPreferenceValues();
+
+ /**
+ * @return the maximum preference value that is possible in the current problem domain being evaluated. For
+ * example, if the domain is movie ratings on a scale of 1 to 5, this should be 5. While a
+ * {@link org.apache.mahout.cf.taste.recommender.Recommender} may estimate a preference value above 5.0, it
+ * isn't "fair" to consider that the system is actually suggesting an impossible rating of, say, 5.4 stars.
+ * In practice the application would cap this estimate to 5.0. Since evaluators evaluate
+ * the difference between estimated and actual value, this at least prevents this effect from unfairly
+ * penalizing a {@link org.apache.mahout.cf.taste.recommender.Recommender}
+ */
+ float getMaxPreference();
+
+ /**
+ * @see #getMaxPreference()
+ */
+ float getMinPreference();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/model/IDMigrator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/model/IDMigrator.java b/mr/src/main/java/org/apache/mahout/cf/taste/model/IDMigrator.java
new file mode 100644
index 0000000..cc477fe
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/model/IDMigrator.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.model;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * <p>
+ * Mahout 0.2 changed the framework to operate only in terms of numeric (long) ID values for users and items.
+ * This is, obviously, not compatible with applications that used other key types -- most commonly
+ * {@link String}. Implementation of this class provide support for mapping String to longs and vice versa in
+ * order to provide a smoother migration path to applications that must still use strings as IDs.
+ * </p>
+ *
+ * <p>
+ * The mapping from strings to 64-bit numeric values is fixed here, to provide a standard implementation that
+ * is 'portable' or reproducible outside the framework easily. See {@link #toLongID(String)}.
+ * </p>
+ *
+ * <p>
+ * Because this mapping is deterministically computable, it does not need to be stored. Indeed, subclasses'
+ * job is to store the reverse mapping. There are an infinite number of strings but only a fixed number of
+ * longs, so, it is possible for two strings to map to the same value. Subclasses do not treat this as an
+ * error but rather retain only the most recent mapping, overwriting a previous mapping. The probability of
+ * collision in a 64-bit space is quite small, but not zero. However, in the context of a collaborative
+ * filtering problem, the consequence of a collision is small, at worst -- perhaps one user receives another
+ * recommendations.
+ * </p>
+ *
+ * @since 0.2
+ */
+public interface IDMigrator extends Refreshable {
+
+ /**
+ * @return the top 8 bytes of the MD5 hash of the bytes of the given {@link String}'s UTF-8 encoding as a
+ * long.
+ */
+ long toLongID(String stringID);
+
+ /**
+ * @return the string ID most recently associated with the given long ID, or null if doesn't exist
+ * @throws TasteException
+ * if an error occurs while retrieving the mapping
+ */
+ String toStringID(long longID) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/model/JDBCDataModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/model/JDBCDataModel.java b/mr/src/main/java/org/apache/mahout/cf/taste/model/JDBCDataModel.java
new file mode 100644
index 0000000..e91ed48
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/model/JDBCDataModel.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.model;
+
+import javax.sql.DataSource;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+
+public interface JDBCDataModel extends DataModel {
+
+ /**
+ * @return {@link DataSource} underlying this model
+ */
+ DataSource getDataSource();
+
+ /**
+ * Hmm, should this exist elsewhere? seems like most relevant for a DB implementation, which is not in
+ * memory, which might want to export to memory.
+ *
+ * @return all user preference data
+ */
+ FastByIDMap<PreferenceArray> exportWithPrefs() throws TasteException;
+
+ FastByIDMap<FastIDSet> exportWithIDsOnly() throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/model/Preference.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/model/Preference.java b/mr/src/main/java/org/apache/mahout/cf/taste/model/Preference.java
new file mode 100644
index 0000000..fe0150a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/model/Preference.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.model;
+
+/**
+ * <p>
+ * A {@link Preference} encapsulates an item and a preference value, which indicates the strength of the
+ * preference for it. {@link Preference}s are associated to users.
+ * </p>
+ */
+public interface Preference {
+
+ /** @return ID of user who prefers the item */
+ long getUserID();
+
+ /** @return item ID that is preferred */
+ long getItemID();
+
+ /**
+ * @return strength of the preference for that item. Zero should indicate "no preference either way";
+ * positive values indicate preference and negative values indicate dislike
+ */
+ float getValue();
+
+ /**
+ * Sets the strength of the preference for this item
+ *
+ * @param value
+ * new preference
+ */
+ void setValue(float value);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/model/PreferenceArray.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/model/PreferenceArray.java b/mr/src/main/java/org/apache/mahout/cf/taste/model/PreferenceArray.java
new file mode 100644
index 0000000..3886bc6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/model/PreferenceArray.java
@@ -0,0 +1,143 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.model;
+
+import java.io.Serializable;
+
+/**
+ * An alternate representation of an array of {@link Preference}. Implementations, in theory, can produce a
+ * more memory-efficient representation.
+ */
+public interface PreferenceArray extends Cloneable, Serializable, Iterable<Preference> {
+
+ /**
+ * @return size of length of the "array"
+ */
+ int length();
+
+ /**
+ * @param i
+ * index
+ * @return a materialized {@link Preference} representation of the preference at i
+ */
+ Preference get(int i);
+
+ /**
+ * Sets preference at i from information in the given {@link Preference}
+ *
+ * @param i
+ * @param pref
+ */
+ void set(int i, Preference pref);
+
+ /**
+ * @param i
+ * index
+ * @return user ID from preference at i
+ */
+ long getUserID(int i);
+
+ /**
+ * Sets user ID for preference at i.
+ *
+ * @param i
+ * index
+ * @param userID
+ * new user ID
+ */
+ void setUserID(int i, long userID);
+
+ /**
+ * @param i
+ * index
+ * @return item ID from preference at i
+ */
+ long getItemID(int i);
+
+ /**
+ * Sets item ID for preference at i.
+ *
+ * @param i
+ * index
+ * @param itemID
+ * new item ID
+ */
+ void setItemID(int i, long itemID);
+
+ /**
+ * @return all user or item IDs
+ */
+ long[] getIDs();
+
+ /**
+ * @param i
+ * index
+ * @return preference value from preference at i
+ */
+ float getValue(int i);
+
+ /**
+ * Sets preference value for preference at i.
+ *
+ * @param i
+ * index
+ * @param value
+ * new preference value
+ */
+ void setValue(int i, float value);
+
+ /**
+ * @return independent copy of this object
+ */
+ PreferenceArray clone();
+
+ /**
+ * Sorts underlying array by user ID, ascending.
+ */
+ void sortByUser();
+
+ /**
+ * Sorts underlying array by item ID, ascending.
+ */
+ void sortByItem();
+
+ /**
+ * Sorts underlying array by preference value, ascending.
+ */
+ void sortByValue();
+
+ /**
+ * Sorts underlying array by preference value, descending.
+ */
+ void sortByValueReversed();
+
+ /**
+ * @param userID
+ * user ID
+ * @return true if array contains a preference with given user ID
+ */
+ boolean hasPrefWithUserID(long userID);
+
+ /**
+ * @param itemID
+ * item ID
+ * @return true if array contains a preference with given item ID
+ */
+ boolean hasPrefWithItemID(long itemID);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/model/UpdatableIDMigrator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/model/UpdatableIDMigrator.java b/mr/src/main/java/org/apache/mahout/cf/taste/model/UpdatableIDMigrator.java
new file mode 100644
index 0000000..ff29a34
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/model/UpdatableIDMigrator.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.model;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+
+public interface UpdatableIDMigrator extends IDMigrator {
+
+ /**
+ * Stores the reverse long-to-String mapping in some kind of backing store. Note that this must be called
+ * directly (or indirectly through {@link #initialize(Iterable)}) for every String that might be encountered
+ * in the application, or else the mapping will not be known.
+ *
+ * @param longID
+ * long ID
+ * @param stringID
+ * string ID that maps to/from that long ID
+ * @throws TasteException
+ * if an error occurs while saving the mapping
+ */
+ void storeMapping(long longID, String stringID) throws TasteException;
+
+ /**
+ * Make the mapping aware of the given string IDs. This must be called initially before the implementation
+ * is used, or else it will not be aware of reverse long-to-String mappings.
+ *
+ * @throws TasteException
+ * if an error occurs while storing the mappings
+ */
+ void initialize(Iterable<String> stringIDs) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/neighborhood/UserNeighborhood.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/neighborhood/UserNeighborhood.java b/mr/src/main/java/org/apache/mahout/cf/taste/neighborhood/UserNeighborhood.java
new file mode 100644
index 0000000..2a143e1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/neighborhood/UserNeighborhood.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.neighborhood;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * <p>
+ * Implementations of this interface compute a "neighborhood" of users like a given user. This neighborhood
+ * can be used to compute recommendations then.
+ * </p>
+ */
+public interface UserNeighborhood extends Refreshable {
+
+ /**
+ * @param userID
+ * ID of user for which a neighborhood will be computed
+ * @return IDs of users in the neighborhood
+ * @throws TasteException
+ * if an error occurs while accessing data
+ */
+ long[] getUserNeighborhood(long userID) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/recommender/CandidateItemsStrategy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/recommender/CandidateItemsStrategy.java b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/CandidateItemsStrategy.java
new file mode 100644
index 0000000..ada1949
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/CandidateItemsStrategy.java
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.recommender;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+
+/**
+ * Used to retrieve all items that could possibly be recommended to the user
+ */
+public interface CandidateItemsStrategy extends Refreshable {
+
+ /**
+ * @return IDs of all items that could be recommended to the user
+ */
+ FastIDSet getCandidateItems(long userID, PreferenceArray preferencesFromUser, DataModel dataModel,
+ boolean includeKnownItems) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/recommender/IDRescorer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/recommender/IDRescorer.java b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/IDRescorer.java
new file mode 100644
index 0000000..d9a9cf7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/IDRescorer.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.recommender;
+
+/**
+ * <p>
+ * A {@link Rescorer} which operates on {@code long} primitive IDs, rather than arbitrary {@link Object}s.
+ * This is provided since most uses of this interface in the framework take IDs (as {@code long}) as an
+ * argument, and so this can be used to avoid unnecessary boxing/unboxing.
+ * </p>
+ */
+public interface IDRescorer {
+
+ /**
+ * @param id
+ * ID of thing (user, item, etc.) to rescore
+ * @param originalScore
+ * original score
+ * @return modified score, or {@link Double#NaN} to indicate that this should be excluded entirely
+ */
+ double rescore(long id, double originalScore);
+
+ /**
+ * Returns {@code true} to exclude the given thing.
+ *
+ * @param id
+ * ID of thing (user, item, etc.) to rescore
+ * @return {@code true} to exclude, {@code false} otherwise
+ */
+ boolean isFiltered(long id);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/recommender/ItemBasedRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/recommender/ItemBasedRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/ItemBasedRecommender.java
new file mode 100644
index 0000000..570f851
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/ItemBasedRecommender.java
@@ -0,0 +1,145 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.recommender;
+
+import java.util.List;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.common.LongPair;
+
+/**
+ * <p>
+ * Interface implemented by "item-based" recommenders.
+ * </p>
+ */
+public interface ItemBasedRecommender extends Recommender {
+
+ /**
+ * @param itemID
+ * ID of item for which to find most similar other items
+ * @param howMany
+ * desired number of most similar items to find
+ * @return items most similar to the given item, ordered from most similar to least
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ List<RecommendedItem> mostSimilarItems(long itemID, int howMany) throws TasteException;
+
+ /**
+ * @param itemID
+ * ID of item for which to find most similar other items
+ * @param howMany
+ * desired number of most similar items to find
+ * @param rescorer
+ * {@link Rescorer} which can adjust item-item similarity estimates used to determine most similar
+ * items
+ * @return itemss most similar to the given item, ordered from most similar to least
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ List<RecommendedItem> mostSimilarItems(long itemID, int howMany, Rescorer<LongPair> rescorer) throws TasteException;
+
+ /**
+ * @param itemIDs
+ * IDs of item for which to find most similar other items
+ * @param howMany
+ * desired number of most similar items to find estimates used to determine most similar items
+ * @return items most similar to the given items, ordered from most similar to least
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ List<RecommendedItem> mostSimilarItems(long[] itemIDs, int howMany) throws TasteException;
+
+ /**
+ * @param itemIDs
+ * IDs of item for which to find most similar other items
+ * @param howMany
+ * desired number of most similar items to find
+ * @param rescorer
+ * {@link Rescorer} which can adjust item-item similarity estimates used to determine most similar
+ * items
+ * @return items most similar to the given items, ordered from most similar to least
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ List<RecommendedItem> mostSimilarItems(long[] itemIDs,
+ int howMany,
+ Rescorer<LongPair> rescorer) throws TasteException;
+
+ /**
+ * @param itemIDs
+ * IDs of item for which to find most similar other items
+ * @param howMany
+ * desired number of most similar items to find
+ * @param excludeItemIfNotSimilarToAll
+ * exclude an item if it is not similar to each of the input items
+ * @return items most similar to the given items, ordered from most similar to least
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ List<RecommendedItem> mostSimilarItems(long[] itemIDs,
+ int howMany,
+ boolean excludeItemIfNotSimilarToAll) throws TasteException;
+
+ /**
+ * @param itemIDs
+ * IDs of item for which to find most similar other items
+ * @param howMany
+ * desired number of most similar items to find
+ * @param rescorer
+ * {@link Rescorer} which can adjust item-item similarity estimates used to determine most similar
+ * items
+ * @param excludeItemIfNotSimilarToAll
+ * exclude an item if it is not similar to each of the input items
+ * @return items most similar to the given items, ordered from most similar to least
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ List<RecommendedItem> mostSimilarItems(long[] itemIDs,
+ int howMany,
+ Rescorer<LongPair> rescorer,
+ boolean excludeItemIfNotSimilarToAll) throws TasteException;
+
+ /**
+ * <p>
+ * Lists the items that were most influential in recommending a given item to a given user. Exactly how this
+ * is determined is left to the implementation, but, generally this will return items that the user prefers
+ * and that are similar to the given item.
+ * </p>
+ *
+ * <p>
+ * This returns a {@link List} of {@link RecommendedItem} which is a little misleading since it's returning
+ * recommend<strong>ing</strong> items, but, I thought it more natural to just reuse this class since it
+ * encapsulates an item and value. The value here does not necessarily have a consistent interpretation or
+ * expected range; it will be higher the more influential the item was in the recommendation.
+ * </p>
+ *
+ * @param userID
+ * ID of user who was recommended the item
+ * @param itemID
+ * ID of item that was recommended
+ * @param howMany
+ * maximum number of items to return
+ * @return {@link List} of {@link RecommendedItem}, ordered from most influential in recommended the given
+ * item to least
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ List<RecommendedItem> recommendedBecause(long userID, long itemID, int howMany) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/recommender/MostSimilarItemsCandidateItemsStrategy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/recommender/MostSimilarItemsCandidateItemsStrategy.java b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/MostSimilarItemsCandidateItemsStrategy.java
new file mode 100644
index 0000000..282ceff
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/MostSimilarItemsCandidateItemsStrategy.java
@@ -0,0 +1,31 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.recommender;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+/**
+ * Used to retrieve all items that could possibly be similar
+ */
+public interface MostSimilarItemsCandidateItemsStrategy extends Refreshable {
+
+ FastIDSet getCandidateItems(long[] itemIDs, DataModel dataModel) throws TasteException;
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/recommender/RecommendedItem.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/recommender/RecommendedItem.java b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/RecommendedItem.java
new file mode 100644
index 0000000..1fcece8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/RecommendedItem.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.recommender;
+
+/**
+ * <p>
+ * Implementations encapsulate items that are recommended, and include the item recommended and a value
+ * expressing the strength of the preference.
+ * </p>
+ */
+public interface RecommendedItem {
+
+ /** @return the recommended item ID */
+ long getItemID();
+
+ /**
+ * <p>
+ * A value expressing the strength of the preference for the recommended item. The range of the values
+ * depends on the implementation. Implementations must use larger values to express stronger preference.
+ * </p>
+ *
+ * @return strength of the preference
+ */
+ float getValue();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Recommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Recommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Recommender.java
new file mode 100644
index 0000000..4135aff
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Recommender.java
@@ -0,0 +1,132 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.recommender;
+
+import java.util.List;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+/**
+ * <p>
+ * Implementations of this interface can recommend items for a user. Implementations will likely take
+ * advantage of several classes in other packages here to compute this.
+ * </p>
+ */
+public interface Recommender extends Refreshable {
+
+ /**
+ * @param userID
+ * user for which recommendations are to be computed
+ * @param howMany
+ * desired number of recommendations
+ * @return {@link List} of recommended {@link RecommendedItem}s, ordered from most strongly recommend to
+ * least
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel}
+ */
+ List<RecommendedItem> recommend(long userID, int howMany) throws TasteException;
+
+ /**
+ * @param userID
+ * user for which recommendations are to be computed
+ * @param howMany
+ * desired number of recommendations
+ * @return {@link List} of recommended {@link RecommendedItem}s, ordered from most strongly recommend to
+ * least
+ * @param includeKnownItems
+ * whether to include items already known by the user in recommendations
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel}
+ */
+ List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException;
+
+ /**
+ * @param userID
+ * user for which recommendations are to be computed
+ * @param howMany
+ * desired number of recommendations
+ * @param rescorer
+ * rescoring function to apply before final list of recommendations is determined
+ * @return {@link List} of recommended {@link RecommendedItem}s, ordered from most strongly recommend to
+ * least
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel}
+ */
+ List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException;
+
+ /**
+ * @param userID
+ * user for which recommendations are to be computed
+ * @param howMany
+ * desired number of recommendations
+ * @param rescorer
+ * rescoring function to apply before final list of recommendations is determined
+ * @param includeKnownItems
+ * whether to include items already known by the user in recommendations
+ * @return {@link List} of recommended {@link RecommendedItem}s, ordered from most strongly recommend to
+ * least
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel}
+ */
+
+ List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException;
+
+ /**
+ * @param userID
+ * user ID whose preference is to be estimated
+ * @param itemID
+ * item ID to estimate preference for
+ * @return an estimated preference if the user has not expressed a preference for the item, or else the
+ * user's actual preference for the item. If a preference cannot be estimated, returns
+ * {@link Double#NaN}
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel}
+ */
+ float estimatePreference(long userID, long itemID) throws TasteException;
+
+ /**
+ * @param userID
+ * user to set preference for
+ * @param itemID
+ * item to set preference for
+ * @param value
+ * preference value
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel}
+ */
+ void setPreference(long userID, long itemID, float value) throws TasteException;
+
+ /**
+ * @param userID
+ * user from which to remove preference
+ * @param itemID
+ * item for which to remove preference
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel}
+ */
+ void removePreference(long userID, long itemID) throws TasteException;
+
+ /**
+ * @return underlying {@link DataModel} used by this {@link Recommender} implementation
+ */
+ DataModel getDataModel();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java
new file mode 100644
index 0000000..1490761
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/Rescorer.java
@@ -0,0 +1,52 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.recommender;
+
+/**
+ * <p>
+ * A {@link Rescorer} simply assigns a new "score" to a thing like an ID of an item or user which a
+ * {@link Recommender} is considering returning as a top recommendation. It may be used to arbitrarily re-rank
+ * the results according to application-specific logic before returning recommendations. For example, an
+ * application may want to boost the score of items in a certain category just for one request.
+ * </p>
+ *
+ * <p>
+ * A {@link Rescorer} can also exclude a thing from consideration entirely by returning {@code true} from
+ * {@link #isFiltered(Object)}.
+ * </p>
+ */
+public interface Rescorer<T> {
+
+ /**
+ * @param thing
+ * thing to rescore
+ * @param originalScore
+ * original score
+ * @return modified score, or {@link Double#NaN} to indicate that this should be excluded entirely
+ */
+ double rescore(T thing, double originalScore);
+
+ /**
+ * Returns {@code true} to exclude the given thing.
+ *
+ * @param thing
+ * the thing to filter
+ * @return {@code true} to exclude, {@code false} otherwise
+ */
+ boolean isFiltered(T thing);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java
new file mode 100644
index 0000000..b48593a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/recommender/UserBasedRecommender.java
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.recommender;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.common.LongPair;
+
+/**
+ * <p>
+ * Interface implemented by "user-based" recommenders.
+ * </p>
+ */
+public interface UserBasedRecommender extends Recommender {
+
+ /**
+ * @param userID
+ * ID of user for which to find most similar other users
+ * @param howMany
+ * desired number of most similar users to find
+ * @return users most similar to the given user
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ long[] mostSimilarUserIDs(long userID, int howMany) throws TasteException;
+
+ /**
+ * @param userID
+ * ID of user for which to find most similar other users
+ * @param howMany
+ * desired number of most similar users to find
+ * @param rescorer
+ * {@link Rescorer} which can adjust user-user similarity estimates used to determine most similar
+ * users
+ * @return IDs of users most similar to the given user
+ * @throws TasteException
+ * if an error occurs while accessing the {@link org.apache.mahout.cf.taste.model.DataModel}
+ */
+ long[] mostSimilarUserIDs(long userID, int howMany, Rescorer<LongPair> rescorer) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java
new file mode 100644
index 0000000..814610b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/ItemSimilarity.java
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * <p>
+ * Implementations of this interface define a notion of similarity between two items. Implementations should
+ * return values in the range -1.0 to 1.0, with 1.0 representing perfect similarity.
+ * </p>
+ *
+ * @see UserSimilarity
+ */
+public interface ItemSimilarity extends Refreshable {
+
+ /**
+ * <p>
+ * Returns the degree of similarity, of two items, based on the preferences that users have expressed for
+ * the items.
+ * </p>
+ *
+ * @param itemID1 first item ID
+ * @param itemID2 second item ID
+ * @return similarity between the items, in [-1,1] or {@link Double#NaN} similarity is unknown
+ * @throws org.apache.mahout.cf.taste.common.NoSuchItemException
+ * if either item is known to be non-existent in the data
+ * @throws TasteException if an error occurs while accessing the data
+ */
+ double itemSimilarity(long itemID1, long itemID2) throws TasteException;
+
+ /**
+ * <p>A bulk-get version of {@link #itemSimilarity(long, long)}.</p>
+ *
+ * @param itemID1 first item ID
+ * @param itemID2s second item IDs to compute similarity with
+ * @return similarity between itemID1 and other items
+ * @throws org.apache.mahout.cf.taste.common.NoSuchItemException
+ * if any item is known to be non-existent in the data
+ * @throws TasteException if an error occurs while accessing the data
+ */
+ double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException;
+
+ /**
+ * @return all IDs of similar items, in no particular order
+ */
+ long[] allSimilarItemIDs(long itemID) throws TasteException;
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java
new file mode 100644
index 0000000..76bb328
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/PreferenceInferrer.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * <p>
+ * Implementations of this interface compute an inferred preference for a user and an item that the user has
+ * not expressed any preference for. This might be an average of other preferences scores from that user, for
+ * example. This technique is sometimes called "default voting".
+ * </p>
+ */
+public interface PreferenceInferrer extends Refreshable {
+
+ /**
+ * <p>
+ * Infers the given user's preference value for an item.
+ * </p>
+ *
+ * @param userID
+ * ID of user to infer preference for
+ * @param itemID
+ * item ID to infer preference for
+ * @return inferred preference
+ * @throws TasteException
+ * if an error occurs while inferring
+ */
+ float inferPreference(long userID, long itemID) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java
new file mode 100644
index 0000000..bd53c51
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/UserSimilarity.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * <p>
+ * Implementations of this interface define a notion of similarity between two users. Implementations should
+ * return values in the range -1.0 to 1.0, with 1.0 representing perfect similarity.
+ * </p>
+ *
+ * @see ItemSimilarity
+ */
+public interface UserSimilarity extends Refreshable {
+
+ /**
+ * <p>
+ * Returns the degree of similarity, of two users, based on the their preferences.
+ * </p>
+ *
+ * @param userID1 first user ID
+ * @param userID2 second user ID
+ * @return similarity between the users, in [-1,1] or {@link Double#NaN} similarity is unknown
+ * @throws org.apache.mahout.cf.taste.common.NoSuchUserException
+ * if either user is known to be non-existent in the data
+ * @throws TasteException if an error occurs while accessing the data
+ */
+ double userSimilarity(long userID1, long userID2) throws TasteException;
+
+ // Should we implement userSimilarities() like ItemSimilarity.itemSimilarities()?
+
+ /**
+ * <p>
+ * Attaches a {@link PreferenceInferrer} to the {@link UserSimilarity} implementation.
+ * </p>
+ *
+ * @param inferrer {@link PreferenceInferrer}
+ */
+ void setPreferenceInferrer(PreferenceInferrer inferrer);
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java
new file mode 100644
index 0000000..b934d0c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/BatchItemSimilarities.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute;
+
+import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
+
+import java.io.IOException;
+
+public abstract class BatchItemSimilarities {
+
+ private final ItemBasedRecommender recommender;
+ private final int similarItemsPerItem;
+
+ /**
+ * @param recommender recommender to use
+ * @param similarItemsPerItem number of similar items to compute per item
+ */
+ protected BatchItemSimilarities(ItemBasedRecommender recommender, int similarItemsPerItem) {
+ this.recommender = recommender;
+ this.similarItemsPerItem = similarItemsPerItem;
+ }
+
+ protected ItemBasedRecommender getRecommender() {
+ return recommender;
+ }
+
+ protected int getSimilarItemsPerItem() {
+ return similarItemsPerItem;
+ }
+
+ /**
+ * @param degreeOfParallelism number of threads to use for the computation
+ * @param maxDurationInHours maximum duration of the computation
+ * @param writer {@link SimilarItemsWriter} used to persist the results
+ * @return the number of similarities precomputed
+ * @throws IOException
+ * @throws RuntimeException if the computation takes longer than maxDurationInHours
+ */
+ public abstract int computeItemSimilarities(int degreeOfParallelism, int maxDurationInHours,
+ SimilarItemsWriter writer) throws IOException;
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java
new file mode 100644
index 0000000..5d40051
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItem.java
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute;
+
+import com.google.common.primitives.Doubles;
+
+import java.util.Comparator;
+
+/**
+ * Modeling similarity towards another item
+ */
+public class SimilarItem {
+
+ public static final Comparator<SimilarItem> COMPARE_BY_SIMILARITY = new Comparator<SimilarItem>() {
+ @Override
+ public int compare(SimilarItem s1, SimilarItem s2) {
+ return Doubles.compare(s1.similarity, s2.similarity);
+ }
+ };
+
+ private long itemID;
+ private double similarity;
+
+ public SimilarItem(long itemID, double similarity) {
+ set(itemID, similarity);
+ }
+
+ public void set(long itemID, double similarity) {
+ this.itemID = itemID;
+ this.similarity = similarity;
+ }
+
+ public long getItemID() {
+ return itemID;
+ }
+
+ public double getSimilarity() {
+ return similarity;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java
new file mode 100644
index 0000000..18ee42c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItems.java
@@ -0,0 +1,84 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute;
+
+import com.google.common.collect.UnmodifiableIterator;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.NoSuchElementException;
+
+/**
+ * Compact representation of all similar items for an item
+ */
+public class SimilarItems {
+
+ private final long itemID;
+ private final long[] similarItemIDs;
+ private final double[] similarities;
+
+ public SimilarItems(long itemID, List<RecommendedItem> similarItems) {
+ this.itemID = itemID;
+
+ int numSimilarItems = similarItems.size();
+ similarItemIDs = new long[numSimilarItems];
+ similarities = new double[numSimilarItems];
+
+ for (int n = 0; n < numSimilarItems; n++) {
+ similarItemIDs[n] = similarItems.get(n).getItemID();
+ similarities[n] = similarItems.get(n).getValue();
+ }
+ }
+
+ public long getItemID() {
+ return itemID;
+ }
+
+ public int numSimilarItems() {
+ return similarItemIDs.length;
+ }
+
+ public Iterable<SimilarItem> getSimilarItems() {
+ return new Iterable<SimilarItem>() {
+ @Override
+ public Iterator<SimilarItem> iterator() {
+ return new SimilarItemsIterator();
+ }
+ };
+ }
+
+ private class SimilarItemsIterator extends UnmodifiableIterator<SimilarItem> {
+
+ private int index = 0;
+
+ @Override
+ public boolean hasNext() {
+ return index < (similarItemIDs.length - 1);
+ }
+
+ @Override
+ public SimilarItem next() {
+ if (!hasNext()) {
+ throw new NoSuchElementException();
+ }
+ index++;
+ return new SimilarItem(similarItemIDs[index], similarities[index]);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java
new file mode 100644
index 0000000..35d6bfe
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/similarity/precompute/SimilarItemsWriter.java
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.similarity.precompute;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+/**
+ * Used to persist the results of a batch item similarity computation
+ * conducted with a {@link BatchItemSimilarities} implementation
+ */
+public interface SimilarItemsWriter extends Closeable {
+
+ void open() throws IOException;
+
+ void add(SimilarItems similarItems) throws IOException;
+
+}
[11/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java
new file mode 100644
index 0000000..16c8dff
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDevTest.java
@@ -0,0 +1,107 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import java.util.Random;
+
+public final class RunningAverageAndStdDevTest extends TasteTestCase {
+
+ private static final double SMALL_EPSILON = 1.0;
+
+ @Test
+ public void testFull() {
+ RunningAverageAndStdDev average = new FullRunningAverageAndStdDev();
+
+ assertEquals(0, average.getCount());
+ assertTrue(Double.isNaN(average.getAverage()));
+ assertTrue(Double.isNaN(average.getStandardDeviation()));
+
+ average.addDatum(6.0);
+ assertEquals(1, average.getCount());
+ assertEquals(6.0, average.getAverage(), EPSILON);
+ assertTrue(Double.isNaN(average.getStandardDeviation()));
+
+ average.addDatum(6.0);
+ assertEquals(2, average.getCount());
+ assertEquals(6.0, average.getAverage(), EPSILON);
+ assertEquals(0.0, average.getStandardDeviation(), EPSILON);
+
+ average.removeDatum(6.0);
+ assertEquals(1, average.getCount());
+ assertEquals(6.0, average.getAverage(), EPSILON);
+ assertTrue(Double.isNaN(average.getStandardDeviation()));
+
+ average.addDatum(-4.0);
+ assertEquals(2, average.getCount());
+ assertEquals(1.0, average.getAverage(), EPSILON);
+ assertEquals(5.0 * 1.4142135623730951, average.getStandardDeviation(), EPSILON);
+
+ average.removeDatum(4.0);
+ assertEquals(1, average.getCount());
+ assertEquals(-2.0, average.getAverage(), EPSILON);
+ assertTrue(Double.isNaN(average.getStandardDeviation()));
+
+ }
+
+ @Test
+ public void testFullBig() {
+ RunningAverageAndStdDev average = new FullRunningAverageAndStdDev();
+
+ Random r = RandomUtils.getRandom();
+ for (int i = 0; i < 100000; i++) {
+ average.addDatum(r.nextDouble() * 1000.0);
+ }
+ assertEquals(500.0, average.getAverage(), SMALL_EPSILON);
+ assertEquals(1000.0 / Math.sqrt(12.0), average.getStandardDeviation(), SMALL_EPSILON);
+
+ }
+
+ @Test
+ public void testStddev() {
+
+ RunningAverageAndStdDev runningAverage = new FullRunningAverageAndStdDev();
+
+ assertEquals(0, runningAverage.getCount());
+ assertTrue(Double.isNaN(runningAverage.getAverage()));
+ runningAverage.addDatum(1.0);
+ assertEquals(1, runningAverage.getCount());
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ assertTrue(Double.isNaN(runningAverage.getStandardDeviation()));
+ runningAverage.addDatum(1.0);
+ assertEquals(2, runningAverage.getCount());
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(0.0, runningAverage.getStandardDeviation(), EPSILON);
+
+ runningAverage.addDatum(7.0);
+ assertEquals(3, runningAverage.getCount());
+ assertEquals(3.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(3.464101552963257, runningAverage.getStandardDeviation(), EPSILON);
+
+ runningAverage.addDatum(5.0);
+ assertEquals(4, runningAverage.getCount());
+ assertEquals(3.5, runningAverage.getAverage(), EPSILON);
+ assertEquals(3.0, runningAverage.getStandardDeviation(), EPSILON);
+
+ }
+
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
new file mode 100644
index 0000000..6b891c5
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
@@ -0,0 +1,75 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+
+/** <p>Tests {@link FullRunningAverage}.</p> */
+public final class RunningAverageTest extends TasteTestCase {
+
+ @Test
+ public void testFull() {
+ RunningAverage runningAverage = new FullRunningAverage();
+
+ assertEquals(0, runningAverage.getCount());
+ assertTrue(Double.isNaN(runningAverage.getAverage()));
+ runningAverage.addDatum(1.0);
+ assertEquals(1, runningAverage.getCount());
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.addDatum(1.0);
+ assertEquals(2, runningAverage.getCount());
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.addDatum(4.0);
+ assertEquals(3, runningAverage.getCount());
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.addDatum(-4.0);
+ assertEquals(4, runningAverage.getCount());
+ assertEquals(0.5, runningAverage.getAverage(), EPSILON);
+
+ runningAverage.removeDatum(-4.0);
+ assertEquals(3, runningAverage.getCount());
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.removeDatum(4.0);
+ assertEquals(2, runningAverage.getCount());
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+
+ runningAverage.changeDatum(0.0);
+ assertEquals(2, runningAverage.getCount());
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.changeDatum(2.0);
+ assertEquals(2, runningAverage.getCount());
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ }
+
+ @Test
+ public void testCopyConstructor() {
+ RunningAverage runningAverage = new FullRunningAverage();
+
+ runningAverage.addDatum(1.0);
+ runningAverage.addDatum(1.0);
+ assertEquals(2, runningAverage.getCount());
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+
+ RunningAverage copy = new FullRunningAverage(runningAverage.getCount(), runningAverage.getAverage());
+ assertEquals(2, copy.getCount());
+ assertEquals(1.0, copy.getAverage(), EPSILON);
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIteratorTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIteratorTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIteratorTest.java
new file mode 100644
index 0000000..a5b91ff
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIteratorTest.java
@@ -0,0 +1,91 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+
+public final class SamplingLongPrimitiveIteratorTest extends TasteTestCase {
+
+ @Test
+ public void testEmptyCase() {
+ assertFalse(new SamplingLongPrimitiveIterator(
+ countingIterator(0), 0.9999).hasNext());
+ assertFalse(new SamplingLongPrimitiveIterator(
+ countingIterator(0), 1).hasNext());
+ }
+
+ @Test
+ public void testSmallInput() {
+ SamplingLongPrimitiveIterator t = new SamplingLongPrimitiveIterator(
+ countingIterator(1), 0.9999);
+ assertTrue(t.hasNext());
+ assertEquals(0L, t.nextLong());
+ assertFalse(t.hasNext());
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testBadRate1() {
+ new SamplingLongPrimitiveIterator(countingIterator(1), 0.0);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testBadRate2() {
+ new SamplingLongPrimitiveIterator(countingIterator(1), 1.1);
+ }
+
+ @Test
+ public void testExactSizeMatch() {
+ SamplingLongPrimitiveIterator t = new SamplingLongPrimitiveIterator(
+ countingIterator(10), 1);
+ for (int i = 0; i < 10; i++) {
+ assertTrue(t.hasNext());
+ assertEquals(i, t.next().intValue());
+ }
+ assertFalse(t.hasNext());
+ }
+
+ @Test
+ public void testSample() {
+ double p = 0.1;
+ int n = 1000;
+ double sd = Math.sqrt(n * p * (1.0 - p));
+ for (int i = 0; i < 1000; i++) {
+ SamplingLongPrimitiveIterator t = new SamplingLongPrimitiveIterator(countingIterator(n), p);
+ int k = 0;
+ while (t.hasNext()) {
+ long v = t.nextLong();
+ k++;
+ assertTrue(v >= 0L);
+ assertTrue(v < 1000L);
+ }
+ // Should be +/- 5 standard deviations except in about 1 out of 1.7M cases
+ assertTrue(k >= 100 - 5 * sd);
+ assertTrue(k <= 100 + 5 * sd);
+ }
+ }
+
+ private static LongPrimitiveArrayIterator countingIterator(int to) {
+ long[] data = new long[to];
+ for (int i = 0; i < to; i++) {
+ data[i] = i;
+ }
+ return new LongPrimitiveArrayIterator(data);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java
new file mode 100644
index 0000000..daa56e8
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java
@@ -0,0 +1,85 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+
+/**
+ * <p>Tests {@link WeightedRunningAverage} and {@link WeightedRunningAverageAndStdDev}.</p>
+ */
+public final class WeightedRunningAverageTest extends TasteTestCase {
+
+ @Test
+ public void testWeighted() {
+
+ WeightedRunningAverage runningAverage = new WeightedRunningAverage();
+
+ assertEquals(0, runningAverage.getCount());
+ assertTrue(Double.isNaN(runningAverage.getAverage()));
+ runningAverage.addDatum(1.0, 2.0);
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.addDatum(1.0);
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.addDatum(8.0, 0.5);
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.addDatum(-4.0);
+ assertEquals(2.0/3.0, runningAverage.getAverage(), EPSILON);
+
+ runningAverage.removeDatum(-4.0);
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.removeDatum(2.0, 2.0);
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+
+ runningAverage.changeDatum(0.0);
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ runningAverage.changeDatum(4.0, 0.5);
+ assertEquals(5.0/1.5, runningAverage.getAverage(), EPSILON);
+ }
+
+ @Test
+ public void testWeightedAndStdDev() {
+
+ WeightedRunningAverageAndStdDev runningAverage = new WeightedRunningAverageAndStdDev();
+
+ assertEquals(0, runningAverage.getCount());
+ assertTrue(Double.isNaN(runningAverage.getAverage()));
+ assertTrue(Double.isNaN(runningAverage.getStandardDeviation()));
+
+ runningAverage.addDatum(1.0);
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ assertTrue(Double.isNaN(runningAverage.getStandardDeviation()));
+ runningAverage.addDatum(1.0, 2.0);
+ assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(0.0, runningAverage.getStandardDeviation(), EPSILON);
+ runningAverage.addDatum(8.0, 0.5);
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(Math.sqrt(10.5), runningAverage.getStandardDeviation(), EPSILON);
+ runningAverage.addDatum(-4.0);
+ assertEquals(2.0/3.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(Math.sqrt(15.75), runningAverage.getStandardDeviation(), EPSILON);
+
+ runningAverage.removeDatum(-4.0);
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(Math.sqrt(10.5), runningAverage.getStandardDeviation(), EPSILON);
+ runningAverage.removeDatum(2.0, 2.0);
+ assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+ assertEquals(Math.sqrt(31.5), runningAverage.getStandardDeviation(), EPSILON);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluatorImplTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluatorImplTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluatorImplTest.java
new file mode 100644
index 0000000..a47327d
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/eval/GenericRecommenderIRStatsEvaluatorImplTest.java
@@ -0,0 +1,73 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import org.apache.mahout.cf.taste.eval.DataModelBuilder;
+import org.apache.mahout.cf.taste.eval.IRStatistics;
+import org.apache.mahout.cf.taste.eval.RecommenderBuilder;
+import org.apache.mahout.cf.taste.eval.RecommenderIRStatsEvaluator;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.model.GenericBooleanPrefDataModel;
+import org.apache.mahout.cf.taste.impl.recommender.GenericBooleanPrefItemBasedRecommender;
+import org.apache.mahout.cf.taste.impl.similarity.LogLikelihoodSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.junit.Test;
+
+public final class GenericRecommenderIRStatsEvaluatorImplTest extends TasteTestCase {
+
+ @Test
+ public void testBoolean() throws Exception {
+ DataModel model = getBooleanDataModel();
+ RecommenderBuilder builder = new RecommenderBuilder() {
+ @Override
+ public Recommender buildRecommender(DataModel dataModel) {
+ return new GenericBooleanPrefItemBasedRecommender(dataModel, new LogLikelihoodSimilarity(dataModel));
+ }
+ };
+ DataModelBuilder dataModelBuilder = new DataModelBuilder() {
+ @Override
+ public DataModel buildDataModel(FastByIDMap<PreferenceArray> trainingData) {
+ return new GenericBooleanPrefDataModel(GenericBooleanPrefDataModel.toDataMap(trainingData));
+ }
+ };
+ RecommenderIRStatsEvaluator evaluator = new GenericRecommenderIRStatsEvaluator();
+ IRStatistics stats = evaluator.evaluate(
+ builder, dataModelBuilder, model, null, 1, GenericRecommenderIRStatsEvaluator.CHOOSE_THRESHOLD, 1.0);
+
+ assertNotNull(stats);
+ assertEquals(0.666666666, stats.getPrecision(), EPSILON);
+ assertEquals(0.666666666, stats.getRecall(), EPSILON);
+ assertEquals(0.666666666, stats.getF1Measure(), EPSILON);
+ assertEquals(0.666666666, stats.getFNMeasure(2.0), EPSILON);
+ assertEquals(0.666666666, stats.getNormalizedDiscountedCumulativeGain(), EPSILON);
+ }
+
+ @Test
+ public void testIRStats() {
+ IRStatistics stats = new IRStatisticsImpl(0.3, 0.1, 0.2, 0.05, 0.15);
+ assertEquals(0.3, stats.getPrecision(), EPSILON);
+ assertEquals(0.1, stats.getRecall(), EPSILON);
+ assertEquals(0.15, stats.getF1Measure(), EPSILON);
+ assertEquals(0.11538461538462, stats.getFNMeasure(2.0), EPSILON);
+ assertEquals(0.05, stats.getNormalizedDiscountedCumulativeGain(), EPSILON);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluationRunner.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluationRunner.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluationRunner.java
new file mode 100644
index 0000000..852b3e0
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/eval/LoadEvaluationRunner.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.eval;
+
+import org.apache.mahout.cf.taste.impl.model.file.FileDataModel;
+import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood;
+import org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender;
+import org.apache.mahout.cf.taste.impl.recommender.GenericUserBasedRecommender;
+import org.apache.mahout.cf.taste.impl.similarity.EuclideanDistanceSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+
+import java.io.File;
+
+public final class LoadEvaluationRunner {
+
+ private static final int LOOPS = 10;
+
+ private LoadEvaluationRunner() {
+ }
+
+ public static void main(String[] args) throws Exception {
+
+ DataModel model = new FileDataModel(new File(args[0]));
+
+ int howMany = 10;
+ if (args.length > 1) {
+ howMany = Integer.parseInt(args[1]);
+ }
+
+ System.out.println("Run Items");
+ ItemSimilarity similarity = new EuclideanDistanceSimilarity(model);
+ Recommender recommender = new GenericItemBasedRecommender(model, similarity); // Use an item-item recommender
+ for (int i = 0; i < LOOPS; i++) {
+ LoadStatistics loadStats = LoadEvaluator.runLoad(recommender, howMany);
+ System.out.println(loadStats);
+ }
+
+ System.out.println("Run Users");
+ UserSimilarity userSim = new EuclideanDistanceSimilarity(model);
+ UserNeighborhood neighborhood = new NearestNUserNeighborhood(10, userSim, model);
+ recommender = new GenericUserBasedRecommender(model, neighborhood, userSim);
+ for (int i = 0; i < LOOPS; i++) {
+ LoadStatistics loadStats = LoadEvaluator.runLoad(recommender, howMany);
+ System.out.println(loadStats);
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/BooleanItemPreferenceArrayTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/BooleanItemPreferenceArrayTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/BooleanItemPreferenceArrayTest.java
new file mode 100644
index 0000000..384b120
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/BooleanItemPreferenceArrayTest.java
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.junit.Test;
+
+public final class BooleanItemPreferenceArrayTest extends TasteTestCase {
+
+ @Test
+ public void testUserID() {
+ PreferenceArray prefs = new BooleanItemPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ prefs.setItemID(0, 1L);
+ assertEquals(1L, prefs.getItemID(0));
+ assertEquals(1L, prefs.getItemID(1));
+ assertEquals(1L, prefs.getItemID(2));
+ }
+
+ @Test
+ public void testItemID() {
+ PreferenceArray prefs = new BooleanItemPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ prefs.setUserID(0, 1L);
+ prefs.setUserID(1, 2L);
+ prefs.setUserID(2, 3L);
+ assertEquals(1L, prefs.getUserID(0));
+ assertEquals(2L, prefs.getUserID(1));
+ assertEquals(3L, prefs.getUserID(2));
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void testSetValue() {
+ PreferenceArray prefs = new BooleanItemPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ assertEquals(1.0f, prefs.getValue(2), EPSILON);
+ prefs.setValue(0, 1.0f);
+ }
+
+ @Test
+ public void testHasPref() {
+ PreferenceArray prefs = new BooleanItemPreferenceArray(3);
+ prefs.set(0, new GenericPreference(1L, 3L, 5.0f));
+ assertTrue(prefs.hasPrefWithItemID(3L));
+ assertTrue(prefs.hasPrefWithUserID(1L));
+ assertFalse(prefs.hasPrefWithItemID(2L));
+ assertFalse(prefs.hasPrefWithUserID(2L));
+ }
+
+ @Test
+ public void testSort() {
+ PreferenceArray prefs = new BooleanItemPreferenceArray(3);
+ prefs.set(0, new GenericPreference(3L, 1L, 5.0f));
+ prefs.set(1, new GenericPreference(1L, 1L, 5.0f));
+ prefs.set(2, new GenericPreference(2L, 1L, 5.0f));
+ prefs.sortByUser();
+ assertEquals(1L, prefs.getUserID(0));
+ assertEquals(2L, prefs.getUserID(1));
+ assertEquals(3L, prefs.getUserID(2));
+ }
+
+ @Test
+ public void testClone() {
+ BooleanItemPreferenceArray prefs = new BooleanItemPreferenceArray(3);
+ prefs.set(0, new BooleanPreference(3L, 1L));
+ prefs.set(1, new BooleanPreference(1L, 1L));
+ prefs.set(2, new BooleanPreference(2L, 1L));
+ prefs = prefs.clone();
+ assertEquals(3L, prefs.getUserID(0));
+ assertEquals(1L, prefs.getItemID(1));
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/BooleanUserPreferenceArrayTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/BooleanUserPreferenceArrayTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/BooleanUserPreferenceArrayTest.java
new file mode 100644
index 0000000..fa1f88c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/BooleanUserPreferenceArrayTest.java
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.junit.Test;
+
+public final class BooleanUserPreferenceArrayTest extends TasteTestCase {
+
+ @Test
+ public void testUserID() {
+ PreferenceArray prefs = new BooleanUserPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ prefs.setUserID(0, 1L);
+ assertEquals(1L, prefs.getUserID(0));
+ assertEquals(1L, prefs.getUserID(1));
+ assertEquals(1L, prefs.getUserID(2));
+ }
+
+ @Test
+ public void testItemID() {
+ PreferenceArray prefs = new BooleanUserPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ prefs.setItemID(0, 1L);
+ prefs.setItemID(1, 2L);
+ prefs.setItemID(2, 3L);
+ assertEquals(1L, prefs.getItemID(0));
+ assertEquals(2L, prefs.getItemID(1));
+ assertEquals(3L, prefs.getItemID(2));
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void testSetValue() {
+ PreferenceArray prefs = new BooleanUserPreferenceArray(3);
+ assertEquals(1.0f, prefs.getValue(2), EPSILON);
+ assertEquals(3, prefs.length());
+ prefs.setValue(0, 1.0f);
+ }
+
+ @Test
+ public void testHasPref() {
+ PreferenceArray prefs = new BooleanUserPreferenceArray(3);
+ prefs.set(0, new GenericPreference(1L, 3L, 5.0f));
+ assertTrue(prefs.hasPrefWithItemID(3L));
+ assertTrue(prefs.hasPrefWithUserID(1L));
+ assertFalse(prefs.hasPrefWithItemID(2L));
+ assertFalse(prefs.hasPrefWithUserID(2L));
+ }
+
+ @Test
+ public void testSort() {
+ PreferenceArray prefs = new BooleanUserPreferenceArray(3);
+ prefs.set(0, new BooleanPreference(1L, 3L));
+ prefs.set(1, new BooleanPreference(1L, 1L));
+ prefs.set(2, new BooleanPreference(1L, 2L));
+ prefs.sortByItem();
+ assertEquals(1L, prefs.getItemID(0));
+ assertEquals(2L, prefs.getItemID(1));
+ assertEquals(3L, prefs.getItemID(2));
+ }
+
+ @Test
+ public void testClone() {
+ BooleanUserPreferenceArray prefs = new BooleanUserPreferenceArray(3);
+ prefs.set(0, new BooleanPreference(1L, 3L));
+ prefs.set(1, new BooleanPreference(1L, 1L));
+ prefs.set(2, new BooleanPreference(1L, 2L));
+ prefs = prefs.clone();
+ assertEquals(3L, prefs.getItemID(0));
+ assertEquals(1L, prefs.getUserID(1));
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericDataModelTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericDataModelTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericDataModelTest.java
new file mode 100644
index 0000000..75bf070
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericDataModelTest.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.junit.Test;
+
+/**
+ * Tests {@link GenericDataModel}.
+ */
+public final class GenericDataModelTest extends TasteTestCase {
+
+ @Test
+ public void testSerialization() throws Exception {
+ GenericDataModel model = (GenericDataModel) getDataModel();
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ ObjectOutputStream out = new ObjectOutputStream(baos);
+ out.writeObject(model);
+ ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
+ ObjectInputStream in = new ObjectInputStream(bais);
+ GenericDataModel newModel = (GenericDataModel) in.readObject();
+ assertEquals(model.getNumItems(), newModel.getNumItems());
+ assertEquals(model.getNumUsers(), newModel.getNumUsers());
+ assertEquals(model.getPreferencesFromUser(1L), newModel.getPreferencesFromUser(1L));
+ assertEquals(model.getPreferencesForItem(1L), newModel.getPreferencesForItem(1L));
+ assertEquals(model.getRawUserData(), newModel.getRawUserData());
+ }
+
+ // Lots of other stuff should be tested but is kind of covered by FileDataModelTest
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericItemPreferenceArrayTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericItemPreferenceArrayTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericItemPreferenceArrayTest.java
new file mode 100644
index 0000000..5d77d0e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericItemPreferenceArrayTest.java
@@ -0,0 +1,110 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.junit.Test;
+
+public final class GenericItemPreferenceArrayTest extends TasteTestCase {
+
+ @Test
+ public void testUserID() {
+ PreferenceArray prefs = new GenericItemPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ prefs.setItemID(0, 1L);
+ assertEquals(1L, prefs.getItemID(0));
+ assertEquals(1L, prefs.getItemID(1));
+ assertEquals(1L, prefs.getItemID(2));
+ }
+
+ @Test
+ public void testItemID() {
+ PreferenceArray prefs = new GenericItemPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ prefs.setUserID(0, 1L);
+ prefs.setUserID(1, 2L);
+ prefs.setUserID(2, 3L);
+ assertEquals(1L, prefs.getUserID(0));
+ assertEquals(2L, prefs.getUserID(1));
+ assertEquals(3L, prefs.getUserID(2));
+ }
+
+ @Test
+ public void testSetValue() {
+ PreferenceArray prefs = new GenericItemPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ prefs.setValue(0, 1.0f);
+ prefs.setValue(1, 2.0f);
+ prefs.setValue(2, 3.0f);
+ assertEquals(1.0f, prefs.getValue(0), EPSILON);
+ assertEquals(2.0f, prefs.getValue(1), EPSILON);
+ assertEquals(3.0f, prefs.getValue(2), EPSILON);
+ }
+
+ @Test
+ public void testHasPref() {
+ PreferenceArray prefs = new GenericItemPreferenceArray(3);
+ prefs.set(0, new GenericPreference(1L, 3L, 5.0f));
+ assertTrue(prefs.hasPrefWithItemID(3L));
+ assertTrue(prefs.hasPrefWithUserID(1L));
+ assertFalse(prefs.hasPrefWithItemID(2L));
+ assertFalse(prefs.hasPrefWithUserID(2L));
+ }
+
+ @Test
+ public void testSort() {
+ PreferenceArray prefs = new GenericItemPreferenceArray(3);
+ prefs.set(0, new GenericPreference(3L, 1L, 5.0f));
+ prefs.set(1, new GenericPreference(1L, 1L, 5.0f));
+ prefs.set(2, new GenericPreference(2L, 1L, 5.0f));
+ prefs.sortByUser();
+ assertEquals(1L, prefs.getUserID(0));
+ assertEquals(2L, prefs.getUserID(1));
+ assertEquals(3L, prefs.getUserID(2));
+ }
+
+ @Test
+ public void testSortValue() {
+ PreferenceArray prefs = new GenericItemPreferenceArray(3);
+ prefs.set(0, new GenericPreference(3L, 1L, 5.0f));
+ prefs.set(1, new GenericPreference(1L, 1L, 4.0f));
+ prefs.set(2, new GenericPreference(2L, 1L, 3.0f));
+ prefs.sortByValue();
+ assertEquals(2L, prefs.getUserID(0));
+ assertEquals(1L, prefs.getUserID(1));
+ assertEquals(3L, prefs.getUserID(2));
+ prefs.sortByValueReversed();
+ assertEquals(3L, prefs.getUserID(0));
+ assertEquals(1L, prefs.getUserID(1));
+ assertEquals(2L, prefs.getUserID(2));
+ }
+
+ @Test
+ public void testClone() {
+ GenericItemPreferenceArray prefs = new GenericItemPreferenceArray(3);
+ prefs.set(0, new GenericPreference(3L, 1L, 5.0f));
+ prefs.set(1, new GenericPreference(1L, 1L, 4.0f));
+ prefs.set(2, new GenericPreference(2L, 1L, 3.0f));
+ prefs = prefs.clone();
+ assertEquals(3L, prefs.getUserID(0));
+ assertEquals(1L, prefs.getItemID(1));
+ assertEquals(3.0f, prefs.getValue(2), EPSILON);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericUserPreferenceArrayTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericUserPreferenceArrayTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericUserPreferenceArrayTest.java
new file mode 100644
index 0000000..2bde8cc
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/GenericUserPreferenceArrayTest.java
@@ -0,0 +1,110 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.junit.Test;
+
+public final class GenericUserPreferenceArrayTest extends TasteTestCase {
+
+ @Test
+ public void testUserID() {
+ PreferenceArray prefs = new GenericUserPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ prefs.setUserID(0, 1L);
+ assertEquals(1L, prefs.getUserID(0));
+ assertEquals(1L, prefs.getUserID(1));
+ assertEquals(1L, prefs.getUserID(2));
+ }
+
+ @Test
+ public void testItemID() {
+ PreferenceArray prefs = new GenericUserPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ prefs.setItemID(0, 1L);
+ prefs.setItemID(1, 2L);
+ prefs.setItemID(2, 3L);
+ assertEquals(1L, prefs.getItemID(0));
+ assertEquals(2L, prefs.getItemID(1));
+ assertEquals(3L, prefs.getItemID(2));
+ }
+
+ @Test
+ public void testSetValue() {
+ PreferenceArray prefs = new GenericUserPreferenceArray(3);
+ assertEquals(3, prefs.length());
+ prefs.setValue(0, 1.0f);
+ prefs.setValue(1, 2.0f);
+ prefs.setValue(2, 3.0f);
+ assertEquals(1.0f, prefs.getValue(0), EPSILON);
+ assertEquals(2.0f, prefs.getValue(1), EPSILON);
+ assertEquals(3.0f, prefs.getValue(2), EPSILON);
+ }
+
+ @Test
+ public void testHasPref() {
+ PreferenceArray prefs = new GenericUserPreferenceArray(3);
+ prefs.set(0, new GenericPreference(1L, 3L, 5.0f));
+ assertTrue(prefs.hasPrefWithItemID(3L));
+ assertTrue(prefs.hasPrefWithUserID(1L));
+ assertFalse(prefs.hasPrefWithItemID(2L));
+ assertFalse(prefs.hasPrefWithUserID(2L));
+ }
+
+ @Test
+ public void testSort() {
+ PreferenceArray prefs = new GenericUserPreferenceArray(3);
+ prefs.set(0, new GenericPreference(1L, 3L, 5.0f));
+ prefs.set(1, new GenericPreference(1L, 1L, 5.0f));
+ prefs.set(2, new GenericPreference(1L, 2L, 5.0f));
+ prefs.sortByItem();
+ assertEquals(1L, prefs.getItemID(0));
+ assertEquals(2L, prefs.getItemID(1));
+ assertEquals(3L, prefs.getItemID(2));
+ }
+
+ @Test
+ public void testSortValue() {
+ PreferenceArray prefs = new GenericUserPreferenceArray(3);
+ prefs.set(0, new GenericPreference(1L, 3L, 5.0f));
+ prefs.set(1, new GenericPreference(1L, 1L, 4.0f));
+ prefs.set(2, new GenericPreference(1L, 2L, 3.0f));
+ prefs.sortByValue();
+ assertEquals(2L, prefs.getItemID(0));
+ assertEquals(1L, prefs.getItemID(1));
+ assertEquals(3L, prefs.getItemID(2));
+ prefs.sortByValueReversed();
+ assertEquals(3L, prefs.getItemID(0));
+ assertEquals(1L, prefs.getItemID(1));
+ assertEquals(2L, prefs.getItemID(2));
+ }
+
+ @Test
+ public void testClone() {
+ GenericUserPreferenceArray prefs = new GenericUserPreferenceArray(3);
+ prefs.set(0, new GenericPreference(1L, 3L, 5.0f));
+ prefs.set(1, new GenericPreference(1L, 1L, 4.0f));
+ prefs.set(2, new GenericPreference(1L, 2L, 3.0f));
+ prefs = prefs.clone();
+ assertEquals(3L, prefs.getItemID(0));
+ assertEquals(1L, prefs.getUserID(1));
+ assertEquals(3.0f, prefs.getValue(2), EPSILON);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/MemoryIDMigratorTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/MemoryIDMigratorTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/MemoryIDMigratorTest.java
new file mode 100644
index 0000000..c8c673b
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/MemoryIDMigratorTest.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.model.IDMigrator;
+
+import java.util.Collections;
+import org.apache.mahout.cf.taste.model.UpdatableIDMigrator;
+import org.junit.Test;
+
+public final class MemoryIDMigratorTest extends TasteTestCase {
+
+ private static final String DUMMY_STRING = "Mahout";
+ private static final long DUMMY_ID = -6311185995763544451L;
+
+ @Test
+ public void testToLong() {
+ IDMigrator migrator = new MemoryIDMigrator();
+ long id = migrator.toLongID(DUMMY_STRING);
+ assertEquals(DUMMY_ID, id);
+ }
+
+ @Test
+ public void testStore() throws Exception {
+ UpdatableIDMigrator migrator = new MemoryIDMigrator();
+ long id = migrator.toLongID(DUMMY_STRING);
+ assertNull(migrator.toStringID(id));
+ migrator.storeMapping(id, DUMMY_STRING);
+ assertEquals(DUMMY_STRING, migrator.toStringID(id));
+ }
+
+ @Test
+ public void testInitialize() throws Exception {
+ UpdatableIDMigrator migrator = new MemoryIDMigrator();
+ long id = migrator.toLongID(DUMMY_STRING);
+ assertNull(migrator.toStringID(id));
+ migrator.initialize(Collections.singleton(DUMMY_STRING));
+ assertEquals(DUMMY_STRING, migrator.toStringID(id));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousConcurrentUserDataModelTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousConcurrentUserDataModelTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousConcurrentUserDataModelTest.java
new file mode 100644
index 0000000..984ef6c
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousConcurrentUserDataModelTest.java
@@ -0,0 +1,313 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.util.Iterator;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public final class PlusAnonymousConcurrentUserDataModelTest extends MahoutTestCase {
+
+ /**
+ * Prepares a testable object without delegate data
+ */
+ private static PlusAnonymousConcurrentUserDataModel getTestableWithoutDelegateData(int maxConcurrentUsers) {
+ FastByIDMap<PreferenceArray> delegatePreferences = new FastByIDMap<PreferenceArray>();
+ return new PlusAnonymousConcurrentUserDataModel(new GenericDataModel(delegatePreferences), maxConcurrentUsers);
+ }
+
+ /**
+ * Prepares a testable object with delegate data
+ */
+ private static PlusAnonymousConcurrentUserDataModel getTestableWithDelegateData(
+ int maxConcurrentUsers, FastByIDMap<PreferenceArray> delegatePreferences) {
+ return new PlusAnonymousConcurrentUserDataModel(new GenericDataModel(delegatePreferences), maxConcurrentUsers);
+ }
+
+ /**
+ * Test taking the first available user
+ */
+ @Test
+ public void testTakeFirstAvailableUser() {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(10);
+ Long expResult = PlusAnonymousUserDataModel.TEMP_USER_ID;
+ Long result = instance.takeAvailableUser();
+ assertEquals(expResult, result);
+ }
+
+ /**
+ * Test taking the next available user
+ */
+ @Test
+ public void testTakeNextAvailableUser() {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(10);
+ // Skip first user
+ instance.takeAvailableUser();
+ Long result = instance.takeAvailableUser();
+ Long expResult = PlusAnonymousUserDataModel.TEMP_USER_ID + 1;
+ assertEquals(expResult, result);
+ }
+
+ /**
+ * Test taking an unavailable user
+ */
+ @Test
+ public void testTakeUnavailableUser() {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(1);
+ // Take the only available user
+ instance.takeAvailableUser();
+ // There are no more users available
+ assertNull(instance.takeAvailableUser());
+ }
+
+ /**
+ * Test releasing a valid previously taken user
+ */
+ @Test
+ public void testReleaseValidUser() {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(10);
+ Long takenUserID = instance.takeAvailableUser();
+ assertTrue(instance.releaseUser(takenUserID));
+ }
+
+ /**
+ * Test releasing an invalid user
+ */
+ @Test
+ public void testReleaseInvalidUser() {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(10);
+ assertFalse(instance.releaseUser(Long.MAX_VALUE));
+ }
+
+ /**
+ * Test releasing a user which had been released earlier
+ */
+ @Test
+ public void testReleasePreviouslyReleasedUser() {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(10);
+ Long takenUserID = instance.takeAvailableUser();
+ assertTrue(instance.releaseUser(takenUserID));
+ assertFalse(instance.releaseUser(takenUserID));
+ }
+
+ /**
+ * Test setting anonymous user preferences
+ */
+ @Test
+ public void testSetAndGetTempPreferences() throws TasteException {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(10);
+ Long anonymousUserID = instance.takeAvailableUser();
+ PreferenceArray tempPrefs = new GenericUserPreferenceArray(1);
+ tempPrefs.setUserID(0, anonymousUserID);
+ tempPrefs.setItemID(0, 1);
+ instance.setTempPrefs(tempPrefs, anonymousUserID);
+ assertEquals(tempPrefs, instance.getPreferencesFromUser(anonymousUserID));
+ instance.releaseUser(anonymousUserID);
+ }
+
+ /**
+ * Test setting and getting preferences from several concurrent anonymous users
+ */
+ @Test
+ public void testSetMultipleTempPreferences() throws TasteException {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(10);
+
+ Long anonymousUserID1 = instance.takeAvailableUser();
+ Long anonymousUserID2 = instance.takeAvailableUser();
+
+ PreferenceArray tempPrefs1 = new GenericUserPreferenceArray(1);
+ tempPrefs1.setUserID(0, anonymousUserID1);
+ tempPrefs1.setItemID(0, 1);
+
+ PreferenceArray tempPrefs2 = new GenericUserPreferenceArray(2);
+ tempPrefs2.setUserID(0, anonymousUserID2);
+ tempPrefs2.setItemID(0, 2);
+ tempPrefs2.setUserID(1, anonymousUserID2);
+ tempPrefs2.setItemID(1, 3);
+
+ instance.setTempPrefs(tempPrefs1, anonymousUserID1);
+ instance.setTempPrefs(tempPrefs2, anonymousUserID2);
+
+ assertEquals(tempPrefs1, instance.getPreferencesFromUser(anonymousUserID1));
+ assertEquals(tempPrefs2, instance.getPreferencesFromUser(anonymousUserID2));
+ }
+
+ /**
+ * Test counting the number of delegate users
+ */
+ @Test
+ public void testGetNumUsersWithDelegateUsersOnly() throws TasteException {
+ PreferenceArray prefs = new GenericUserPreferenceArray(1);
+ long sampleUserID = 1;
+ prefs.setUserID(0, sampleUserID);
+ long sampleItemID = 11;
+ prefs.setItemID(0, sampleItemID);
+
+ FastByIDMap<PreferenceArray> delegatePreferences = new FastByIDMap<PreferenceArray>();
+ delegatePreferences.put(sampleUserID, prefs);
+
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithDelegateData(10, delegatePreferences);
+
+ assertEquals(1, instance.getNumUsers());
+ }
+
+ /**
+ * Test counting the number of anonymous users
+ */
+ @Test
+ public void testGetNumAnonymousUsers() throws TasteException {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(10);
+
+ Long anonymousUserID1 = instance.takeAvailableUser();
+
+ PreferenceArray tempPrefs1 = new GenericUserPreferenceArray(1);
+ tempPrefs1.setUserID(0, anonymousUserID1);
+ tempPrefs1.setItemID(0, 1);
+
+ instance.setTempPrefs(tempPrefs1, anonymousUserID1);
+
+ // Anonymous users should not be included into the universe.
+ assertEquals(0, instance.getNumUsers());
+ }
+
+ /**
+ * Test retrieve a single preference value of an anonymous user
+ */
+ @Test
+ public void testGetPreferenceValue() throws TasteException {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(10);
+
+ Long anonymousUserID = instance.takeAvailableUser();
+
+ PreferenceArray tempPrefs = new GenericUserPreferenceArray(1);
+ tempPrefs.setUserID(0, anonymousUserID);
+ long sampleItemID = 1;
+ tempPrefs.setItemID(0, sampleItemID);
+ tempPrefs.setValue(0, Float.MAX_VALUE);
+
+ instance.setTempPrefs(tempPrefs, anonymousUserID);
+
+ assertEquals(Float.MAX_VALUE, instance.getPreferenceValue(anonymousUserID, sampleItemID), EPSILON);
+ }
+
+ /**
+ * Test retrieve preferences for existing non-anonymous user
+ */
+ @Test
+ public void testGetPreferencesForNonAnonymousUser() throws TasteException {
+ PreferenceArray prefs = new GenericUserPreferenceArray(1);
+ long sampleUserID = 1;
+ prefs.setUserID(0, sampleUserID);
+ long sampleItemID = 11;
+ prefs.setItemID(0, sampleItemID);
+
+ FastByIDMap<PreferenceArray> delegatePreferences = new FastByIDMap<PreferenceArray>();
+ delegatePreferences.put(sampleUserID, prefs);
+
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithDelegateData(10, delegatePreferences);
+
+ assertEquals(prefs, instance.getPreferencesFromUser(sampleUserID));
+ }
+
+ /**
+ * Test retrieve preferences for non-anonymous and non-existing user
+ */
+ @Test(expected=NoSuchUserException.class)
+ public void testGetPreferencesForNonExistingUser() throws TasteException {
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithoutDelegateData(10);
+ // Exception is expected since such user does not exist
+ instance.getPreferencesFromUser(1);
+ }
+
+ /**
+ * Test retrieving the user IDs and verifying that anonymous ones are not included
+ */
+ @Test
+ public void testGetUserIDs() throws TasteException {
+ PreferenceArray prefs = new GenericUserPreferenceArray(1);
+ long sampleUserID = 1;
+ prefs.setUserID(0, sampleUserID);
+ long sampleItemID = 11;
+ prefs.setItemID(0, sampleItemID);
+
+ FastByIDMap<PreferenceArray> delegatePreferences = new FastByIDMap<PreferenceArray>();
+ delegatePreferences.put(sampleUserID, prefs);
+
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithDelegateData(10, delegatePreferences);
+
+ Long anonymousUserID = instance.takeAvailableUser();
+
+ PreferenceArray tempPrefs = new GenericUserPreferenceArray(1);
+ tempPrefs.setUserID(0, anonymousUserID);
+ tempPrefs.setItemID(0, 22);
+
+ instance.setTempPrefs(tempPrefs, anonymousUserID);
+
+ Iterator<Long> userIDs = instance.getUserIDs();
+
+ assertSame(sampleUserID, userIDs.next());
+ assertFalse(userIDs.hasNext());
+ }
+
+ /**
+ * Test getting preferences for an item.
+ *
+ * @throws TasteException
+ */
+ @Test
+ public void testGetPreferencesForItem() throws TasteException {
+ PreferenceArray prefs = new GenericUserPreferenceArray(2);
+ long sampleUserID = 4;
+ prefs.setUserID(0, sampleUserID);
+ long sampleItemID = 11;
+ prefs.setItemID(0, sampleItemID);
+ prefs.setUserID(1, sampleUserID);
+ long sampleItemID2 = 22;
+ prefs.setItemID(1, sampleItemID2);
+
+ FastByIDMap<PreferenceArray> delegatePreferences = new FastByIDMap<PreferenceArray>();
+ delegatePreferences.put(sampleUserID, prefs);
+
+ PlusAnonymousConcurrentUserDataModel instance = getTestableWithDelegateData(10, delegatePreferences);
+
+ Long anonymousUserID = instance.takeAvailableUser();
+
+ PreferenceArray tempPrefs = new GenericUserPreferenceArray(2);
+ tempPrefs.setUserID(0, anonymousUserID);
+ tempPrefs.setItemID(0, sampleItemID);
+ tempPrefs.setUserID(1, anonymousUserID);
+ long sampleItemID3 = 33;
+ tempPrefs.setItemID(1, sampleItemID3);
+
+ instance.setTempPrefs(tempPrefs, anonymousUserID);
+
+ assertEquals(sampleUserID, instance.getPreferencesForItem(sampleItemID).get(0).getUserID());
+ assertEquals(2, instance.getPreferencesForItem(sampleItemID).length());
+ assertEquals(1, instance.getPreferencesForItem(sampleItemID2).length());
+ assertEquals(1, instance.getPreferencesForItem(sampleItemID3).length());
+
+ assertEquals(2, instance.getNumUsersWithPreferenceFor(sampleItemID));
+ assertEquals(1, instance.getNumUsersWithPreferenceFor(sampleItemID, sampleItemID2));
+ assertEquals(1, instance.getNumUsersWithPreferenceFor(sampleItemID, sampleItemID3));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/file/FileDataModelTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/file/FileDataModelTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/file/FileDataModelTest.java
new file mode 100644
index 0000000..be59eee
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/file/FileDataModelTest.java
@@ -0,0 +1,216 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model.file;
+
+import java.io.File;
+import java.util.NoSuchElementException;
+
+import org.apache.commons.lang3.mutable.MutableBoolean;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood;
+import org.apache.mahout.cf.taste.impl.recommender.GenericUserBasedRecommender;
+import org.apache.mahout.cf.taste.impl.similarity.PearsonCorrelationSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import org.junit.Before;
+import org.junit.Test;
+
+/** <p>Tests {@link FileDataModel}.</p> */
+public final class FileDataModelTest extends TasteTestCase {
+
+ private static final String[] DATA = {
+ "123,456,0.1",
+ "123,789,0.6",
+ "123,654,0.7",
+ "234,123,0.5",
+ "234,234,1.0",
+ "234,999,0.9",
+ "345,789,0.6",
+ "345,654,0.7",
+ "345,123,1.0",
+ "345,234,0.5",
+ "345,999,0.5",
+ "456,456,0.1",
+ "456,789,0.5",
+ "456,654,0.0",
+ "456,999,0.2",};
+
+ private static final String[] DATA_SPLITTED_WITH_TWO_SPACES = {
+ "123 456 0.1",
+ "123 789 0.6",
+ "123 654 0.7",
+ "234 123 0.5",
+ "234 234 1.0",
+ "234 999 0.9",
+ "345 789 0.6",
+ "345 654 0.7",
+ "345 123 1.0",
+ "345 234 0.5",
+ "345 999 0.5",
+ "456 456 0.1",
+ "456 789 0.5",
+ "456 654 0.0",
+ "456 999 0.2",};
+
+ private DataModel model;
+ private File testFile;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ testFile = getTestTempFile("test.txt");
+ writeLines(testFile, DATA);
+ model = new FileDataModel(testFile);
+ }
+
+ @Test
+ public void testReadRegexSplittedFile() throws Exception {
+ File testFile = getTestTempFile("testRegex.txt");
+ writeLines(testFile, DATA_SPLITTED_WITH_TWO_SPACES);
+ FileDataModel model = new FileDataModel(testFile,"\\s+");
+ assertEquals(model.getItemIDsFromUser(123).size(), 3);
+ assertEquals(model.getItemIDsFromUser(456).size(), 4);
+ }
+
+ @Test
+ public void testFile() throws Exception {
+ UserSimilarity userSimilarity = new PearsonCorrelationSimilarity(model);
+ UserNeighborhood neighborhood = new NearestNUserNeighborhood(3, userSimilarity, model);
+ Recommender recommender = new GenericUserBasedRecommender(model, neighborhood, userSimilarity);
+ assertEquals(1, recommender.recommend(123, 3).size());
+ assertEquals(0, recommender.recommend(234, 3).size());
+ assertEquals(1, recommender.recommend(345, 3).size());
+
+ // Make sure this doesn't throw an exception
+ model.refresh(null);
+ }
+
+ @Test
+ public void testTranspose() throws Exception {
+ FileDataModel tModel = new FileDataModel(testFile, true, FileDataModel.DEFAULT_MIN_RELOAD_INTERVAL_MS);
+ PreferenceArray userPrefs = tModel.getPreferencesFromUser(456);
+ assertNotNull("user prefs are null and it shouldn't be", userPrefs);
+ PreferenceArray pref = tModel.getPreferencesForItem(123);
+ assertNotNull("pref is null and it shouldn't be", pref);
+ assertEquals("pref Size: " + pref.length() + " is not: " + 3, 3, pref.length());
+ }
+
+ @Test(expected = NoSuchElementException.class)
+ public void testGetItems() throws Exception {
+ LongPrimitiveIterator it = model.getItemIDs();
+ assertNotNull(it);
+ assertTrue(it.hasNext());
+ assertEquals(123, it.nextLong());
+ assertTrue(it.hasNext());
+ assertEquals(234, it.nextLong());
+ assertTrue(it.hasNext());
+ assertEquals(456, it.nextLong());
+ assertTrue(it.hasNext());
+ assertEquals(654, it.nextLong());
+ assertTrue(it.hasNext());
+ assertEquals(789, it.nextLong());
+ assertTrue(it.hasNext());
+ assertEquals(999, it.nextLong());
+ assertFalse(it.hasNext());
+ it.next();
+ }
+
+ @Test
+ public void testPreferencesForItem() throws Exception {
+ PreferenceArray prefs = model.getPreferencesForItem(456);
+ assertNotNull(prefs);
+ Preference pref1 = prefs.get(0);
+ assertEquals(123, pref1.getUserID());
+ assertEquals(456, pref1.getItemID());
+ Preference pref2 = prefs.get(1);
+ assertEquals(456, pref2.getUserID());
+ assertEquals(456, pref2.getItemID());
+ assertEquals(2, prefs.length());
+ }
+
+ @Test
+ public void testGetNumUsers() throws Exception {
+ assertEquals(4, model.getNumUsers());
+ }
+
+ @Test
+ public void testNumUsersPreferring() throws Exception {
+ assertEquals(2, model.getNumUsersWithPreferenceFor(456));
+ assertEquals(0, model.getNumUsersWithPreferenceFor(111));
+ assertEquals(0, model.getNumUsersWithPreferenceFor(111, 456));
+ assertEquals(2, model.getNumUsersWithPreferenceFor(123, 234));
+ }
+
+ @Test
+ public void testRefresh() throws Exception {
+ final MutableBoolean initialized = new MutableBoolean(false);
+ Runnable initializer = new Runnable() {
+ @Override
+ public void run() {
+ try {
+ model.getNumUsers();
+ initialized.setValue(true);
+ } catch (TasteException te) {
+ // oops
+ }
+ }
+ };
+ new Thread(initializer).start();
+ Thread.sleep(1000L); // wait a second for thread to start and call getNumUsers()
+ model.getNumUsers(); // should block
+ assertTrue(initialized.booleanValue());
+ assertEquals(4, model.getNumUsers());
+ }
+
+ @Test
+ public void testExplicitRefreshAfterCompleteFileUpdate() throws Exception {
+ File file = getTestTempFile("refresh");
+ writeLines(file, "123,456,3.0");
+
+ /* create a FileDataModel that always reloads when the underlying file has changed */
+ FileDataModel dataModel = new FileDataModel(file, false, 0L);
+ assertEquals(3.0f, dataModel.getPreferenceValue(123L, 456L), EPSILON);
+
+ /* change the underlying file,
+ * we have to wait at least a second to see the change in the file's lastModified timestamp */
+ Thread.sleep(2000L);
+ writeLines(file, "123,456,5.0");
+ dataModel.refresh(null);
+
+ assertEquals(5.0f, dataModel.getPreferenceValue(123L, 456L), EPSILON);
+ }
+
+ @Test
+ public void testToString() {
+ assertFalse(model.toString().isEmpty());
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testEmptyFile() throws Exception {
+ File file = getTestTempFile("empty");
+ writeLines(file); //required to create file.
+ new FileDataModel(file);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/file/FileIDMigratorTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/file/FileIDMigratorTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/file/FileIDMigratorTest.java
new file mode 100644
index 0000000..0a73315
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/model/file/FileIDMigratorTest.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model.file;
+
+import java.io.File;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.model.IDMigrator;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Tests {@link FileIDMigrator}
+ */
+public final class FileIDMigratorTest extends TasteTestCase {
+
+ private static final String[] STRING_IDS = {
+ "dog",
+ "cow" };
+
+ private static final String[] UPDATED_STRING_IDS = {
+ "dog",
+ "cow",
+ "donkey" };
+
+ private File testFile;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ testFile = getTestTempFile("test.txt");
+ writeLines(testFile, STRING_IDS);
+ }
+
+ @Test
+ public void testLoadFromFile() throws Exception {
+ IDMigrator migrator = new FileIDMigrator(testFile);
+ long dogAsLong = migrator.toLongID("dog");
+ long cowAsLong = migrator.toLongID("cow");
+ long donkeyAsLong = migrator.toLongID("donkey");
+ assertEquals("dog", migrator.toStringID(dogAsLong));
+ assertEquals("cow", migrator.toStringID(cowAsLong));
+ assertNull(migrator.toStringID(donkeyAsLong));
+ }
+
+ @Test
+ public void testNoRefreshAfterFileUpdate() throws Exception {
+ IDMigrator migrator = new FileIDMigrator(testFile, 0L);
+
+ /* call a method to make sure the original file is loaded */
+ long dogAsLong = migrator.toLongID("dog");
+ migrator.toStringID(dogAsLong);
+
+ /* change the underlying file,
+ * we have to wait at least a second to see the change in the file's lastModified timestamp */
+ Thread.sleep(2000L);
+ writeLines(testFile, UPDATED_STRING_IDS);
+
+ /* we shouldn't see any changes in the data as we have not yet refreshed */
+ long cowAsLong = migrator.toLongID("cow");
+ long donkeyAsLong = migrator.toLongID("donkey");
+ assertEquals("dog", migrator.toStringID(dogAsLong));
+ assertEquals("cow", migrator.toStringID(cowAsLong));
+ assertNull(migrator.toStringID(donkeyAsLong));
+ }
+
+ @Test
+ public void testRefreshAfterFileUpdate() throws Exception {
+ IDMigrator migrator = new FileIDMigrator(testFile, 0L);
+
+ /* call a method to make sure the original file is loaded */
+ long dogAsLong = migrator.toLongID("dog");
+ migrator.toStringID(dogAsLong);
+
+ /* change the underlying file,
+ * we have to wait at least a second to see the change in the file's lastModified timestamp */
+ Thread.sleep(2000L);
+ writeLines(testFile, UPDATED_STRING_IDS);
+
+ migrator.refresh(null);
+
+ long cowAsLong = migrator.toLongID("cow");
+ long donkeyAsLong = migrator.toLongID("donkey");
+ assertEquals("dog", migrator.toStringID(dogAsLong));
+ assertEquals("cow", migrator.toStringID(cowAsLong));
+ assertEquals("donkey", migrator.toStringID(donkeyAsLong));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/DummySimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/DummySimilarity.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/DummySimilarity.java
new file mode 100644
index 0000000..b057e5b
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/DummySimilarity.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.neighborhood;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.similarity.AbstractItemSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+
+import java.util.Collection;
+
+final class DummySimilarity extends AbstractItemSimilarity implements UserSimilarity {
+
+ DummySimilarity(DataModel dataModel) {
+ super(dataModel);
+ }
+
+ @Override
+ public double userSimilarity(long userID1, long userID2) throws TasteException {
+ DataModel dataModel = getDataModel();
+ return 1.0 / (1.0 + Math.abs(dataModel.getPreferencesFromUser(userID1).get(0).getValue()
+ - dataModel.getPreferencesFromUser(userID2).get(0).getValue()));
+ }
+
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) {
+ // Make up something wacky
+ return 1.0 / (1.0 + Math.abs(itemID1 - itemID2));
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) {
+ int length = itemID2s.length;
+ double[] result = new double[length];
+ for (int i = 0; i < length; i++) {
+ result[i] = itemSimilarity(itemID1, itemID2s[i]);
+ }
+ return result;
+ }
+
+ @Override
+ public void setPreferenceInferrer(PreferenceInferrer inferrer) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // do nothing
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNNeighborhoodTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNNeighborhoodTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNNeighborhoodTest.java
new file mode 100644
index 0000000..729dc9a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNNeighborhoodTest.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.neighborhood;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.junit.Test;
+
+/** <p>Tests {@link NearestNUserNeighborhood}.</p> */
+public final class NearestNNeighborhoodTest extends TasteTestCase {
+
+ @Test
+ public void testNeighborhood() throws Exception {
+ DataModel dataModel = getDataModel();
+
+ long[] neighborhood =
+ new NearestNUserNeighborhood(1, new DummySimilarity(dataModel), dataModel).getUserNeighborhood(1);
+ assertNotNull(neighborhood);
+ assertEquals(1, neighborhood.length);
+ assertTrue(arrayContains(neighborhood, 2));
+
+ long[] neighborhood2 =
+ new NearestNUserNeighborhood(2, new DummySimilarity(dataModel), dataModel).getUserNeighborhood(2);
+ assertNotNull(neighborhood2);
+ assertEquals(2, neighborhood2.length);
+ assertTrue(arrayContains(neighborhood2, 1));
+ assertTrue(arrayContains(neighborhood2, 3));
+
+ long[] neighborhood3 =
+ new NearestNUserNeighborhood(4, new DummySimilarity(dataModel), dataModel).getUserNeighborhood(4);
+ assertNotNull(neighborhood3);
+ assertEquals(3, neighborhood3.length);
+ assertTrue(arrayContains(neighborhood3, 1));
+ assertTrue(arrayContains(neighborhood3, 2));
+ assertTrue(arrayContains(neighborhood3, 3));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdNeighborhoodTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdNeighborhoodTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdNeighborhoodTest.java
new file mode 100644
index 0000000..c3005a9
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdNeighborhoodTest.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.neighborhood;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.junit.Test;
+
+/** <p>Tests {@link ThresholdUserNeighborhood}.</p> */
+public final class ThresholdNeighborhoodTest extends TasteTestCase {
+
+ @Test
+ public void testNeighborhood() throws Exception {
+ DataModel dataModel = getDataModel();
+
+ long[] neighborhood =
+ new ThresholdUserNeighborhood(1.0, new DummySimilarity(dataModel), dataModel).getUserNeighborhood(1);
+ assertNotNull(neighborhood);
+ assertEquals(0, neighborhood.length);
+
+ long[] neighborhood2 =
+ new ThresholdUserNeighborhood(0.8, new DummySimilarity(dataModel), dataModel).getUserNeighborhood(1);
+ assertNotNull(neighborhood2);
+ assertEquals(1, neighborhood2.length);
+ assertTrue(arrayContains(neighborhood2, 2));
+
+ long[] neighborhood3 =
+ new ThresholdUserNeighborhood(0.6, new DummySimilarity(dataModel), dataModel).getUserNeighborhood(2);
+ assertNotNull(neighborhood3);
+ assertEquals(3, neighborhood3.length);
+ assertTrue(arrayContains(neighborhood3, 1));
+ assertTrue(arrayContains(neighborhood3, 3));
+ assertTrue(arrayContains(neighborhood3, 4));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/AllUnknownItemsCandidateItemsStrategyTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/AllUnknownItemsCandidateItemsStrategyTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/AllUnknownItemsCandidateItemsStrategyTest.java
new file mode 100644
index 0000000..5d9f8cc
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/AllUnknownItemsCandidateItemsStrategyTest.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+import java.util.Collections;
+
+/**
+ * Tests {@link AllUnknownItemsCandidateItemsStrategyTest}
+ */
+public final class AllUnknownItemsCandidateItemsStrategyTest extends TasteTestCase {
+
+ @Test
+ public void testStrategy() throws TasteException {
+ FastIDSet allItemIDs = new FastIDSet();
+ allItemIDs.addAll(new long[] { 1L, 2L, 3L });
+
+ FastIDSet preferredItemIDs = new FastIDSet(1);
+ preferredItemIDs.add(2L);
+
+ DataModel dataModel = EasyMock.createMock(DataModel.class);
+ EasyMock.expect(dataModel.getNumItems()).andReturn(3);
+ EasyMock.expect(dataModel.getItemIDs()).andReturn(allItemIDs.iterator());
+
+ PreferenceArray prefArrayOfUser123 = new GenericUserPreferenceArray(Collections.singletonList(
+ new GenericPreference(123L, 2L, 1.0f)));
+
+ CandidateItemsStrategy strategy = new AllUnknownItemsCandidateItemsStrategy();
+
+ EasyMock.replay(dataModel);
+
+ FastIDSet candidateItems = strategy.getCandidateItems(123L, prefArrayOfUser123, dataModel, false);
+ assertEquals(2, candidateItems.size());
+ assertTrue(candidateItems.contains(1L));
+ assertTrue(candidateItems.contains(3L));
+
+ EasyMock.verify(dataModel);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/CachingRecommenderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/CachingRecommenderTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/CachingRecommenderTest.java
new file mode 100644
index 0000000..3ae35b0
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/CachingRecommenderTest.java
@@ -0,0 +1,78 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.commons.lang3.mutable.MutableInt;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.junit.Test;
+
+/** <p>Tests {@link CachingRecommender}.</p> */
+public final class CachingRecommenderTest extends TasteTestCase {
+
+ @Test
+ public void testRecommender() throws Exception {
+ MutableInt recommendCount = new MutableInt();
+ Recommender mockRecommender = new MockRecommender(recommendCount);
+
+ Recommender cachingRecommender = new CachingRecommender(mockRecommender);
+ cachingRecommender.recommend(1, 1);
+ assertEquals(1, recommendCount.intValue());
+ cachingRecommender.recommend(2, 1);
+ assertEquals(2, recommendCount.intValue());
+ cachingRecommender.recommend(1, 1);
+ assertEquals(2, recommendCount.intValue());
+ cachingRecommender.recommend(2, 1);
+ assertEquals(2, recommendCount.intValue());
+ cachingRecommender.refresh(null);
+ cachingRecommender.recommend(1, 1);
+ assertEquals(3, recommendCount.intValue());
+ cachingRecommender.recommend(2, 1);
+ assertEquals(4, recommendCount.intValue());
+ cachingRecommender.recommend(3, 1);
+ assertEquals(5, recommendCount.intValue());
+
+ // Results from this recommend() method can be cached...
+ IDRescorer rescorer = NullRescorer.getItemInstance();
+ cachingRecommender.refresh(null);
+ cachingRecommender.recommend(1, 1, rescorer);
+ assertEquals(6, recommendCount.intValue());
+ cachingRecommender.recommend(2, 1, rescorer);
+ assertEquals(7, recommendCount.intValue());
+ cachingRecommender.recommend(1, 1, rescorer);
+ assertEquals(7, recommendCount.intValue());
+ cachingRecommender.recommend(2, 1, rescorer);
+ assertEquals(7, recommendCount.intValue());
+
+ // until you switch Rescorers
+ cachingRecommender.recommend(1, 1, null);
+ assertEquals(8, recommendCount.intValue());
+ cachingRecommender.recommend(2, 1, null);
+ assertEquals(9, recommendCount.intValue());
+
+ cachingRecommender.refresh(null);
+ cachingRecommender.estimatePreference(1, 1);
+ assertEquals(10, recommendCount.intValue());
+ cachingRecommender.estimatePreference(1, 2);
+ assertEquals(11, recommendCount.intValue());
+ cachingRecommender.estimatePreference(1, 2);
+ assertEquals(11, recommendCount.intValue());
+ }
+
+}
[40/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java
new file mode 100644
index 0000000..59c30d9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AbstractSimilarity.java
@@ -0,0 +1,343 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.common.Weighting;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+
+import com.google.common.base.Preconditions;
+
+/** Abstract superclass encapsulating functionality that is common to most implementations in this package. */
+abstract class AbstractSimilarity extends AbstractItemSimilarity implements UserSimilarity {
+
+ private PreferenceInferrer inferrer;
+ private final boolean weighted;
+ private final boolean centerData;
+ private int cachedNumItems;
+ private int cachedNumUsers;
+ private final RefreshHelper refreshHelper;
+
+ /**
+ * <p>
+ * Creates a possibly weighted {@link AbstractSimilarity}.
+ * </p>
+ */
+ AbstractSimilarity(final DataModel dataModel, Weighting weighting, boolean centerData) throws TasteException {
+ super(dataModel);
+ this.weighted = weighting == Weighting.WEIGHTED;
+ this.centerData = centerData;
+ this.cachedNumItems = dataModel.getNumItems();
+ this.cachedNumUsers = dataModel.getNumUsers();
+ this.refreshHelper = new RefreshHelper(new Callable<Object>() {
+ @Override
+ public Object call() throws TasteException {
+ cachedNumItems = dataModel.getNumItems();
+ cachedNumUsers = dataModel.getNumUsers();
+ return null;
+ }
+ });
+ }
+
+ final PreferenceInferrer getPreferenceInferrer() {
+ return inferrer;
+ }
+
+ @Override
+ public final void setPreferenceInferrer(PreferenceInferrer inferrer) {
+ Preconditions.checkArgument(inferrer != null, "inferrer is null");
+ refreshHelper.addDependency(inferrer);
+ refreshHelper.removeDependency(this.inferrer);
+ this.inferrer = inferrer;
+ }
+
+ final boolean isWeighted() {
+ return weighted;
+ }
+
+ /**
+ * <p>
+ * Several subclasses in this package implement this method to actually compute the similarity from figures
+ * computed over users or items. Note that the computations in this class "center" the data, such that X and
+ * Y's mean are 0.
+ * </p>
+ *
+ * <p>
+ * Note that the sum of all X and Y values must then be 0. This value isn't passed down into the standard
+ * similarity computations as a result.
+ * </p>
+ *
+ * @param n
+ * total number of users or items
+ * @param sumXY
+ * sum of product of user/item preference values, over all items/users preferred by both
+ * users/items
+ * @param sumX2
+ * sum of the square of user/item preference values, over the first item/user
+ * @param sumY2
+ * sum of the square of the user/item preference values, over the second item/user
+ * @param sumXYdiff2
+ * sum of squares of differences in X and Y values
+ * @return similarity value between -1.0 and 1.0, inclusive, or {@link Double#NaN} if no similarity can be
+ * computed (e.g. when no items have been rated by both users
+ */
+ abstract double computeResult(int n, double sumXY, double sumX2, double sumY2, double sumXYdiff2);
+
+ @Override
+ public double userSimilarity(long userID1, long userID2) throws TasteException {
+ DataModel dataModel = getDataModel();
+ PreferenceArray xPrefs = dataModel.getPreferencesFromUser(userID1);
+ PreferenceArray yPrefs = dataModel.getPreferencesFromUser(userID2);
+ int xLength = xPrefs.length();
+ int yLength = yPrefs.length();
+
+ if (xLength == 0 || yLength == 0) {
+ return Double.NaN;
+ }
+
+ long xIndex = xPrefs.getItemID(0);
+ long yIndex = yPrefs.getItemID(0);
+ int xPrefIndex = 0;
+ int yPrefIndex = 0;
+
+ double sumX = 0.0;
+ double sumX2 = 0.0;
+ double sumY = 0.0;
+ double sumY2 = 0.0;
+ double sumXY = 0.0;
+ double sumXYdiff2 = 0.0;
+ int count = 0;
+
+ boolean hasInferrer = inferrer != null;
+
+ while (true) {
+ int compare = xIndex < yIndex ? -1 : xIndex > yIndex ? 1 : 0;
+ if (hasInferrer || compare == 0) {
+ double x;
+ double y;
+ if (xIndex == yIndex) {
+ // Both users expressed a preference for the item
+ x = xPrefs.getValue(xPrefIndex);
+ y = yPrefs.getValue(yPrefIndex);
+ } else {
+ // Only one user expressed a preference, but infer the other one's preference and tally
+ // as if the other user expressed that preference
+ if (compare < 0) {
+ // X has a value; infer Y's
+ x = xPrefs.getValue(xPrefIndex);
+ y = inferrer.inferPreference(userID2, xIndex);
+ } else {
+ // compare > 0
+ // Y has a value; infer X's
+ x = inferrer.inferPreference(userID1, yIndex);
+ y = yPrefs.getValue(yPrefIndex);
+ }
+ }
+ sumXY += x * y;
+ sumX += x;
+ sumX2 += x * x;
+ sumY += y;
+ sumY2 += y * y;
+ double diff = x - y;
+ sumXYdiff2 += diff * diff;
+ count++;
+ }
+ if (compare <= 0) {
+ if (++xPrefIndex >= xLength) {
+ if (hasInferrer) {
+ // Must count other Ys; pretend next X is far away
+ if (yIndex == Long.MAX_VALUE) {
+ // ... but stop if both are done!
+ break;
+ }
+ xIndex = Long.MAX_VALUE;
+ } else {
+ break;
+ }
+ } else {
+ xIndex = xPrefs.getItemID(xPrefIndex);
+ }
+ }
+ if (compare >= 0) {
+ if (++yPrefIndex >= yLength) {
+ if (hasInferrer) {
+ // Must count other Xs; pretend next Y is far away
+ if (xIndex == Long.MAX_VALUE) {
+ // ... but stop if both are done!
+ break;
+ }
+ yIndex = Long.MAX_VALUE;
+ } else {
+ break;
+ }
+ } else {
+ yIndex = yPrefs.getItemID(yPrefIndex);
+ }
+ }
+ }
+
+ // "Center" the data. If my math is correct, this'll do it.
+ double result;
+ if (centerData) {
+ double meanX = sumX / count;
+ double meanY = sumY / count;
+ // double centeredSumXY = sumXY - meanY * sumX - meanX * sumY + n * meanX * meanY;
+ double centeredSumXY = sumXY - meanY * sumX;
+ // double centeredSumX2 = sumX2 - 2.0 * meanX * sumX + n * meanX * meanX;
+ double centeredSumX2 = sumX2 - meanX * sumX;
+ // double centeredSumY2 = sumY2 - 2.0 * meanY * sumY + n * meanY * meanY;
+ double centeredSumY2 = sumY2 - meanY * sumY;
+ result = computeResult(count, centeredSumXY, centeredSumX2, centeredSumY2, sumXYdiff2);
+ } else {
+ result = computeResult(count, sumXY, sumX2, sumY2, sumXYdiff2);
+ }
+
+ if (!Double.isNaN(result)) {
+ result = normalizeWeightResult(result, count, cachedNumItems);
+ }
+ return result;
+ }
+
+ @Override
+ public final double itemSimilarity(long itemID1, long itemID2) throws TasteException {
+ DataModel dataModel = getDataModel();
+ PreferenceArray xPrefs = dataModel.getPreferencesForItem(itemID1);
+ PreferenceArray yPrefs = dataModel.getPreferencesForItem(itemID2);
+ int xLength = xPrefs.length();
+ int yLength = yPrefs.length();
+
+ if (xLength == 0 || yLength == 0) {
+ return Double.NaN;
+ }
+
+ long xIndex = xPrefs.getUserID(0);
+ long yIndex = yPrefs.getUserID(0);
+ int xPrefIndex = 0;
+ int yPrefIndex = 0;
+
+ double sumX = 0.0;
+ double sumX2 = 0.0;
+ double sumY = 0.0;
+ double sumY2 = 0.0;
+ double sumXY = 0.0;
+ double sumXYdiff2 = 0.0;
+ int count = 0;
+
+ // No, pref inferrers and transforms don't apply here. I think.
+
+ while (true) {
+ int compare = xIndex < yIndex ? -1 : xIndex > yIndex ? 1 : 0;
+ if (compare == 0) {
+ // Both users expressed a preference for the item
+ double x = xPrefs.getValue(xPrefIndex);
+ double y = yPrefs.getValue(yPrefIndex);
+ sumXY += x * y;
+ sumX += x;
+ sumX2 += x * x;
+ sumY += y;
+ sumY2 += y * y;
+ double diff = x - y;
+ sumXYdiff2 += diff * diff;
+ count++;
+ }
+ if (compare <= 0) {
+ if (++xPrefIndex == xLength) {
+ break;
+ }
+ xIndex = xPrefs.getUserID(xPrefIndex);
+ }
+ if (compare >= 0) {
+ if (++yPrefIndex == yLength) {
+ break;
+ }
+ yIndex = yPrefs.getUserID(yPrefIndex);
+ }
+ }
+
+ double result;
+ if (centerData) {
+ // See comments above on these computations
+ double n = (double) count;
+ double meanX = sumX / n;
+ double meanY = sumY / n;
+ // double centeredSumXY = sumXY - meanY * sumX - meanX * sumY + n * meanX * meanY;
+ double centeredSumXY = sumXY - meanY * sumX;
+ // double centeredSumX2 = sumX2 - 2.0 * meanX * sumX + n * meanX * meanX;
+ double centeredSumX2 = sumX2 - meanX * sumX;
+ // double centeredSumY2 = sumY2 - 2.0 * meanY * sumY + n * meanY * meanY;
+ double centeredSumY2 = sumY2 - meanY * sumY;
+ result = computeResult(count, centeredSumXY, centeredSumX2, centeredSumY2, sumXYdiff2);
+ } else {
+ result = computeResult(count, sumXY, sumX2, sumY2, sumXYdiff2);
+ }
+
+ if (!Double.isNaN(result)) {
+ result = normalizeWeightResult(result, count, cachedNumUsers);
+ }
+ return result;
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
+ int length = itemID2s.length;
+ double[] result = new double[length];
+ for (int i = 0; i < length; i++) {
+ result[i] = itemSimilarity(itemID1, itemID2s[i]);
+ }
+ return result;
+ }
+
+ final double normalizeWeightResult(double result, int count, int num) {
+ double normalizedResult = result;
+ if (weighted) {
+ double scaleFactor = 1.0 - (double) count / (double) (num + 1);
+ if (normalizedResult < 0.0) {
+ normalizedResult = -1.0 + scaleFactor * (1.0 + normalizedResult);
+ } else {
+ normalizedResult = 1.0 - scaleFactor * (1.0 - normalizedResult);
+ }
+ }
+ // Make sure the result is not accidentally a little outside [-1.0, 1.0] due to rounding:
+ if (normalizedResult < -1.0) {
+ normalizedResult = -1.0;
+ } else if (normalizedResult > 1.0) {
+ normalizedResult = 1.0;
+ }
+ return normalizedResult;
+ }
+
+ @Override
+ public final void refresh(Collection<Refreshable> alreadyRefreshed) {
+ super.refresh(alreadyRefreshed);
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+ @Override
+ public final String toString() {
+ return this.getClass().getSimpleName() + "[dataModel:" + getDataModel() + ",inferrer:" + inferrer + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java
new file mode 100644
index 0000000..7c655fe
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrer.java
@@ -0,0 +1,85 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.Cache;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.Retriever;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+
+/**
+ * <p>
+ * Implementations of this interface compute an inferred preference for a user and an item that the user has
+ * not expressed any preference for. This might be an average of other preferences scores from that user, for
+ * example. This technique is sometimes called "default voting".
+ * </p>
+ */
+public final class AveragingPreferenceInferrer implements PreferenceInferrer {
+
+ private static final Float ZERO = 0.0f;
+
+ private final DataModel dataModel;
+ private final Cache<Long,Float> averagePreferenceValue;
+
+ public AveragingPreferenceInferrer(DataModel dataModel) throws TasteException {
+ this.dataModel = dataModel;
+ Retriever<Long,Float> retriever = new PrefRetriever();
+ averagePreferenceValue = new Cache<>(retriever, dataModel.getNumUsers());
+ refresh(null);
+ }
+
+ @Override
+ public float inferPreference(long userID, long itemID) throws TasteException {
+ return averagePreferenceValue.get(userID);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ averagePreferenceValue.clear();
+ }
+
+ private final class PrefRetriever implements Retriever<Long,Float> {
+
+ @Override
+ public Float get(Long key) throws TasteException {
+ PreferenceArray prefs = dataModel.getPreferencesFromUser(key);
+ int size = prefs.length();
+ if (size == 0) {
+ return ZERO;
+ }
+ RunningAverage average = new FullRunningAverage();
+ for (int i = 0; i < size; i++) {
+ average.addDatum(prefs.getValue(i));
+ }
+ return (float) average.getAverage();
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "AveragingPreferenceInferrer";
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java
new file mode 100644
index 0000000..87aeae9
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingItemSimilarity.java
@@ -0,0 +1,111 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.Cache;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.impl.common.Retriever;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.apache.mahout.common.LongPair;
+import com.google.common.base.Preconditions;
+
+/**
+ * Caches the results from an underlying {@link ItemSimilarity} implementation.
+ */
+public final class CachingItemSimilarity implements ItemSimilarity {
+
+ private final ItemSimilarity similarity;
+ private final Cache<LongPair,Double> similarityCache;
+ private final RefreshHelper refreshHelper;
+
+ /**
+ * Creates this on top of the given {@link ItemSimilarity}.
+ * The cache is sized according to properties of the given {@link DataModel}.
+ */
+ public CachingItemSimilarity(ItemSimilarity similarity, DataModel dataModel) throws TasteException {
+ this(similarity, dataModel.getNumItems());
+ }
+
+ /**
+ * Creates this on top of the given {@link ItemSimilarity}.
+ * The cache size is capped by the given size.
+ */
+ public CachingItemSimilarity(ItemSimilarity similarity, int maxCacheSize) {
+ Preconditions.checkArgument(similarity != null, "similarity is null");
+ this.similarity = similarity;
+ this.similarityCache = new Cache<>(new SimilarityRetriever(similarity), maxCacheSize);
+ this.refreshHelper = new RefreshHelper(new Callable<Void>() {
+ @Override
+ public Void call() {
+ similarityCache.clear();
+ return null;
+ }
+ });
+ refreshHelper.addDependency(similarity);
+ }
+
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
+ LongPair key = itemID1 < itemID2 ? new LongPair(itemID1, itemID2) : new LongPair(itemID2, itemID1);
+ return similarityCache.get(key);
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
+ int length = itemID2s.length;
+ double[] result = new double[length];
+ for (int i = 0; i < length; i++) {
+ result[i] = itemSimilarity(itemID1, itemID2s[i]);
+ }
+ return result;
+ }
+
+ @Override
+ public long[] allSimilarItemIDs(long itemID) throws TasteException {
+ return similarity.allSimilarItemIDs(itemID);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+ public void clearCacheForItem(long itemID) {
+ similarityCache.removeKeysMatching(new LongPairMatchPredicate(itemID));
+ }
+
+ private static final class SimilarityRetriever implements Retriever<LongPair,Double> {
+ private final ItemSimilarity similarity;
+
+ private SimilarityRetriever(ItemSimilarity similarity) {
+ this.similarity = similarity;
+ }
+
+ @Override
+ public Double get(LongPair key) throws TasteException {
+ return similarity.itemSimilarity(key.getFirst(), key.getSecond());
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingUserSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingUserSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingUserSimilarity.java
new file mode 100644
index 0000000..873568a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CachingUserSimilarity.java
@@ -0,0 +1,104 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+import java.util.concurrent.Callable;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.Cache;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.impl.common.Retriever;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import org.apache.mahout.common.LongPair;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Caches the results from an underlying {@link UserSimilarity} implementation.
+ */
+public final class CachingUserSimilarity implements UserSimilarity {
+
+ private final UserSimilarity similarity;
+ private final Cache<LongPair,Double> similarityCache;
+ private final RefreshHelper refreshHelper;
+
+ /**
+ * Creates this on top of the given {@link UserSimilarity}.
+ * The cache is sized according to properties of the given {@link DataModel}.
+ */
+ public CachingUserSimilarity(UserSimilarity similarity, DataModel dataModel) throws TasteException {
+ this(similarity, dataModel.getNumUsers());
+ }
+
+ /**
+ * Creates this on top of the given {@link UserSimilarity}.
+ * The cache size is capped by the given size.
+ */
+ public CachingUserSimilarity(UserSimilarity similarity, int maxCacheSize) {
+ Preconditions.checkArgument(similarity != null, "similarity is null");
+ this.similarity = similarity;
+ this.similarityCache = new Cache<>(new SimilarityRetriever(similarity), maxCacheSize);
+ this.refreshHelper = new RefreshHelper(new Callable<Void>() {
+ @Override
+ public Void call() {
+ similarityCache.clear();
+ return null;
+ }
+ });
+ refreshHelper.addDependency(similarity);
+ }
+
+ @Override
+ public double userSimilarity(long userID1, long userID2) throws TasteException {
+ LongPair key = userID1 < userID2 ? new LongPair(userID1, userID2) : new LongPair(userID2, userID1);
+ return similarityCache.get(key);
+ }
+
+ @Override
+ public void setPreferenceInferrer(PreferenceInferrer inferrer) {
+ similarityCache.clear();
+ similarity.setPreferenceInferrer(inferrer);
+ }
+
+ public void clearCacheForUser(long userID) {
+ similarityCache.removeKeysMatching(new LongPairMatchPredicate(userID));
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+ private static final class SimilarityRetriever implements Retriever<LongPair,Double> {
+ private final UserSimilarity similarity;
+
+ private SimilarityRetriever(UserSimilarity similarity) {
+ this.similarity = similarity;
+ }
+
+ @Override
+ public Double get(LongPair key) throws TasteException {
+ return similarity.userSimilarity(key.getFirst(), key.getSecond());
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CityBlockSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CityBlockSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CityBlockSimilarity.java
new file mode 100644
index 0000000..88fbe58
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/CityBlockSimilarity.java
@@ -0,0 +1,98 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+
+/**
+ * Implementation of City Block distance (also known as Manhattan distance) - the absolute value of the difference of
+ * each direction is summed. The resulting unbounded distance is then mapped between 0 and 1.
+ */
+public final class CityBlockSimilarity extends AbstractItemSimilarity implements UserSimilarity {
+
+ public CityBlockSimilarity(DataModel dataModel) {
+ super(dataModel);
+ }
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ @Override
+ public void setPreferenceInferrer(PreferenceInferrer inferrer) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ Collection<Refreshable> refreshed = RefreshHelper.buildRefreshed(alreadyRefreshed);
+ RefreshHelper.maybeRefresh(refreshed, getDataModel());
+ }
+
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
+ DataModel dataModel = getDataModel();
+ int preferring1 = dataModel.getNumUsersWithPreferenceFor(itemID1);
+ int preferring2 = dataModel.getNumUsersWithPreferenceFor(itemID2);
+ int intersection = dataModel.getNumUsersWithPreferenceFor(itemID1, itemID2);
+ return doSimilarity(preferring1, preferring2, intersection);
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
+ DataModel dataModel = getDataModel();
+ int preferring1 = dataModel.getNumUsersWithPreferenceFor(itemID1);
+ double[] distance = new double[itemID2s.length];
+ for (int i = 0; i < itemID2s.length; ++i) {
+ int preferring2 = dataModel.getNumUsersWithPreferenceFor(itemID2s[i]);
+ int intersection = dataModel.getNumUsersWithPreferenceFor(itemID1, itemID2s[i]);
+ distance[i] = doSimilarity(preferring1, preferring2, intersection);
+ }
+ return distance;
+ }
+
+ @Override
+ public double userSimilarity(long userID1, long userID2) throws TasteException {
+ DataModel dataModel = getDataModel();
+ FastIDSet prefs1 = dataModel.getItemIDsFromUser(userID1);
+ FastIDSet prefs2 = dataModel.getItemIDsFromUser(userID2);
+ int prefs1Size = prefs1.size();
+ int prefs2Size = prefs2.size();
+ int intersectionSize = prefs1Size < prefs2Size ? prefs2.intersectionSize(prefs1) : prefs1.intersectionSize(prefs2);
+ return doSimilarity(prefs1Size, prefs2Size, intersectionSize);
+ }
+
+ /**
+ * Calculate City Block Distance from total non-zero values and intersections and map to a similarity value.
+ *
+ * @param pref1 number of non-zero values in left vector
+ * @param pref2 number of non-zero values in right vector
+ * @param intersection number of overlapping non-zero values
+ */
+ private static double doSimilarity(int pref1, int pref2, int intersection) {
+ int distance = pref1 + pref2 - 2 * intersection;
+ return 1.0 / (1.0 + distance);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarity.java
new file mode 100644
index 0000000..990e9ea
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/EuclideanDistanceSimilarity.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.common.Weighting;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * An implementation of a "similarity" based on the Euclidean "distance" between two users X and Y. Thinking
+ * of items as dimensions and preferences as points along those dimensions, a distance is computed using all
+ * items (dimensions) where both users have expressed a preference for that item. This is simply the square
+ * root of the sum of the squares of differences in position (preference) along each dimension.</p>
+ *
+ * <p>The similarity could be computed as 1 / (1 + distance / sqrt(n)), so the resulting values are in the range (0,1].
+ * This would weight against pairs that overlap in more dimensions, which should indicate more similarity,
+ * since more dimensions offer more opportunities to be farther apart. Actually, it is computed as
+ * sqrt(n) / (1 + distance), where n is the number of dimensions, in order to help correct for this.
+ * sqrt(n) is chosen since randomly-chosen points have a distance that grows as sqrt(n).</p>
+ *
+ * <p>Note that this could cause a similarity to exceed 1; such values are capped at 1.</p>
+ *
+ * <p>Note that the distance isn't normalized in any way; it's not valid to compare similarities computed from
+ * different domains (different rating scales, for example). Within one domain, normalizing doesn't matter much as
+ * it doesn't change ordering.</p>
+ */
+public final class EuclideanDistanceSimilarity extends AbstractSimilarity {
+
+ /**
+ * @throws IllegalArgumentException if {@link DataModel} does not have preference values
+ */
+ public EuclideanDistanceSimilarity(DataModel dataModel) throws TasteException {
+ this(dataModel, Weighting.UNWEIGHTED);
+ }
+
+ /**
+ * @throws IllegalArgumentException if {@link DataModel} does not have preference values
+ */
+ public EuclideanDistanceSimilarity(DataModel dataModel, Weighting weighting) throws TasteException {
+ super(dataModel, weighting, false);
+ Preconditions.checkArgument(dataModel.hasPreferenceValues(), "DataModel doesn't have preference values");
+ }
+
+ @Override
+ double computeResult(int n, double sumXY, double sumX2, double sumY2, double sumXYdiff2) {
+ return 1.0 / (1.0 + Math.sqrt(sumXYdiff2) / Math.sqrt(n));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarity.java
new file mode 100644
index 0000000..d0c9b8c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericItemSimilarity.java
@@ -0,0 +1,358 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+import java.util.Iterator;
+
+import com.google.common.collect.AbstractIterator;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.recommender.TopItems;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.apache.mahout.common.RandomUtils;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * A "generic" {@link ItemSimilarity} which takes a static list of precomputed item similarities and bases its
+ * responses on that alone. The values may have been precomputed offline by another process, stored in a file,
+ * and then read and fed into an instance of this class.
+ * </p>
+ *
+ * <p>
+ * This is perhaps the best {@link ItemSimilarity} to use with
+ * {@link org.apache.mahout.cf.taste.impl.recommender.GenericItemBasedRecommender}, for now, since the point
+ * of item-based recommenders is that they can take advantage of the fact that item similarity is relatively
+ * static, can be precomputed, and then used in computation to gain a significant performance advantage.
+ * </p>
+ */
+public final class GenericItemSimilarity implements ItemSimilarity {
+
+ private static final long[] NO_IDS = new long[0];
+
+ private final FastByIDMap<FastByIDMap<Double>> similarityMaps = new FastByIDMap<>();
+ private final FastByIDMap<FastIDSet> similarItemIDsIndex = new FastByIDMap<>();
+
+ /**
+ * <p>
+ * Creates a {@link GenericItemSimilarity} from a precomputed list of {@link ItemItemSimilarity}s. Each
+ * represents the similarity between two distinct items. Since similarity is assumed to be symmetric, it is
+ * not necessary to specify similarity between item1 and item2, and item2 and item1. Both are the same. It
+ * is also not necessary to specify a similarity between any item and itself; these are assumed to be 1.0.
+ * </p>
+ *
+ * <p>
+ * Note that specifying a similarity between two items twice is not an error, but, the later value will win.
+ * </p>
+ *
+ * @param similarities
+ * set of {@link ItemItemSimilarity}s on which to base this instance
+ */
+ public GenericItemSimilarity(Iterable<ItemItemSimilarity> similarities) {
+ initSimilarityMaps(similarities.iterator());
+ }
+
+ /**
+ * <p>
+ * Like {@link #GenericItemSimilarity(Iterable)}, but will only keep the specified number of similarities
+ * from the given {@link Iterable} of similarities. It will keep those with the highest similarity -- those
+ * that are therefore most important.
+ * </p>
+ *
+ * <p>
+ * Thanks to tsmorton for suggesting this and providing part of the implementation.
+ * </p>
+ *
+ * @param similarities
+ * set of {@link ItemItemSimilarity}s on which to base this instance
+ * @param maxToKeep
+ * maximum number of similarities to keep
+ */
+ public GenericItemSimilarity(Iterable<ItemItemSimilarity> similarities, int maxToKeep) {
+ Iterable<ItemItemSimilarity> keptSimilarities =
+ TopItems.getTopItemItemSimilarities(maxToKeep, similarities.iterator());
+ initSimilarityMaps(keptSimilarities.iterator());
+ }
+
+ /**
+ * <p>
+ * Builds a list of item-item similarities given an {@link ItemSimilarity} implementation and a
+ * {@link DataModel}, rather than a list of {@link ItemItemSimilarity}s.
+ * </p>
+ *
+ * <p>
+ * It's valid to build a {@link GenericItemSimilarity} this way, but perhaps missing some of the point of an
+ * item-based recommender. Item-based recommenders use the assumption that item-item similarities are
+ * relatively fixed, and might be known already independent of user preferences. Hence it is useful to
+ * inject that information, using {@link #GenericItemSimilarity(Iterable)}.
+ * </p>
+ *
+ * @param otherSimilarity
+ * other {@link ItemSimilarity} to get similarities from
+ * @param dataModel
+ * data model to get items from
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel} items
+ */
+ public GenericItemSimilarity(ItemSimilarity otherSimilarity, DataModel dataModel) throws TasteException {
+ long[] itemIDs = GenericUserSimilarity.longIteratorToList(dataModel.getItemIDs());
+ initSimilarityMaps(new DataModelSimilaritiesIterator(otherSimilarity, itemIDs));
+ }
+
+ /**
+ * <p>
+ * Like {@link #GenericItemSimilarity(ItemSimilarity, DataModel)} )}, but will only keep the specified
+ * number of similarities from the given {@link DataModel}. It will keep those with the highest similarity
+ * -- those that are therefore most important.
+ * </p>
+ *
+ * <p>
+ * Thanks to tsmorton for suggesting this and providing part of the implementation.
+ * </p>
+ *
+ * @param otherSimilarity
+ * other {@link ItemSimilarity} to get similarities from
+ * @param dataModel
+ * data model to get items from
+ * @param maxToKeep
+ * maximum number of similarities to keep
+ * @throws TasteException
+ * if an error occurs while accessing the {@link DataModel} items
+ */
+ public GenericItemSimilarity(ItemSimilarity otherSimilarity,
+ DataModel dataModel,
+ int maxToKeep) throws TasteException {
+ long[] itemIDs = GenericUserSimilarity.longIteratorToList(dataModel.getItemIDs());
+ Iterator<ItemItemSimilarity> it = new DataModelSimilaritiesIterator(otherSimilarity, itemIDs);
+ Iterable<ItemItemSimilarity> keptSimilarities = TopItems.getTopItemItemSimilarities(maxToKeep, it);
+ initSimilarityMaps(keptSimilarities.iterator());
+ }
+
+ private void initSimilarityMaps(Iterator<ItemItemSimilarity> similarities) {
+ while (similarities.hasNext()) {
+ ItemItemSimilarity iic = similarities.next();
+ long similarityItemID1 = iic.getItemID1();
+ long similarityItemID2 = iic.getItemID2();
+ if (similarityItemID1 != similarityItemID2) {
+ // Order them -- first key should be the "smaller" one
+ long itemID1;
+ long itemID2;
+ if (similarityItemID1 < similarityItemID2) {
+ itemID1 = similarityItemID1;
+ itemID2 = similarityItemID2;
+ } else {
+ itemID1 = similarityItemID2;
+ itemID2 = similarityItemID1;
+ }
+ FastByIDMap<Double> map = similarityMaps.get(itemID1);
+ if (map == null) {
+ map = new FastByIDMap<>();
+ similarityMaps.put(itemID1, map);
+ }
+ map.put(itemID2, iic.getValue());
+
+ doIndex(itemID1, itemID2);
+ doIndex(itemID2, itemID1);
+ }
+ // else similarity between item and itself already assumed to be 1.0
+ }
+ }
+
+ private void doIndex(long fromItemID, long toItemID) {
+ FastIDSet similarItemIDs = similarItemIDsIndex.get(fromItemID);
+ if (similarItemIDs == null) {
+ similarItemIDs = new FastIDSet();
+ similarItemIDsIndex.put(fromItemID, similarItemIDs);
+ }
+ similarItemIDs.add(toItemID);
+ }
+
+ /**
+ * <p>
+ * Returns the similarity between two items. Note that similarity is assumed to be symmetric, that
+ * {@code itemSimilarity(item1, item2) == itemSimilarity(item2, item1)}, and that
+ * {@code itemSimilarity(item1,item1) == 1.0} for all items.
+ * </p>
+ *
+ * @param itemID1
+ * first item
+ * @param itemID2
+ * second item
+ * @return similarity between the two
+ */
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) {
+ if (itemID1 == itemID2) {
+ return 1.0;
+ }
+ long firstID;
+ long secondID;
+ if (itemID1 < itemID2) {
+ firstID = itemID1;
+ secondID = itemID2;
+ } else {
+ firstID = itemID2;
+ secondID = itemID1;
+ }
+ FastByIDMap<Double> nextMap = similarityMaps.get(firstID);
+ if (nextMap == null) {
+ return Double.NaN;
+ }
+ Double similarity = nextMap.get(secondID);
+ return similarity == null ? Double.NaN : similarity;
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) {
+ int length = itemID2s.length;
+ double[] result = new double[length];
+ for (int i = 0; i < length; i++) {
+ result[i] = itemSimilarity(itemID1, itemID2s[i]);
+ }
+ return result;
+ }
+
+ @Override
+ public long[] allSimilarItemIDs(long itemID) {
+ FastIDSet similarItemIDs = similarItemIDsIndex.get(itemID);
+ return similarItemIDs != null ? similarItemIDs.toArray() : NO_IDS;
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // Do nothing
+ }
+
+ /** Encapsulates a similarity between two items. Similarity must be in the range [-1.0,1.0]. */
+ public static final class ItemItemSimilarity implements Comparable<ItemItemSimilarity> {
+
+ private final long itemID1;
+ private final long itemID2;
+ private final double value;
+
+ /**
+ * @param itemID1
+ * first item
+ * @param itemID2
+ * second item
+ * @param value
+ * similarity between the two
+ * @throws IllegalArgumentException
+ * if value is NaN, less than -1.0 or greater than 1.0
+ */
+ public ItemItemSimilarity(long itemID1, long itemID2, double value) {
+ Preconditions.checkArgument(value >= -1.0 && value <= 1.0, "Illegal value: " + value + ". Must be: -1.0 <= value <= 1.0");
+ this.itemID1 = itemID1;
+ this.itemID2 = itemID2;
+ this.value = value;
+ }
+
+ public long getItemID1() {
+ return itemID1;
+ }
+
+ public long getItemID2() {
+ return itemID2;
+ }
+
+ public double getValue() {
+ return value;
+ }
+
+ @Override
+ public String toString() {
+ return "ItemItemSimilarity[" + itemID1 + ',' + itemID2 + ':' + value + ']';
+ }
+
+ /** Defines an ordering from highest similarity to lowest. */
+ @Override
+ public int compareTo(ItemItemSimilarity other) {
+ double otherValue = other.getValue();
+ return value > otherValue ? -1 : value < otherValue ? 1 : 0;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof ItemItemSimilarity)) {
+ return false;
+ }
+ ItemItemSimilarity otherSimilarity = (ItemItemSimilarity) other;
+ return otherSimilarity.getItemID1() == itemID1
+ && otherSimilarity.getItemID2() == itemID2
+ && otherSimilarity.getValue() == value;
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) itemID1 ^ (int) itemID2 ^ RandomUtils.hashDouble(value);
+ }
+
+ }
+
+ private static final class DataModelSimilaritiesIterator extends AbstractIterator<ItemItemSimilarity> {
+
+ private final ItemSimilarity otherSimilarity;
+ private final long[] itemIDs;
+ private int i;
+ private long itemID1;
+ private int j;
+
+ private DataModelSimilaritiesIterator(ItemSimilarity otherSimilarity, long[] itemIDs) {
+ this.otherSimilarity = otherSimilarity;
+ this.itemIDs = itemIDs;
+ i = 0;
+ itemID1 = itemIDs[0];
+ j = 1;
+ }
+
+ @Override
+ protected ItemItemSimilarity computeNext() {
+ int size = itemIDs.length;
+ ItemItemSimilarity result = null;
+ while (result == null && i < size - 1) {
+ long itemID2 = itemIDs[j];
+ double similarity;
+ try {
+ similarity = otherSimilarity.itemSimilarity(itemID1, itemID2);
+ } catch (TasteException te) {
+ // ugly:
+ throw new IllegalStateException(te);
+ }
+ if (!Double.isNaN(similarity)) {
+ result = new ItemItemSimilarity(itemID1, itemID2, similarity);
+ }
+ if (++j == size) {
+ itemID1 = itemIDs[++i];
+ j = i + 1;
+ }
+ }
+ if (result == null) {
+ return endOfData();
+ } else {
+ return result;
+ }
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericUserSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericUserSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericUserSimilarity.java
new file mode 100644
index 0000000..1c221c2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/GenericUserSimilarity.java
@@ -0,0 +1,238 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+import java.util.Iterator;
+
+import com.google.common.collect.AbstractIterator;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.recommender.TopItems;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import org.apache.mahout.common.RandomUtils;
+
+import com.google.common.base.Preconditions;
+
+public final class GenericUserSimilarity implements UserSimilarity {
+
+ private final FastByIDMap<FastByIDMap<Double>> similarityMaps = new FastByIDMap<>();
+
+ public GenericUserSimilarity(Iterable<UserUserSimilarity> similarities) {
+ initSimilarityMaps(similarities.iterator());
+ }
+
+ public GenericUserSimilarity(Iterable<UserUserSimilarity> similarities, int maxToKeep) {
+ Iterable<UserUserSimilarity> keptSimilarities =
+ TopItems.getTopUserUserSimilarities(maxToKeep, similarities.iterator());
+ initSimilarityMaps(keptSimilarities.iterator());
+ }
+
+ public GenericUserSimilarity(UserSimilarity otherSimilarity, DataModel dataModel) throws TasteException {
+ long[] userIDs = longIteratorToList(dataModel.getUserIDs());
+ initSimilarityMaps(new DataModelSimilaritiesIterator(otherSimilarity, userIDs));
+ }
+
+ public GenericUserSimilarity(UserSimilarity otherSimilarity,
+ DataModel dataModel,
+ int maxToKeep) throws TasteException {
+ long[] userIDs = longIteratorToList(dataModel.getUserIDs());
+ Iterator<UserUserSimilarity> it = new DataModelSimilaritiesIterator(otherSimilarity, userIDs);
+ Iterable<UserUserSimilarity> keptSimilarities = TopItems.getTopUserUserSimilarities(maxToKeep, it);
+ initSimilarityMaps(keptSimilarities.iterator());
+ }
+
+ static long[] longIteratorToList(LongPrimitiveIterator iterator) {
+ long[] result = new long[5];
+ int size = 0;
+ while (iterator.hasNext()) {
+ if (size == result.length) {
+ long[] newResult = new long[result.length << 1];
+ System.arraycopy(result, 0, newResult, 0, result.length);
+ result = newResult;
+ }
+ result[size++] = iterator.next();
+ }
+ if (size != result.length) {
+ long[] newResult = new long[size];
+ System.arraycopy(result, 0, newResult, 0, size);
+ result = newResult;
+ }
+ return result;
+ }
+
+ private void initSimilarityMaps(Iterator<UserUserSimilarity> similarities) {
+ while (similarities.hasNext()) {
+ UserUserSimilarity uuc = similarities.next();
+ long similarityUser1 = uuc.getUserID1();
+ long similarityUser2 = uuc.getUserID2();
+ if (similarityUser1 != similarityUser2) {
+ // Order them -- first key should be the "smaller" one
+ long user1;
+ long user2;
+ if (similarityUser1 < similarityUser2) {
+ user1 = similarityUser1;
+ user2 = similarityUser2;
+ } else {
+ user1 = similarityUser2;
+ user2 = similarityUser1;
+ }
+ FastByIDMap<Double> map = similarityMaps.get(user1);
+ if (map == null) {
+ map = new FastByIDMap<>();
+ similarityMaps.put(user1, map);
+ }
+ map.put(user2, uuc.getValue());
+ }
+ // else similarity between user and itself already assumed to be 1.0
+ }
+ }
+
+ @Override
+ public double userSimilarity(long userID1, long userID2) {
+ if (userID1 == userID2) {
+ return 1.0;
+ }
+ long first;
+ long second;
+ if (userID1 < userID2) {
+ first = userID1;
+ second = userID2;
+ } else {
+ first = userID2;
+ second = userID1;
+ }
+ FastByIDMap<Double> nextMap = similarityMaps.get(first);
+ if (nextMap == null) {
+ return Double.NaN;
+ }
+ Double similarity = nextMap.get(second);
+ return similarity == null ? Double.NaN : similarity;
+ }
+
+ @Override
+ public void setPreferenceInferrer(PreferenceInferrer inferrer) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // Do nothing
+ }
+
+ public static final class UserUserSimilarity implements Comparable<UserUserSimilarity> {
+
+ private final long userID1;
+ private final long userID2;
+ private final double value;
+
+ public UserUserSimilarity(long userID1, long userID2, double value) {
+ Preconditions.checkArgument(value >= -1.0 && value <= 1.0, "Illegal value: " + value + ". Must be: -1.0 <= value <= 1.0");
+ this.userID1 = userID1;
+ this.userID2 = userID2;
+ this.value = value;
+ }
+
+ public long getUserID1() {
+ return userID1;
+ }
+
+ public long getUserID2() {
+ return userID2;
+ }
+
+ public double getValue() {
+ return value;
+ }
+
+ @Override
+ public String toString() {
+ return "UserUserSimilarity[" + userID1 + ',' + userID2 + ':' + value + ']';
+ }
+
+ /** Defines an ordering from highest similarity to lowest. */
+ @Override
+ public int compareTo(UserUserSimilarity other) {
+ double otherValue = other.getValue();
+ return value > otherValue ? -1 : value < otherValue ? 1 : 0;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof UserUserSimilarity)) {
+ return false;
+ }
+ UserUserSimilarity otherSimilarity = (UserUserSimilarity) other;
+ return otherSimilarity.getUserID1() == userID1
+ && otherSimilarity.getUserID2() == userID2
+ && otherSimilarity.getValue() == value;
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) userID1 ^ (int) userID2 ^ RandomUtils.hashDouble(value);
+ }
+
+ }
+
+ private static final class DataModelSimilaritiesIterator extends AbstractIterator<UserUserSimilarity> {
+
+ private final UserSimilarity otherSimilarity;
+ private final long[] itemIDs;
+ private int i;
+ private long itemID1;
+ private int j;
+
+ private DataModelSimilaritiesIterator(UserSimilarity otherSimilarity, long[] itemIDs) {
+ this.otherSimilarity = otherSimilarity;
+ this.itemIDs = itemIDs;
+ i = 0;
+ itemID1 = itemIDs[0];
+ j = 1;
+ }
+
+ @Override
+ protected UserUserSimilarity computeNext() {
+ int size = itemIDs.length;
+ while (i < size - 1) {
+ long itemID2 = itemIDs[j];
+ double similarity;
+ try {
+ similarity = otherSimilarity.userSimilarity(itemID1, itemID2);
+ } catch (TasteException te) {
+ // ugly:
+ throw new IllegalStateException(te);
+ }
+ if (!Double.isNaN(similarity)) {
+ return new UserUserSimilarity(itemID1, itemID2, similarity);
+ }
+ if (++j == size) {
+ itemID1 = itemIDs[++i];
+ j = i + 1;
+ }
+ }
+ return endOfData();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java
new file mode 100644
index 0000000..3084c8f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LogLikelihoodSimilarity.java
@@ -0,0 +1,121 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import org.apache.mahout.math.stats.LogLikelihood;
+
+/**
+ * See <a href="http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.14.5962">
+ * http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.14.5962</a> and
+ * <a href="http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html">
+ * http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html</a>.
+ */
+public final class LogLikelihoodSimilarity extends AbstractItemSimilarity implements UserSimilarity {
+
+ public LogLikelihoodSimilarity(DataModel dataModel) {
+ super(dataModel);
+ }
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ @Override
+ public void setPreferenceInferrer(PreferenceInferrer inferrer) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double userSimilarity(long userID1, long userID2) throws TasteException {
+
+ DataModel dataModel = getDataModel();
+ FastIDSet prefs1 = dataModel.getItemIDsFromUser(userID1);
+ FastIDSet prefs2 = dataModel.getItemIDsFromUser(userID2);
+
+ long prefs1Size = prefs1.size();
+ long prefs2Size = prefs2.size();
+ long intersectionSize =
+ prefs1Size < prefs2Size ? prefs2.intersectionSize(prefs1) : prefs1.intersectionSize(prefs2);
+ if (intersectionSize == 0) {
+ return Double.NaN;
+ }
+ long numItems = dataModel.getNumItems();
+ double logLikelihood =
+ LogLikelihood.logLikelihoodRatio(intersectionSize,
+ prefs2Size - intersectionSize,
+ prefs1Size - intersectionSize,
+ numItems - prefs1Size - prefs2Size + intersectionSize);
+ return 1.0 - 1.0 / (1.0 + logLikelihood);
+ }
+
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
+ DataModel dataModel = getDataModel();
+ long preferring1 = dataModel.getNumUsersWithPreferenceFor(itemID1);
+ long numUsers = dataModel.getNumUsers();
+ return doItemSimilarity(itemID1, itemID2, preferring1, numUsers);
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
+ DataModel dataModel = getDataModel();
+ long preferring1 = dataModel.getNumUsersWithPreferenceFor(itemID1);
+ long numUsers = dataModel.getNumUsers();
+ int length = itemID2s.length;
+ double[] result = new double[length];
+ for (int i = 0; i < length; i++) {
+ result[i] = doItemSimilarity(itemID1, itemID2s[i], preferring1, numUsers);
+ }
+ return result;
+ }
+
+ private double doItemSimilarity(long itemID1, long itemID2, long preferring1, long numUsers) throws TasteException {
+ DataModel dataModel = getDataModel();
+ long preferring1and2 = dataModel.getNumUsersWithPreferenceFor(itemID1, itemID2);
+ if (preferring1and2 == 0) {
+ return Double.NaN;
+ }
+ long preferring2 = dataModel.getNumUsersWithPreferenceFor(itemID2);
+ double logLikelihood =
+ LogLikelihood.logLikelihoodRatio(preferring1and2,
+ preferring2 - preferring1and2,
+ preferring1 - preferring1and2,
+ numUsers - preferring1 - preferring2 + preferring1and2);
+ return 1.0 - 1.0 / (1.0 + logLikelihood);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ alreadyRefreshed = RefreshHelper.buildRefreshed(alreadyRefreshed);
+ RefreshHelper.maybeRefresh(alreadyRefreshed, getDataModel());
+ }
+
+ @Override
+ public String toString() {
+ return "LogLikelihoodSimilarity[dataModel:" + getDataModel() + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LongPairMatchPredicate.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LongPairMatchPredicate.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LongPairMatchPredicate.java
new file mode 100644
index 0000000..48dc4e0
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/LongPairMatchPredicate.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.impl.common.Cache;
+import org.apache.mahout.common.LongPair;
+
+/**
+ * A {@link Cache.MatchPredicate} which will match an ID against either element of a
+ * {@link LongPair}.
+ */
+final class LongPairMatchPredicate implements Cache.MatchPredicate<LongPair> {
+
+ private final long id;
+
+ LongPairMatchPredicate(long id) {
+ this.id = id;
+ }
+
+ @Override
+ public boolean matches(LongPair pair) {
+ return pair.getFirst() == id || pair.getSecond() == id;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarity.java
new file mode 100644
index 0000000..8ea1660
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/PearsonCorrelationSimilarity.java
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.common.Weighting;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * An implementation of the Pearson correlation. For users X and Y, the following values are calculated:
+ * </p>
+ *
+ * <ul>
+ * <li>sumX2: sum of the square of all X's preference values</li>
+ * <li>sumY2: sum of the square of all Y's preference values</li>
+ * <li>sumXY: sum of the product of X and Y's preference value for all items for which both X and Y express a
+ * preference</li>
+ * </ul>
+ *
+ * <p>
+ * The correlation is then:
+ *
+ * <p>
+ * {@code sumXY / sqrt(sumX2 * sumY2)}
+ * </p>
+ *
+ * <p>
+ * Note that this correlation "centers" its data, shifts the user's preference values so that each of their
+ * means is 0. This is necessary to achieve expected behavior on all data sets.
+ * </p>
+ *
+ * <p>
+ * This correlation implementation is equivalent to the cosine similarity since the data it receives
+ * is assumed to be centered -- mean is 0. The correlation may be interpreted as the cosine of the angle
+ * between the two vectors defined by the users' preference values.
+ * </p>
+ *
+ * <p>
+ * For cosine similarity on uncentered data, see {@link UncenteredCosineSimilarity}.
+ * </p>
+ */
+public final class PearsonCorrelationSimilarity extends AbstractSimilarity {
+
+ /**
+ * @throws IllegalArgumentException if {@link DataModel} does not have preference values
+ */
+ public PearsonCorrelationSimilarity(DataModel dataModel) throws TasteException {
+ this(dataModel, Weighting.UNWEIGHTED);
+ }
+
+ /**
+ * @throws IllegalArgumentException if {@link DataModel} does not have preference values
+ */
+ public PearsonCorrelationSimilarity(DataModel dataModel, Weighting weighting) throws TasteException {
+ super(dataModel, weighting, true);
+ Preconditions.checkArgument(dataModel.hasPreferenceValues(), "DataModel doesn't have preference values");
+ }
+
+ @Override
+ double computeResult(int n, double sumXY, double sumX2, double sumY2, double sumXYdiff2) {
+ if (n == 0) {
+ return Double.NaN;
+ }
+ // Note that sum of X and sum of Y don't appear here since they are assumed to be 0;
+ // the data is assumed to be centered.
+ double denominator = Math.sqrt(sumX2) * Math.sqrt(sumY2);
+ if (denominator == 0.0) {
+ // One or both parties has -all- the same ratings;
+ // can't really say much similarity under this measure
+ return Double.NaN;
+ }
+ return sumXY / denominator;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarity.java
new file mode 100644
index 0000000..1116368
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/SpearmanCorrelationSimilarity.java
@@ -0,0 +1,135 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * Like {@link PearsonCorrelationSimilarity}, but compares relative ranking of preference values instead of
+ * preference values themselves. That is, each user's preferences are sorted and then assign a rank as their
+ * preference value, with 1 being assigned to the least preferred item.
+ * </p>
+ */
+public final class SpearmanCorrelationSimilarity implements UserSimilarity {
+
+ private final DataModel dataModel;
+
+ public SpearmanCorrelationSimilarity(DataModel dataModel) {
+ this.dataModel = Preconditions.checkNotNull(dataModel);
+ }
+
+ @Override
+ public double userSimilarity(long userID1, long userID2) throws TasteException {
+ PreferenceArray xPrefs = dataModel.getPreferencesFromUser(userID1);
+ PreferenceArray yPrefs = dataModel.getPreferencesFromUser(userID2);
+ int xLength = xPrefs.length();
+ int yLength = yPrefs.length();
+
+ if (xLength <= 1 || yLength <= 1) {
+ return Double.NaN;
+ }
+
+ // Copy prefs since we need to modify pref values to ranks
+ xPrefs = xPrefs.clone();
+ yPrefs = yPrefs.clone();
+
+ // First sort by values from low to high
+ xPrefs.sortByValue();
+ yPrefs.sortByValue();
+
+ // Assign ranks from low to high
+ float nextRank = 1.0f;
+ for (int i = 0; i < xLength; i++) {
+ // ... but only for items that are common to both pref arrays
+ if (yPrefs.hasPrefWithItemID(xPrefs.getItemID(i))) {
+ xPrefs.setValue(i, nextRank);
+ nextRank += 1.0f;
+ }
+ // Other values are bogus but don't matter
+ }
+ nextRank = 1.0f;
+ for (int i = 0; i < yLength; i++) {
+ if (xPrefs.hasPrefWithItemID(yPrefs.getItemID(i))) {
+ yPrefs.setValue(i, nextRank);
+ nextRank += 1.0f;
+ }
+ }
+
+ xPrefs.sortByItem();
+ yPrefs.sortByItem();
+
+ long xIndex = xPrefs.getItemID(0);
+ long yIndex = yPrefs.getItemID(0);
+ int xPrefIndex = 0;
+ int yPrefIndex = 0;
+
+ double sumXYRankDiff2 = 0.0;
+ int count = 0;
+
+ while (true) {
+ int compare = xIndex < yIndex ? -1 : xIndex > yIndex ? 1 : 0;
+ if (compare == 0) {
+ double diff = xPrefs.getValue(xPrefIndex) - yPrefs.getValue(yPrefIndex);
+ sumXYRankDiff2 += diff * diff;
+ count++;
+ }
+ if (compare <= 0) {
+ if (++xPrefIndex >= xLength) {
+ break;
+ }
+ xIndex = xPrefs.getItemID(xPrefIndex);
+ }
+ if (compare >= 0) {
+ if (++yPrefIndex >= yLength) {
+ break;
+ }
+ yIndex = yPrefs.getItemID(yPrefIndex);
+ }
+ }
+
+ if (count <= 1) {
+ return Double.NaN;
+ }
+
+ // When ranks are unique, this formula actually gives the Pearson correlation
+ return 1.0 - 6.0 * sumXYRankDiff2 / (count * (count * count - 1));
+ }
+
+ @Override
+ public void setPreferenceInferrer(PreferenceInferrer inferrer) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ alreadyRefreshed = RefreshHelper.buildRefreshed(alreadyRefreshed);
+ RefreshHelper.maybeRefresh(alreadyRefreshed, dataModel);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarity.java
new file mode 100644
index 0000000..0c3a0a4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/TanimotoCoefficientSimilarity.java
@@ -0,0 +1,126 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+
+/**
+ * <p>
+ * An implementation of a "similarity" based on the <a
+ * href="http://en.wikipedia.org/wiki/Jaccard_index#Tanimoto_coefficient_.28extended_Jaccard_coefficient.29">
+ * Tanimoto coefficient</a>, or extended <a href="http://en.wikipedia.org/wiki/Jaccard_index">Jaccard
+ * coefficient</a>.
+ * </p>
+ *
+ * <p>
+ * This is intended for "binary" data sets where a user either expresses a generic "yes" preference for an
+ * item or has no preference. The actual preference values do not matter here, only their presence or absence.
+ * </p>
+ *
+ * <p>
+ * The value returned is in [0,1].
+ * </p>
+ */
+public final class TanimotoCoefficientSimilarity extends AbstractItemSimilarity implements UserSimilarity {
+
+ public TanimotoCoefficientSimilarity(DataModel dataModel) {
+ super(dataModel);
+ }
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ @Override
+ public void setPreferenceInferrer(PreferenceInferrer inferrer) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double userSimilarity(long userID1, long userID2) throws TasteException {
+
+ DataModel dataModel = getDataModel();
+ FastIDSet xPrefs = dataModel.getItemIDsFromUser(userID1);
+ FastIDSet yPrefs = dataModel.getItemIDsFromUser(userID2);
+
+ int xPrefsSize = xPrefs.size();
+ int yPrefsSize = yPrefs.size();
+ if (xPrefsSize == 0 && yPrefsSize == 0) {
+ return Double.NaN;
+ }
+ if (xPrefsSize == 0 || yPrefsSize == 0) {
+ return 0.0;
+ }
+
+ int intersectionSize =
+ xPrefsSize < yPrefsSize ? yPrefs.intersectionSize(xPrefs) : xPrefs.intersectionSize(yPrefs);
+ if (intersectionSize == 0) {
+ return Double.NaN;
+ }
+
+ int unionSize = xPrefsSize + yPrefsSize - intersectionSize;
+
+ return (double) intersectionSize / (double) unionSize;
+ }
+
+ @Override
+ public double itemSimilarity(long itemID1, long itemID2) throws TasteException {
+ int preferring1 = getDataModel().getNumUsersWithPreferenceFor(itemID1);
+ return doItemSimilarity(itemID1, itemID2, preferring1);
+ }
+
+ @Override
+ public double[] itemSimilarities(long itemID1, long[] itemID2s) throws TasteException {
+ int preferring1 = getDataModel().getNumUsersWithPreferenceFor(itemID1);
+ int length = itemID2s.length;
+ double[] result = new double[length];
+ for (int i = 0; i < length; i++) {
+ result[i] = doItemSimilarity(itemID1, itemID2s[i], preferring1);
+ }
+ return result;
+ }
+
+ private double doItemSimilarity(long itemID1, long itemID2, int preferring1) throws TasteException {
+ DataModel dataModel = getDataModel();
+ int preferring1and2 = dataModel.getNumUsersWithPreferenceFor(itemID1, itemID2);
+ if (preferring1and2 == 0) {
+ return Double.NaN;
+ }
+ int preferring2 = dataModel.getNumUsersWithPreferenceFor(itemID2);
+ return (double) preferring1and2 / (double) (preferring1 + preferring2 - preferring1and2);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ alreadyRefreshed = RefreshHelper.buildRefreshed(alreadyRefreshed);
+ RefreshHelper.maybeRefresh(alreadyRefreshed, getDataModel());
+ }
+
+ @Override
+ public String toString() {
+ return "TanimotoCoefficientSimilarity[dataModel:" + getDataModel() + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/UncenteredCosineSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/UncenteredCosineSimilarity.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/UncenteredCosineSimilarity.java
new file mode 100644
index 0000000..6260606
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/UncenteredCosineSimilarity.java
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.common.Weighting;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * An implementation of the cosine similarity. The result is the cosine of the angle formed between
+ * the two preference vectors.
+ * </p>
+ *
+ * <p>
+ * Note that this similarity does not "center" its data, shifts the user's preference values so that each of their
+ * means is 0. For this behavior, use {@link PearsonCorrelationSimilarity}, which actually is mathematically
+ * equivalent for centered data.
+ * </p>
+ */
+public final class UncenteredCosineSimilarity extends AbstractSimilarity {
+
+ /**
+ * @throws IllegalArgumentException if {@link DataModel} does not have preference values
+ */
+ public UncenteredCosineSimilarity(DataModel dataModel) throws TasteException {
+ this(dataModel, Weighting.UNWEIGHTED);
+ }
+
+ /**
+ * @throws IllegalArgumentException if {@link DataModel} does not have preference values
+ */
+ public UncenteredCosineSimilarity(DataModel dataModel, Weighting weighting) throws TasteException {
+ super(dataModel, weighting, false);
+ Preconditions.checkArgument(dataModel.hasPreferenceValues(), "DataModel doesn't have preference values");
+ }
+
+ @Override
+ double computeResult(int n, double sumXY, double sumX2, double sumY2, double sumXYdiff2) {
+ if (n == 0) {
+ return Double.NaN;
+ }
+ double denominator = Math.sqrt(sumX2) * Math.sqrt(sumY2);
+ if (denominator == 0.0) {
+ // One or both parties has -all- the same ratings;
+ // can't really say much similarity under this measure
+ return Double.NaN;
+ }
+ return sumXY / denominator;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemItemSimilarityIterable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemItemSimilarityIterable.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemItemSimilarityIterable.java
new file mode 100644
index 0000000..1ae45c2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemItemSimilarityIterable.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.file;
+
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Iterator;
+
+/**
+ * {@link Iterable} to be able to read a file linewise into a {@link GenericItemSimilarity}
+ */
+final class FileItemItemSimilarityIterable implements Iterable<GenericItemSimilarity.ItemItemSimilarity> {
+
+ private final File similaritiesFile;
+
+ FileItemItemSimilarityIterable(File similaritiesFile) {
+ this.similaritiesFile = similaritiesFile;
+ }
+
+ @Override
+ public Iterator<GenericItemSimilarity.ItemItemSimilarity> iterator() {
+ try {
+ return new FileItemItemSimilarityIterator(similaritiesFile);
+ } catch (IOException ioe) {
+ throw new IllegalStateException("Can't read " + similaritiesFile, ioe);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemItemSimilarityIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemItemSimilarityIterator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemItemSimilarityIterator.java
new file mode 100644
index 0000000..c071159
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/similarity/file/FileItemItemSimilarityIterator.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity.file;
+
+import com.google.common.base.Function;
+import com.google.common.collect.ForwardingIterator;
+import com.google.common.collect.Iterators;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity;
+import org.apache.mahout.common.iterator.FileLineIterator;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.regex.Pattern;
+
+/**
+ * a simple iterator using a {@link FileLineIterator} internally, parsing each
+ * line into an {@link GenericItemSimilarity.ItemItemSimilarity}.
+ */
+final class FileItemItemSimilarityIterator extends ForwardingIterator<GenericItemSimilarity.ItemItemSimilarity> {
+
+ private static final Pattern SEPARATOR = Pattern.compile("[,\t]");
+
+ private final Iterator<GenericItemSimilarity.ItemItemSimilarity> delegate;
+
+ FileItemItemSimilarityIterator(File similaritiesFile) throws IOException {
+ delegate = Iterators.transform(
+ new FileLineIterator(similaritiesFile),
+ new Function<String, GenericItemSimilarity.ItemItemSimilarity>() {
+ @Override
+ public GenericItemSimilarity.ItemItemSimilarity apply(String from) {
+ String[] tokens = SEPARATOR.split(from);
+ return new GenericItemSimilarity.ItemItemSimilarity(Long.parseLong(tokens[0]),
+ Long.parseLong(tokens[1]),
+ Double.parseDouble(tokens[2]));
+ }
+ });
+ }
+
+ @Override
+ protected Iterator<GenericItemSimilarity.ItemItemSimilarity> delegate() {
+ return delegate;
+ }
+
+}
[46/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastByIDMap.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastByIDMap.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastByIDMap.java
new file mode 100644
index 0000000..fde8958
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastByIDMap.java
@@ -0,0 +1,661 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.io.Serializable;
+import java.util.AbstractCollection;
+import java.util.AbstractSet;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import java.util.Set;
+
+import org.apache.mahout.common.RandomUtils;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * @see FastMap
+ * @see FastIDSet
+ */
+public final class FastByIDMap<V> implements Serializable, Cloneable {
+
+ public static final int NO_MAX_SIZE = Integer.MAX_VALUE;
+ private static final float DEFAULT_LOAD_FACTOR = 1.5f;
+
+ /** Dummy object used to represent a key that has been removed. */
+ private static final long REMOVED = Long.MAX_VALUE;
+ private static final long NULL = Long.MIN_VALUE;
+
+ private long[] keys;
+ private V[] values;
+ private float loadFactor;
+ private int numEntries;
+ private int numSlotsUsed;
+ private final int maxSize;
+ private BitSet recentlyAccessed;
+ private final boolean countingAccesses;
+
+ /** Creates a new {@link FastByIDMap} with default capacity. */
+ public FastByIDMap() {
+ this(2, NO_MAX_SIZE);
+ }
+
+ public FastByIDMap(int size) {
+ this(size, NO_MAX_SIZE);
+ }
+
+ public FastByIDMap(int size, float loadFactor) {
+ this(size, NO_MAX_SIZE, loadFactor);
+ }
+
+ public FastByIDMap(int size, int maxSize) {
+ this(size, maxSize, DEFAULT_LOAD_FACTOR);
+ }
+
+ /**
+ * Creates a new {@link FastByIDMap} whose capacity can accommodate the given number of entries without rehash.
+ *
+ * @param size desired capacity
+ * @param maxSize max capacity
+ * @param loadFactor ratio of internal hash table size to current size
+ * @throws IllegalArgumentException if size is less than 0, maxSize is less than 1
+ * or at least half of {@link RandomUtils#MAX_INT_SMALLER_TWIN_PRIME}, or
+ * loadFactor is less than 1
+ */
+ public FastByIDMap(int size, int maxSize, float loadFactor) {
+ Preconditions.checkArgument(size >= 0, "size must be at least 0");
+ Preconditions.checkArgument(loadFactor >= 1.0f, "loadFactor must be at least 1.0");
+ this.loadFactor = loadFactor;
+ int max = (int) (RandomUtils.MAX_INT_SMALLER_TWIN_PRIME / loadFactor);
+ Preconditions.checkArgument(size < max, "size must be less than " + max);
+ Preconditions.checkArgument(maxSize >= 1, "maxSize must be at least 1");
+ int hashSize = RandomUtils.nextTwinPrime((int) (loadFactor * size));
+ keys = new long[hashSize];
+ Arrays.fill(keys, NULL);
+ values = (V[]) new Object[hashSize];
+ this.maxSize = maxSize;
+ this.countingAccesses = maxSize != Integer.MAX_VALUE;
+ this.recentlyAccessed = countingAccesses ? new BitSet(hashSize) : null;
+ }
+
+ /**
+ * @see #findForAdd(long)
+ */
+ private int find(long key) {
+ int theHashCode = (int) key & 0x7FFFFFFF; // make sure it's positive
+ long[] keys = this.keys;
+ int hashSize = keys.length;
+ int jump = 1 + theHashCode % (hashSize - 2);
+ int index = theHashCode % hashSize;
+ long currentKey = keys[index];
+ while (currentKey != NULL && key != currentKey) {
+ index -= index < jump ? jump - hashSize : jump;
+ currentKey = keys[index];
+ }
+ return index;
+ }
+
+ /**
+ * @see #find(long)
+ */
+ private int findForAdd(long key) {
+ int theHashCode = (int) key & 0x7FFFFFFF; // make sure it's positive
+ long[] keys = this.keys;
+ int hashSize = keys.length;
+ int jump = 1 + theHashCode % (hashSize - 2);
+ int index = theHashCode % hashSize;
+ long currentKey = keys[index];
+ while (currentKey != NULL && currentKey != REMOVED && key != currentKey) {
+ index -= index < jump ? jump - hashSize : jump;
+ currentKey = keys[index];
+ }
+ if (currentKey != REMOVED) {
+ return index;
+ }
+ // If we're adding, it's here, but, the key might have a value already later
+ int addIndex = index;
+ while (currentKey != NULL && key != currentKey) {
+ index -= index < jump ? jump - hashSize : jump;
+ currentKey = keys[index];
+ }
+ return key == currentKey ? index : addIndex;
+ }
+
+ public V get(long key) {
+ if (key == NULL) {
+ return null;
+ }
+ int index = find(key);
+ if (countingAccesses) {
+ recentlyAccessed.set(index);
+ }
+ return values[index];
+ }
+
+ public int size() {
+ return numEntries;
+ }
+
+ public boolean isEmpty() {
+ return numEntries == 0;
+ }
+
+ public boolean containsKey(long key) {
+ return key != NULL && key != REMOVED && keys[find(key)] != NULL;
+ }
+
+ public boolean containsValue(Object value) {
+ if (value == null) {
+ return false;
+ }
+ for (V theValue : values) {
+ if (theValue != null && value.equals(theValue)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ public V put(long key, V value) {
+ Preconditions.checkArgument(key != NULL && key != REMOVED);
+ Preconditions.checkNotNull(value);
+ // If less than half the slots are open, let's clear it up
+ if (numSlotsUsed * loadFactor >= keys.length) {
+ // If over half the slots used are actual entries, let's grow
+ if (numEntries * loadFactor >= numSlotsUsed) {
+ growAndRehash();
+ } else {
+ // Otherwise just rehash to clear REMOVED entries and don't grow
+ rehash();
+ }
+ }
+ // Here we may later consider implementing Brent's variation described on page 532
+ int index = findForAdd(key);
+ long keyIndex = keys[index];
+ if (keyIndex == key) {
+ V oldValue = values[index];
+ values[index] = value;
+ return oldValue;
+ }
+ // If size is limited,
+ if (countingAccesses && numEntries >= maxSize) {
+ // and we're too large, clear some old-ish entry
+ clearStaleEntry(index);
+ }
+ keys[index] = key;
+ values[index] = value;
+ numEntries++;
+ if (keyIndex == NULL) {
+ numSlotsUsed++;
+ }
+ return null;
+ }
+
+ private void clearStaleEntry(int index) {
+ while (true) {
+ long currentKey;
+ do {
+ if (index == 0) {
+ index = keys.length - 1;
+ } else {
+ index--;
+ }
+ currentKey = keys[index];
+ } while (currentKey == NULL || currentKey == REMOVED);
+ if (recentlyAccessed.get(index)) {
+ recentlyAccessed.clear(index);
+ } else {
+ break;
+ }
+ }
+ // Delete the entry
+ keys[index] = REMOVED;
+ numEntries--;
+ values[index] = null;
+ }
+
+ public V remove(long key) {
+ if (key == NULL || key == REMOVED) {
+ return null;
+ }
+ int index = find(key);
+ if (keys[index] == NULL) {
+ return null;
+ } else {
+ keys[index] = REMOVED;
+ numEntries--;
+ V oldValue = values[index];
+ values[index] = null;
+ // don't decrement numSlotsUsed
+ return oldValue;
+ }
+ // Could un-set recentlyAccessed's bit but doesn't matter
+ }
+
+ public void clear() {
+ numEntries = 0;
+ numSlotsUsed = 0;
+ Arrays.fill(keys, NULL);
+ Arrays.fill(values, null);
+ if (countingAccesses) {
+ recentlyAccessed.clear();
+ }
+ }
+
+ public LongPrimitiveIterator keySetIterator() {
+ return new KeyIterator();
+ }
+
+ public Set<Map.Entry<Long,V>> entrySet() {
+ return new EntrySet();
+ }
+
+ public Collection<V> values() {
+ return new ValueCollection();
+ }
+
+ public void rehash() {
+ rehash(RandomUtils.nextTwinPrime((int) (loadFactor * numEntries)));
+ }
+
+ private void growAndRehash() {
+ if (keys.length * loadFactor >= RandomUtils.MAX_INT_SMALLER_TWIN_PRIME) {
+ throw new IllegalStateException("Can't grow any more");
+ }
+ rehash(RandomUtils.nextTwinPrime((int) (loadFactor * keys.length)));
+ }
+
+ private void rehash(int newHashSize) {
+ long[] oldKeys = keys;
+ V[] oldValues = values;
+ numEntries = 0;
+ numSlotsUsed = 0;
+ if (countingAccesses) {
+ recentlyAccessed = new BitSet(newHashSize);
+ }
+ keys = new long[newHashSize];
+ Arrays.fill(keys, NULL);
+ values = (V[]) new Object[newHashSize];
+ int length = oldKeys.length;
+ for (int i = 0; i < length; i++) {
+ long key = oldKeys[i];
+ if (key != NULL && key != REMOVED) {
+ put(key, oldValues[i]);
+ }
+ }
+ }
+
+ void iteratorRemove(int lastNext) {
+ if (lastNext >= values.length) {
+ throw new NoSuchElementException();
+ }
+ if (lastNext < 0) {
+ throw new IllegalStateException();
+ }
+ values[lastNext] = null;
+ keys[lastNext] = REMOVED;
+ numEntries--;
+ }
+
+ @Override
+ public FastByIDMap<V> clone() {
+ FastByIDMap<V> clone;
+ try {
+ clone = (FastByIDMap<V>) super.clone();
+ } catch (CloneNotSupportedException cnse) {
+ throw new AssertionError();
+ }
+ clone.keys = keys.clone();
+ clone.values = values.clone();
+ clone.recentlyAccessed = countingAccesses ? new BitSet(keys.length) : null;
+ return clone;
+ }
+
+ @Override
+ public String toString() {
+ if (isEmpty()) {
+ return "{}";
+ }
+ StringBuilder result = new StringBuilder();
+ result.append('{');
+ for (int i = 0; i < keys.length; i++) {
+ long key = keys[i];
+ if (key != NULL && key != REMOVED) {
+ result.append(key).append('=').append(values[i]).append(',');
+ }
+ }
+ result.setCharAt(result.length() - 1, '}');
+ return result.toString();
+ }
+
+ @Override
+ public int hashCode() {
+ int hash = 0;
+ long[] keys = this.keys;
+ int max = keys.length;
+ for (int i = 0; i < max; i++) {
+ long key = keys[i];
+ if (key != NULL && key != REMOVED) {
+ hash = 31 * hash + ((int) (key >> 32) ^ (int) key);
+ hash = 31 * hash + values[i].hashCode();
+ }
+ }
+ return hash;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof FastByIDMap)) {
+ return false;
+ }
+ FastByIDMap<V> otherMap = (FastByIDMap<V>) other;
+ long[] otherKeys = otherMap.keys;
+ V[] otherValues = otherMap.values;
+ int length = keys.length;
+ int otherLength = otherKeys.length;
+ int max = Math.min(length, otherLength);
+
+ int i = 0;
+ while (i < max) {
+ long key = keys[i];
+ long otherKey = otherKeys[i];
+ if (key == NULL || key == REMOVED) {
+ if (otherKey != NULL && otherKey != REMOVED) {
+ return false;
+ }
+ } else {
+ if (key != otherKey || !values[i].equals(otherValues[i])) {
+ return false;
+ }
+ }
+ i++;
+ }
+ while (i < length) {
+ long key = keys[i];
+ if (key != NULL && key != REMOVED) {
+ return false;
+ }
+ i++;
+ }
+ while (i < otherLength) {
+ long key = otherKeys[i];
+ if (key != NULL && key != REMOVED) {
+ return false;
+ }
+ i++;
+ }
+ return true;
+ }
+
+ private final class KeyIterator extends AbstractLongPrimitiveIterator {
+
+ private int position;
+ private int lastNext = -1;
+
+ @Override
+ public boolean hasNext() {
+ goToNext();
+ return position < keys.length;
+ }
+
+ @Override
+ public long nextLong() {
+ goToNext();
+ lastNext = position;
+ if (position >= keys.length) {
+ throw new NoSuchElementException();
+ }
+ return keys[position++];
+ }
+
+ @Override
+ public long peek() {
+ goToNext();
+ if (position >= keys.length) {
+ throw new NoSuchElementException();
+ }
+ return keys[position];
+ }
+
+ private void goToNext() {
+ int length = values.length;
+ while (position < length && values[position] == null) {
+ position++;
+ }
+ }
+
+ @Override
+ public void remove() {
+ iteratorRemove(lastNext);
+ }
+
+ @Override
+ public void skip(int n) {
+ position += n;
+ }
+
+ }
+
+ private final class EntrySet extends AbstractSet<Map.Entry<Long,V>> {
+
+ @Override
+ public int size() {
+ return FastByIDMap.this.size();
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return FastByIDMap.this.isEmpty();
+ }
+
+ @Override
+ public boolean contains(Object o) {
+ return containsKey((Long) o);
+ }
+
+ @Override
+ public Iterator<Map.Entry<Long,V>> iterator() {
+ return new EntryIterator();
+ }
+
+ @Override
+ public boolean add(Map.Entry<Long,V> t) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean remove(Object o) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean addAll(Collection<? extends Map.Entry<Long,V>> ts) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean retainAll(Collection<?> objects) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean removeAll(Collection<?> objects) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void clear() {
+ FastByIDMap.this.clear();
+ }
+
+ private final class MapEntry implements Map.Entry<Long,V> {
+
+ private final int index;
+
+ private MapEntry(int index) {
+ this.index = index;
+ }
+
+ @Override
+ public Long getKey() {
+ return keys[index];
+ }
+
+ @Override
+ public V getValue() {
+ return values[index];
+ }
+
+ @Override
+ public V setValue(V value) {
+ Preconditions.checkArgument(value != null);
+
+ V oldValue = values[index];
+ values[index] = value;
+ return oldValue;
+ }
+ }
+
+ private final class EntryIterator implements Iterator<Map.Entry<Long,V>> {
+
+ private int position;
+ private int lastNext = -1;
+
+ @Override
+ public boolean hasNext() {
+ goToNext();
+ return position < keys.length;
+ }
+
+ @Override
+ public Map.Entry<Long,V> next() {
+ goToNext();
+ lastNext = position;
+ if (position >= keys.length) {
+ throw new NoSuchElementException();
+ }
+ return new MapEntry(position++);
+ }
+
+ private void goToNext() {
+ int length = values.length;
+ while (position < length && values[position] == null) {
+ position++;
+ }
+ }
+
+ @Override
+ public void remove() {
+ iteratorRemove(lastNext);
+ }
+ }
+
+ }
+
+ private final class ValueCollection extends AbstractCollection<V> {
+
+ @Override
+ public int size() {
+ return FastByIDMap.this.size();
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return FastByIDMap.this.isEmpty();
+ }
+
+ @Override
+ public boolean contains(Object o) {
+ return containsValue(o);
+ }
+
+ @Override
+ public Iterator<V> iterator() {
+ return new ValueIterator();
+ }
+
+ @Override
+ public boolean add(V v) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean remove(Object o) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean addAll(Collection<? extends V> vs) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean removeAll(Collection<?> objects) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean retainAll(Collection<?> objects) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void clear() {
+ FastByIDMap.this.clear();
+ }
+
+ private final class ValueIterator implements Iterator<V> {
+
+ private int position;
+ private int lastNext = -1;
+
+ @Override
+ public boolean hasNext() {
+ goToNext();
+ return position < values.length;
+ }
+
+ @Override
+ public V next() {
+ goToNext();
+ lastNext = position;
+ if (position >= values.length) {
+ throw new NoSuchElementException();
+ }
+ return values[position++];
+ }
+
+ private void goToNext() {
+ int length = values.length;
+ while (position < length && values[position] == null) {
+ position++;
+ }
+ }
+
+ @Override
+ public void remove() {
+ iteratorRemove(lastNext);
+ }
+
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastIDSet.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastIDSet.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastIDSet.java
new file mode 100644
index 0000000..5908270
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastIDSet.java
@@ -0,0 +1,426 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.NoSuchElementException;
+
+import org.apache.mahout.common.RandomUtils;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * @see FastByIDMap
+ */
+public final class FastIDSet implements Serializable, Cloneable, Iterable<Long> {
+
+ private static final float DEFAULT_LOAD_FACTOR = 1.5f;
+
+ /** Dummy object used to represent a key that has been removed. */
+ private static final long REMOVED = Long.MAX_VALUE;
+ private static final long NULL = Long.MIN_VALUE;
+
+ private long[] keys;
+ private float loadFactor;
+ private int numEntries;
+ private int numSlotsUsed;
+
+ /** Creates a new {@link FastIDSet} with default capacity. */
+ public FastIDSet() {
+ this(2);
+ }
+
+ public FastIDSet(long[] initialKeys) {
+ this(initialKeys.length);
+ addAll(initialKeys);
+ }
+
+ public FastIDSet(int size) {
+ this(size, DEFAULT_LOAD_FACTOR);
+ }
+
+ public FastIDSet(int size, float loadFactor) {
+ Preconditions.checkArgument(size >= 0, "size must be at least 0");
+ Preconditions.checkArgument(loadFactor >= 1.0f, "loadFactor must be at least 1.0");
+ this.loadFactor = loadFactor;
+ int max = (int) (RandomUtils.MAX_INT_SMALLER_TWIN_PRIME / loadFactor);
+ Preconditions.checkArgument(size < max, "size must be less than %d", max);
+ int hashSize = RandomUtils.nextTwinPrime((int) (loadFactor * size));
+ keys = new long[hashSize];
+ Arrays.fill(keys, NULL);
+ }
+
+ /**
+ * @see #findForAdd(long)
+ */
+ private int find(long key) {
+ int theHashCode = (int) key & 0x7FFFFFFF; // make sure it's positive
+ long[] keys = this.keys;
+ int hashSize = keys.length;
+ int jump = 1 + theHashCode % (hashSize - 2);
+ int index = theHashCode % hashSize;
+ long currentKey = keys[index];
+ while (currentKey != NULL && key != currentKey) { // note: true when currentKey == REMOVED
+ index -= index < jump ? jump - hashSize : jump;
+ currentKey = keys[index];
+ }
+ return index;
+ }
+
+ /**
+ * @see #find(long)
+ */
+ private int findForAdd(long key) {
+ int theHashCode = (int) key & 0x7FFFFFFF; // make sure it's positive
+ long[] keys = this.keys;
+ int hashSize = keys.length;
+ int jump = 1 + theHashCode % (hashSize - 2);
+ int index = theHashCode % hashSize;
+ long currentKey = keys[index];
+ while (currentKey != NULL && currentKey != REMOVED && key != currentKey) {
+ index -= index < jump ? jump - hashSize : jump;
+ currentKey = keys[index];
+ }
+ if (currentKey != REMOVED) {
+ return index;
+ }
+ // If we're adding, it's here, but, the key might have a value already later
+ int addIndex = index;
+ while (currentKey != NULL && key != currentKey) {
+ index -= index < jump ? jump - hashSize : jump;
+ currentKey = keys[index];
+ }
+ return key == currentKey ? index : addIndex;
+ }
+
+ public int size() {
+ return numEntries;
+ }
+
+ public boolean isEmpty() {
+ return numEntries == 0;
+ }
+
+ public boolean contains(long key) {
+ return key != NULL && key != REMOVED && keys[find(key)] != NULL;
+ }
+
+ public boolean add(long key) {
+ Preconditions.checkArgument(key != NULL && key != REMOVED);
+
+ // If less than half the slots are open, let's clear it up
+ if (numSlotsUsed * loadFactor >= keys.length) {
+ // If over half the slots used are actual entries, let's grow
+ if (numEntries * loadFactor >= numSlotsUsed) {
+ growAndRehash();
+ } else {
+ // Otherwise just rehash to clear REMOVED entries and don't grow
+ rehash();
+ }
+ }
+ // Here we may later consider implementing Brent's variation described on page 532
+ int index = findForAdd(key);
+ long keyIndex = keys[index];
+ if (keyIndex != key) {
+ keys[index] = key;
+ numEntries++;
+ if (keyIndex == NULL) {
+ numSlotsUsed++;
+ }
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public LongPrimitiveIterator iterator() {
+ return new KeyIterator();
+ }
+
+ public long[] toArray() {
+ long[] result = new long[numEntries];
+ for (int i = 0, position = 0; i < result.length; i++) {
+ while (keys[position] == NULL || keys[position] == REMOVED) {
+ position++;
+ }
+ result[i] = keys[position++];
+ }
+ return result;
+ }
+
+ public boolean remove(long key) {
+ if (key == NULL || key == REMOVED) {
+ return false;
+ }
+ int index = find(key);
+ if (keys[index] == NULL) {
+ return false;
+ } else {
+ keys[index] = REMOVED;
+ numEntries--;
+ return true;
+ }
+ }
+
+ public boolean addAll(long[] c) {
+ boolean changed = false;
+ for (long k : c) {
+ if (add(k)) {
+ changed = true;
+ }
+ }
+ return changed;
+ }
+
+ public boolean addAll(FastIDSet c) {
+ boolean changed = false;
+ for (long k : c.keys) {
+ if (k != NULL && k != REMOVED && add(k)) {
+ changed = true;
+ }
+ }
+ return changed;
+ }
+
+ public boolean removeAll(long[] c) {
+ boolean changed = false;
+ for (long o : c) {
+ if (remove(o)) {
+ changed = true;
+ }
+ }
+ return changed;
+ }
+
+ public boolean removeAll(FastIDSet c) {
+ boolean changed = false;
+ for (long k : c.keys) {
+ if (k != NULL && k != REMOVED && remove(k)) {
+ changed = true;
+ }
+ }
+ return changed;
+ }
+
+ public boolean retainAll(FastIDSet c) {
+ boolean changed = false;
+ for (int i = 0; i < keys.length; i++) {
+ long k = keys[i];
+ if (k != NULL && k != REMOVED && !c.contains(k)) {
+ keys[i] = REMOVED;
+ numEntries--;
+ changed = true;
+ }
+ }
+ return changed;
+ }
+
+ public void clear() {
+ numEntries = 0;
+ numSlotsUsed = 0;
+ Arrays.fill(keys, NULL);
+ }
+
+ private void growAndRehash() {
+ if (keys.length * loadFactor >= RandomUtils.MAX_INT_SMALLER_TWIN_PRIME) {
+ throw new IllegalStateException("Can't grow any more");
+ }
+ rehash(RandomUtils.nextTwinPrime((int) (loadFactor * keys.length)));
+ }
+
+ public void rehash() {
+ rehash(RandomUtils.nextTwinPrime((int) (loadFactor * numEntries)));
+ }
+
+ private void rehash(int newHashSize) {
+ long[] oldKeys = keys;
+ numEntries = 0;
+ numSlotsUsed = 0;
+ keys = new long[newHashSize];
+ Arrays.fill(keys, NULL);
+ for (long key : oldKeys) {
+ if (key != NULL && key != REMOVED) {
+ add(key);
+ }
+ }
+ }
+
+ /**
+ * Convenience method to quickly compute just the size of the intersection with another {@link FastIDSet}.
+ *
+ * @param other
+ * {@link FastIDSet} to intersect with
+ * @return number of elements in intersection
+ */
+ public int intersectionSize(FastIDSet other) {
+ int count = 0;
+ for (long key : other.keys) {
+ if (key != NULL && key != REMOVED && keys[find(key)] != NULL) {
+ count++;
+ }
+ }
+ return count;
+ }
+
+ @Override
+ public FastIDSet clone() {
+ FastIDSet clone;
+ try {
+ clone = (FastIDSet) super.clone();
+ } catch (CloneNotSupportedException cnse) {
+ throw new AssertionError();
+ }
+ clone.keys = keys.clone();
+ return clone;
+ }
+
+ @Override
+ public int hashCode() {
+ int hash = 0;
+ long[] keys = this.keys;
+ for (long key : keys) {
+ if (key != NULL && key != REMOVED) {
+ hash = 31 * hash + ((int) (key >> 32) ^ (int) key);
+ }
+ }
+ return hash;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof FastIDSet)) {
+ return false;
+ }
+ FastIDSet otherMap = (FastIDSet) other;
+ long[] otherKeys = otherMap.keys;
+ int length = keys.length;
+ int otherLength = otherKeys.length;
+ int max = Math.min(length, otherLength);
+
+ int i = 0;
+ while (i < max) {
+ long key = keys[i];
+ long otherKey = otherKeys[i];
+ if (key == NULL || key == REMOVED) {
+ if (otherKey != NULL && otherKey != REMOVED) {
+ return false;
+ }
+ } else {
+ if (key != otherKey) {
+ return false;
+ }
+ }
+ i++;
+ }
+ while (i < length) {
+ long key = keys[i];
+ if (key != NULL && key != REMOVED) {
+ return false;
+ }
+ i++;
+ }
+ while (i < otherLength) {
+ long key = otherKeys[i];
+ if (key != NULL && key != REMOVED) {
+ return false;
+ }
+ i++;
+ }
+ return true;
+ }
+
+ @Override
+ public String toString() {
+ if (isEmpty()) {
+ return "[]";
+ }
+ StringBuilder result = new StringBuilder();
+ result.append('[');
+ for (long key : keys) {
+ if (key != NULL && key != REMOVED) {
+ result.append(key).append(',');
+ }
+ }
+ result.setCharAt(result.length() - 1, ']');
+ return result.toString();
+ }
+
+ private final class KeyIterator extends AbstractLongPrimitiveIterator {
+
+ private int position;
+ private int lastNext = -1;
+
+ @Override
+ public boolean hasNext() {
+ goToNext();
+ return position < keys.length;
+ }
+
+ @Override
+ public long nextLong() {
+ goToNext();
+ lastNext = position;
+ if (position >= keys.length) {
+ throw new NoSuchElementException();
+ }
+ return keys[position++];
+ }
+
+ @Override
+ public long peek() {
+ goToNext();
+ if (position >= keys.length) {
+ throw new NoSuchElementException();
+ }
+ return keys[position];
+ }
+
+ private void goToNext() {
+ int length = keys.length;
+ while (position < length
+ && (keys[position] == NULL || keys[position] == REMOVED)) {
+ position++;
+ }
+ }
+
+ @Override
+ public void remove() {
+ if (lastNext >= keys.length) {
+ throw new NoSuchElementException();
+ }
+ if (lastNext < 0) {
+ throw new IllegalStateException();
+ }
+ keys[lastNext] = REMOVED;
+ numEntries--;
+ }
+
+ public Iterator<Long> iterator() {
+ return new KeyIterator();
+ }
+
+ @Override
+ public void skip(int n) {
+ position += n;
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastMap.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastMap.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastMap.java
new file mode 100644
index 0000000..7c64b44
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FastMap.java
@@ -0,0 +1,729 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.io.Serializable;
+import java.util.AbstractCollection;
+import java.util.AbstractSet;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import java.util.Set;
+
+import org.apache.mahout.common.RandomUtils;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * This is an optimized {@link Map} implementation, based on algorithms described in Knuth's "Art of Computer
+ * Programming", Vol. 3, p. 529.
+ * </p>
+ *
+ * <p>
+ * It should be faster than {@link java.util.HashMap} in some cases, but not all. Its main feature is a
+ * "max size" and the ability to transparently, efficiently and semi-intelligently evict old entries when max
+ * size is exceeded.
+ * </p>
+ *
+ * <p>
+ * This class is not a bit thread-safe.
+ * </p>
+ *
+ * <p>
+ * This implementation does not allow {@code null} as a key or value.
+ * </p>
+ */
+public final class FastMap<K,V> implements Map<K,V>, Serializable, Cloneable {
+
+ public static final int NO_MAX_SIZE = Integer.MAX_VALUE;
+ private static final float DEFAULT_LOAD_FACTOR = 1.5f;
+
+ /** Dummy object used to represent a key that has been removed. */
+ private static final Object REMOVED = new Object();
+
+ private K[] keys;
+ private V[] values;
+ private float loadFactor;
+ private int numEntries;
+ private int numSlotsUsed;
+ private final int maxSize;
+ private BitSet recentlyAccessed;
+ private final boolean countingAccesses;
+
+ /** Creates a new {@link FastMap} with default capacity. */
+ public FastMap() {
+ this(2, NO_MAX_SIZE);
+ }
+
+ public FastMap(int size) {
+ this(size, NO_MAX_SIZE);
+ }
+
+ public FastMap(Map<K,V> other) {
+ this(other.size());
+ putAll(other);
+ }
+
+ public FastMap(int size, float loadFactor) {
+ this(size, NO_MAX_SIZE, loadFactor);
+ }
+
+ public FastMap(int size, int maxSize) {
+ this(size, maxSize, DEFAULT_LOAD_FACTOR);
+ }
+
+ /**
+ * Creates a new whose capacity can accommodate the given number of entries without rehash.
+ *
+ * @param size desired capacity
+ * @param maxSize max capacity
+ * @throws IllegalArgumentException if size is less than 0, maxSize is less than 1
+ * or at least half of {@link RandomUtils#MAX_INT_SMALLER_TWIN_PRIME}, or
+ * loadFactor is less than 1
+ */
+ public FastMap(int size, int maxSize, float loadFactor) {
+ Preconditions.checkArgument(size >= 0, "size must be at least 0");
+ Preconditions.checkArgument(loadFactor >= 1.0f, "loadFactor must be at least 1.0");
+ this.loadFactor = loadFactor;
+ int max = (int) (RandomUtils.MAX_INT_SMALLER_TWIN_PRIME / loadFactor);
+ Preconditions.checkArgument(size < max, "size must be less than " + max);
+ Preconditions.checkArgument(maxSize >= 1, "maxSize must be at least 1");
+ int hashSize = RandomUtils.nextTwinPrime((int) (loadFactor * size));
+ keys = (K[]) new Object[hashSize];
+ values = (V[]) new Object[hashSize];
+ this.maxSize = maxSize;
+ this.countingAccesses = maxSize != Integer.MAX_VALUE;
+ this.recentlyAccessed = countingAccesses ? new BitSet(hashSize) : null;
+ }
+
+ private int find(Object key) {
+ int theHashCode = key.hashCode() & 0x7FFFFFFF; // make sure it's positive
+ K[] keys = this.keys;
+ int hashSize = keys.length;
+ int jump = 1 + theHashCode % (hashSize - 2);
+ int index = theHashCode % hashSize;
+ K currentKey = keys[index];
+ while (currentKey != null && !key.equals(currentKey)) {
+ index -= index < jump ? jump - hashSize : jump;
+ currentKey = keys[index];
+ }
+ return index;
+ }
+
+ private int findForAdd(Object key) {
+ int theHashCode = key.hashCode() & 0x7FFFFFFF; // make sure it's positive
+ K[] keys = this.keys;
+ int hashSize = keys.length;
+ int jump = 1 + theHashCode % (hashSize - 2);
+ int index = theHashCode % hashSize;
+ K currentKey = keys[index];
+ while (currentKey != null && currentKey != REMOVED && key != currentKey) {
+ index -= index < jump ? jump - hashSize : jump;
+ currentKey = keys[index];
+ }
+ if (currentKey != REMOVED) {
+ return index;
+ }
+ // If we're adding, it's here, but, the key might have a value already later
+ int addIndex = index;
+ while (currentKey != null && key != currentKey) {
+ index -= index < jump ? jump - hashSize : jump;
+ currentKey = keys[index];
+ }
+ return key == currentKey ? index : addIndex;
+ }
+
+ @Override
+ public V get(Object key) {
+ if (key == null) {
+ return null;
+ }
+ int index = find(key);
+ if (countingAccesses) {
+ recentlyAccessed.set(index);
+ }
+ return values[index];
+ }
+
+ @Override
+ public int size() {
+ return numEntries;
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return numEntries == 0;
+ }
+
+ @Override
+ public boolean containsKey(Object key) {
+ return key != null && keys[find(key)] != null;
+ }
+
+ @Override
+ public boolean containsValue(Object value) {
+ if (value == null) {
+ return false;
+ }
+ for (V theValue : values) {
+ if (theValue != null && value.equals(theValue)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /**
+ * @throws NullPointerException
+ * if key or value is null
+ */
+ @Override
+ public V put(K key, V value) {
+ Preconditions.checkNotNull(key);
+ Preconditions.checkNotNull(value);
+ // If less than half the slots are open, let's clear it up
+ if (numSlotsUsed * loadFactor >= keys.length) {
+ // If over half the slots used are actual entries, let's grow
+ if (numEntries * loadFactor >= numSlotsUsed) {
+ growAndRehash();
+ } else {
+ // Otherwise just rehash to clear REMOVED entries and don't grow
+ rehash();
+ }
+ }
+ // Here we may later consider implementing Brent's variation described on page 532
+ int index = findForAdd(key);
+ if (keys[index] == key) {
+ V oldValue = values[index];
+ values[index] = value;
+ return oldValue;
+ }
+ // If size is limited,
+ if (countingAccesses && numEntries >= maxSize) {
+ // and we're too large, clear some old-ish entry
+ clearStaleEntry(index);
+ }
+ keys[index] = key;
+ values[index] = value;
+ numEntries++;
+ numSlotsUsed++;
+ return null;
+ }
+
+ private void clearStaleEntry(int index) {
+ while (true) {
+ K currentKey;
+ do {
+ if (index == 0) {
+ index = keys.length - 1;
+ } else {
+ index--;
+ }
+ currentKey = keys[index];
+ } while (currentKey == null || currentKey == REMOVED);
+ if (recentlyAccessed.get(index)) {
+ recentlyAccessed.clear(index);
+ } else {
+ break;
+ }
+ }
+ // Delete the entry
+ ((Object[])keys)[index] = REMOVED;
+ numEntries--;
+ values[index] = null;
+ }
+
+ @Override
+ public void putAll(Map<? extends K,? extends V> map) {
+ for (Entry<? extends K,? extends V> entry : map.entrySet()) {
+ put(entry.getKey(), entry.getValue());
+ }
+ }
+
+ @Override
+ public V remove(Object key) {
+ if (key == null) {
+ return null;
+ }
+ int index = find(key);
+ if (keys[index] == null) {
+ return null;
+ } else {
+ ((Object[])keys)[index] = REMOVED;
+ numEntries--;
+ V oldValue = values[index];
+ values[index] = null;
+ // don't decrement numSlotsUsed
+ return oldValue;
+ }
+ // Could un-set recentlyAccessed's bit but doesn't matter
+ }
+
+ @Override
+ public void clear() {
+ numEntries = 0;
+ numSlotsUsed = 0;
+ Arrays.fill(keys, null);
+ Arrays.fill(values, null);
+ if (countingAccesses) {
+ recentlyAccessed.clear();
+ }
+ }
+
+ @Override
+ public Set<K> keySet() {
+ return new KeySet();
+ }
+
+ @Override
+ public Collection<V> values() {
+ return new ValueCollection();
+ }
+
+ @Override
+ public Set<Entry<K,V>> entrySet() {
+ return new EntrySet();
+ }
+
+ public void rehash() {
+ rehash(RandomUtils.nextTwinPrime((int) (loadFactor * numEntries)));
+ }
+
+ private void growAndRehash() {
+ if (keys.length * loadFactor >= RandomUtils.MAX_INT_SMALLER_TWIN_PRIME) {
+ throw new IllegalStateException("Can't grow any more");
+ }
+ rehash(RandomUtils.nextTwinPrime((int) (loadFactor * keys.length)));
+ }
+
+ private void rehash(int newHashSize) {
+ K[] oldKeys = keys;
+ V[] oldValues = values;
+ numEntries = 0;
+ numSlotsUsed = 0;
+ if (countingAccesses) {
+ recentlyAccessed = new BitSet(newHashSize);
+ }
+ keys = (K[]) new Object[newHashSize];
+ values = (V[]) new Object[newHashSize];
+ int length = oldKeys.length;
+ for (int i = 0; i < length; i++) {
+ K key = oldKeys[i];
+ if (key != null && key != REMOVED) {
+ put(key, oldValues[i]);
+ }
+ }
+ }
+
+ void iteratorRemove(int lastNext) {
+ if (lastNext >= values.length) {
+ throw new NoSuchElementException();
+ }
+ if (lastNext < 0) {
+ throw new IllegalStateException();
+ }
+ values[lastNext] = null;
+ ((Object[])keys)[lastNext] = REMOVED;
+ numEntries--;
+ }
+
+ @Override
+ public FastMap<K,V> clone() {
+ FastMap<K,V> clone;
+ try {
+ clone = (FastMap<K,V>) super.clone();
+ } catch (CloneNotSupportedException cnse) {
+ throw new AssertionError();
+ }
+ clone.keys = keys.clone();
+ clone.values = values.clone();
+ clone.recentlyAccessed = countingAccesses ? new BitSet(keys.length) : null;
+ return clone;
+ }
+
+ @Override
+ public int hashCode() {
+ int hash = 0;
+ K[] keys = this.keys;
+ int max = keys.length;
+ for (int i = 0; i < max; i++) {
+ K key = keys[i];
+ if (key != null && key != REMOVED) {
+ hash = 31 * hash + key.hashCode();
+ hash = 31 * hash + values[i].hashCode();
+ }
+ }
+ return hash;
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof FastMap)) {
+ return false;
+ }
+ FastMap<K,V> otherMap = (FastMap<K,V>) other;
+ K[] otherKeys = otherMap.keys;
+ V[] otherValues = otherMap.values;
+ int length = keys.length;
+ int otherLength = otherKeys.length;
+ int max = Math.min(length, otherLength);
+
+ int i = 0;
+ while (i < max) {
+ K key = keys[i];
+ K otherKey = otherKeys[i];
+ if (key == null || key == REMOVED) {
+ if (otherKey != null && otherKey != REMOVED) {
+ return false;
+ }
+ } else {
+ if (key != otherKey || !values[i].equals(otherValues[i])) {
+ return false;
+ }
+ }
+ i++;
+ }
+ while (i < length) {
+ K key = keys[i];
+ if (key != null && key != REMOVED) {
+ return false;
+ }
+ i++;
+ }
+ while (i < otherLength) {
+ K key = otherKeys[i];
+ if (key != null && key != REMOVED) {
+ return false;
+ }
+ i++;
+ }
+ return true;
+ }
+
+ @Override
+ public String toString() {
+ if (isEmpty()) {
+ return "{}";
+ }
+ StringBuilder result = new StringBuilder();
+ result.append('{');
+ for (int i = 0; i < keys.length; i++) {
+ K key = keys[i];
+ if (key != null && key != REMOVED) {
+ result.append(key).append('=').append(values[i]).append(',');
+ }
+ }
+ result.setCharAt(result.length() - 1, '}');
+ return result.toString();
+ }
+
+ private final class EntrySet extends AbstractSet<Entry<K,V>> {
+
+ @Override
+ public int size() {
+ return FastMap.this.size();
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return FastMap.this.isEmpty();
+ }
+
+ @Override
+ public boolean contains(Object o) {
+ return containsKey(o);
+ }
+
+ @Override
+ public Iterator<Entry<K,V>> iterator() {
+ return new EntryIterator();
+ }
+
+ @Override
+ public boolean add(Entry<K,V> t) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean remove(Object o) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean addAll(Collection<? extends Entry<K,V>> ts) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean retainAll(Collection<?> objects) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean removeAll(Collection<?> objects) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void clear() {
+ FastMap.this.clear();
+ }
+
+ private final class MapEntry implements Entry<K,V> {
+
+ private final int index;
+
+ private MapEntry(int index) {
+ this.index = index;
+ }
+
+ @Override
+ public K getKey() {
+ return keys[index];
+ }
+
+ @Override
+ public V getValue() {
+ return values[index];
+ }
+
+ @Override
+ public V setValue(V value) {
+ Preconditions.checkArgument(value != null);
+ V oldValue = values[index];
+ values[index] = value;
+ return oldValue;
+ }
+ }
+
+ private final class EntryIterator implements Iterator<Entry<K,V>> {
+
+ private int position;
+ private int lastNext = -1;
+
+ @Override
+ public boolean hasNext() {
+ goToNext();
+ return position < keys.length;
+ }
+
+ @Override
+ public Entry<K,V> next() {
+ goToNext();
+ lastNext = position;
+ if (position >= keys.length) {
+ throw new NoSuchElementException();
+ }
+ return new MapEntry(position++);
+ }
+
+ private void goToNext() {
+ int length = values.length;
+ while (position < length && values[position] == null) {
+ position++;
+ }
+ }
+
+ @Override
+ public void remove() {
+ iteratorRemove(lastNext);
+ }
+ }
+
+ }
+
+ private final class KeySet extends AbstractSet<K> {
+
+ @Override
+ public int size() {
+ return FastMap.this.size();
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return FastMap.this.isEmpty();
+ }
+
+ @Override
+ public boolean contains(Object o) {
+ return containsKey(o);
+ }
+
+ @Override
+ public Iterator<K> iterator() {
+ return new KeyIterator();
+ }
+
+ @Override
+ public boolean add(K t) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean remove(Object o) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean addAll(Collection<? extends K> ts) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean retainAll(Collection<?> objects) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean removeAll(Collection<?> objects) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void clear() {
+ FastMap.this.clear();
+ }
+
+ private final class KeyIterator implements Iterator<K> {
+
+ private int position;
+ private int lastNext = -1;
+
+ @Override
+ public boolean hasNext() {
+ goToNext();
+ return position < keys.length;
+ }
+
+ @Override
+ public K next() {
+ goToNext();
+ lastNext = position;
+ if (position >= keys.length) {
+ throw new NoSuchElementException();
+ }
+ return keys[position++];
+ }
+
+ private void goToNext() {
+ int length = values.length;
+ while (position < length && values[position] == null) {
+ position++;
+ }
+ }
+
+ @Override
+ public void remove() {
+ iteratorRemove(lastNext);
+ }
+ }
+
+ }
+
+ private final class ValueCollection extends AbstractCollection<V> {
+
+ @Override
+ public int size() {
+ return FastMap.this.size();
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return FastMap.this.isEmpty();
+ }
+
+ @Override
+ public boolean contains(Object o) {
+ return containsValue(o);
+ }
+
+ @Override
+ public Iterator<V> iterator() {
+ return new ValueIterator();
+ }
+
+ @Override
+ public boolean add(V v) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean remove(Object o) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean addAll(Collection<? extends V> vs) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean removeAll(Collection<?> objects) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean retainAll(Collection<?> objects) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void clear() {
+ FastMap.this.clear();
+ }
+
+ private final class ValueIterator implements Iterator<V> {
+
+ private int position;
+ private int lastNext = -1;
+
+ @Override
+ public boolean hasNext() {
+ goToNext();
+ return position < values.length;
+ }
+
+ @Override
+ public V next() {
+ goToNext();
+ lastNext = position;
+ if (position >= values.length) {
+ throw new NoSuchElementException();
+ }
+ return values[position++];
+ }
+
+ private void goToNext() {
+ int length = values.length;
+ while (position < length && values[position] == null) {
+ position++;
+ }
+ }
+
+ @Override
+ public void remove() {
+ iteratorRemove(lastNext);
+ }
+
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java
new file mode 100644
index 0000000..1863d2b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverage.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.io.Serializable;
+
+/**
+ * <p>
+ * A simple class that represents a fixed value of an average and count. This is useful
+ * when an API needs to return {@link RunningAverage} but is not in a position to accept
+ * updates to it.
+ * </p>
+ */
+public class FixedRunningAverage implements RunningAverage, Serializable {
+
+ private final double average;
+ private final int count;
+
+ public FixedRunningAverage(double average, int count) {
+ this.average = average;
+ this.count = count;
+ }
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ @Override
+ public synchronized void addDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ @Override
+ public synchronized void removeDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ @Override
+ public synchronized void changeDatum(double delta) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public synchronized int getCount() {
+ return count;
+ }
+
+ @Override
+ public synchronized double getAverage() {
+ return average;
+ }
+
+ @Override
+ public RunningAverage inverse() {
+ return new InvertedRunningAverage(this);
+ }
+
+ @Override
+ public synchronized String toString() {
+ return String.valueOf(average);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java
new file mode 100644
index 0000000..619b6b7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FixedRunningAverageAndStdDev.java
@@ -0,0 +1,51 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+/**
+ * <p>
+ * A simple class that represents a fixed value of an average, count and standard deviation. This is useful
+ * when an API needs to return {@link RunningAverageAndStdDev} but is not in a position to accept
+ * updates to it.
+ * </p>
+ */
+public final class FixedRunningAverageAndStdDev extends FixedRunningAverage implements RunningAverageAndStdDev {
+
+ private final double stdDev;
+
+ public FixedRunningAverageAndStdDev(double average, double stdDev, int count) {
+ super(average, count);
+ this.stdDev = stdDev;
+ }
+
+ @Override
+ public RunningAverageAndStdDev inverse() {
+ return new InvertedRunningAverageAndStdDev(this);
+ }
+
+ @Override
+ public synchronized String toString() {
+ return super.toString() + ',' + stdDev;
+ }
+
+ @Override
+ public double getStandardDeviation() {
+ return stdDev;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java
new file mode 100644
index 0000000..00d828f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverage.java
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.io.Serializable;
+
+/**
+ * <p>
+ * A simple class that can keep track of a running average of a series of numbers. One can add to or remove
+ * from the series, as well as update a datum in the series. The class does not actually keep track of the
+ * series of values, just its running average, so it doesn't even matter if you remove/change a value that
+ * wasn't added.
+ * </p>
+ */
+public class FullRunningAverage implements RunningAverage, Serializable {
+
+ private int count;
+ private double average;
+
+ public FullRunningAverage() {
+ this(0, Double.NaN);
+ }
+
+ public FullRunningAverage(int count, double average) {
+ this.count = count;
+ this.average = average;
+ }
+
+ /**
+ * @param datum
+ * new item to add to the running average
+ */
+ @Override
+ public synchronized void addDatum(double datum) {
+ if (++count == 1) {
+ average = datum;
+ } else {
+ average = average * (count - 1) / count + datum / count;
+ }
+ }
+
+ /**
+ * @param datum
+ * item to remove to the running average
+ * @throws IllegalStateException
+ * if count is 0
+ */
+ @Override
+ public synchronized void removeDatum(double datum) {
+ if (count == 0) {
+ throw new IllegalStateException();
+ }
+ if (--count == 0) {
+ average = Double.NaN;
+ } else {
+ average = average * (count + 1) / count - datum / count;
+ }
+ }
+
+ /**
+ * @param delta
+ * amount by which to change a datum in the running average
+ * @throws IllegalStateException
+ * if count is 0
+ */
+ @Override
+ public synchronized void changeDatum(double delta) {
+ if (count == 0) {
+ throw new IllegalStateException();
+ }
+ average += delta / count;
+ }
+
+ @Override
+ public synchronized int getCount() {
+ return count;
+ }
+
+ @Override
+ public synchronized double getAverage() {
+ return average;
+ }
+
+ @Override
+ public RunningAverage inverse() {
+ return new InvertedRunningAverage(this);
+ }
+
+ @Override
+ public synchronized String toString() {
+ return String.valueOf(average);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java
new file mode 100644
index 0000000..6212e66
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/FullRunningAverageAndStdDev.java
@@ -0,0 +1,107 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+/**
+ * <p>
+ * Extends {@link FullRunningAverage} to add a running standard deviation computation.
+ * Uses Welford's method, as described at http://www.johndcook.com/standard_deviation.html
+ * </p>
+ */
+public final class FullRunningAverageAndStdDev extends FullRunningAverage implements RunningAverageAndStdDev {
+
+ private double stdDev;
+ private double mk;
+ private double sk;
+
+ public FullRunningAverageAndStdDev() {
+ mk = 0.0;
+ sk = 0.0;
+ recomputeStdDev();
+ }
+
+ public FullRunningAverageAndStdDev(int count, double average, double mk, double sk) {
+ super(count, average);
+ this.mk = mk;
+ this.sk = sk;
+ recomputeStdDev();
+ }
+
+ public double getMk() {
+ return mk;
+ }
+
+ public double getSk() {
+ return sk;
+ }
+
+ @Override
+ public synchronized double getStandardDeviation() {
+ return stdDev;
+ }
+
+ @Override
+ public synchronized void addDatum(double datum) {
+ super.addDatum(datum);
+ int count = getCount();
+ if (count == 1) {
+ mk = datum;
+ sk = 0.0;
+ } else {
+ double oldmk = mk;
+ double diff = datum - oldmk;
+ mk += diff / count;
+ sk += diff * (datum - mk);
+ }
+ recomputeStdDev();
+ }
+
+ @Override
+ public synchronized void removeDatum(double datum) {
+ int oldCount = getCount();
+ super.removeDatum(datum);
+ double oldmk = mk;
+ mk = (oldCount * oldmk - datum) / (oldCount - 1);
+ sk -= (datum - mk) * (datum - oldmk);
+ recomputeStdDev();
+ }
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ @Override
+ public void changeDatum(double delta) {
+ throw new UnsupportedOperationException();
+ }
+
+ private synchronized void recomputeStdDev() {
+ int count = getCount();
+ stdDev = count > 1 ? Math.sqrt(sk / (count - 1)) : Double.NaN;
+ }
+
+ @Override
+ public RunningAverageAndStdDev inverse() {
+ return new InvertedRunningAverageAndStdDev(this);
+ }
+
+ @Override
+ public synchronized String toString() {
+ return String.valueOf(String.valueOf(getAverage()) + ',' + stdDev);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverage.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverage.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverage.java
new file mode 100644
index 0000000..0f94c22
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverage.java
@@ -0,0 +1,58 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+public final class InvertedRunningAverage implements RunningAverage {
+
+ private final RunningAverage delegate;
+
+ public InvertedRunningAverage(RunningAverage delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void addDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void removeDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void changeDatum(double delta) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int getCount() {
+ return delegate.getCount();
+ }
+
+ @Override
+ public double getAverage() {
+ return -delegate.getAverage();
+ }
+
+ @Override
+ public RunningAverage inverse() {
+ return delegate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageAndStdDev.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageAndStdDev.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageAndStdDev.java
new file mode 100644
index 0000000..147012d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/InvertedRunningAverageAndStdDev.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+public final class InvertedRunningAverageAndStdDev implements RunningAverageAndStdDev {
+
+ private final RunningAverageAndStdDev delegate;
+
+ public InvertedRunningAverageAndStdDev(RunningAverageAndStdDev delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void addDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void removeDatum(double datum) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void changeDatum(double delta) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int getCount() {
+ return delegate.getCount();
+ }
+
+ @Override
+ public double getAverage() {
+ return -delegate.getAverage();
+ }
+
+ @Override
+ public double getStandardDeviation() {
+ return delegate.getStandardDeviation();
+ }
+
+ @Override
+ public RunningAverageAndStdDev inverse() {
+ return delegate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIterator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIterator.java
new file mode 100644
index 0000000..5127df0
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveArrayIterator.java
@@ -0,0 +1,93 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.util.NoSuchElementException;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * While long[] is an Iterable, it is not an Iterable<Long>. This adapter class addresses that.
+ */
+public final class LongPrimitiveArrayIterator implements LongPrimitiveIterator {
+
+ private final long[] array;
+ private int position;
+ private final int max;
+
+ /**
+ * <p>
+ * Creates an {@link LongPrimitiveArrayIterator} over an entire array.
+ * </p>
+ *
+ * @param array
+ * array to iterate over
+ */
+ public LongPrimitiveArrayIterator(long[] array) {
+ this.array = Preconditions.checkNotNull(array); // yeah, not going to copy the array here, for performance
+ this.position = 0;
+ this.max = array.length;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return position < max;
+ }
+
+ @Override
+ public Long next() {
+ return nextLong();
+ }
+
+ @Override
+ public long nextLong() {
+ if (position >= array.length) {
+ throw new NoSuchElementException();
+ }
+ return array[position++];
+ }
+
+ @Override
+ public long peek() {
+ if (position >= array.length) {
+ throw new NoSuchElementException();
+ }
+ return array[position];
+ }
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void skip(int n) {
+ if (n > 0) {
+ position += n;
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "LongPrimitiveArrayIterator";
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveIterator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveIterator.java
new file mode 100644
index 0000000..0840749
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/LongPrimitiveIterator.java
@@ -0,0 +1,39 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+/**
+ * Adds notion of iterating over {@code long} primitives in the style of an {@link java.util.Iterator} -- as
+ * opposed to iterating over {@link Long}. Implementations of this interface however also implement
+ * {@link java.util.Iterator} and {@link Iterable} over {@link Long} for convenience.
+ */
+public interface LongPrimitiveIterator extends SkippingIterator<Long> {
+
+ /**
+ * @return next {@code long} in iteration
+ * @throws java.util.NoSuchElementException
+ * if no more elements exist in the iteration
+ */
+ long nextLong();
+
+ /**
+ * @return next {@code long} in iteration without advancing iteration
+ */
+ long peek();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RefreshHelper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RefreshHelper.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RefreshHelper.java
new file mode 100644
index 0000000..cc91560
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RefreshHelper.java
@@ -0,0 +1,122 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.locks.ReentrantLock;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A helper class for implementing {@link Refreshable}. This object is typically included in an implementation
+ * {@link Refreshable} to implement {@link Refreshable#refresh(Collection)}. It execute the class's own
+ * supplied update logic, after updating all the object's dependencies. This also ensures that dependencies
+ * are not updated multiple times.
+ */
+public final class RefreshHelper implements Refreshable {
+
+ private static final Logger log = LoggerFactory.getLogger(RefreshHelper.class);
+
+ private final List<Refreshable> dependencies;
+ private final ReentrantLock refreshLock;
+ private final Callable<?> refreshRunnable;
+
+ /**
+ * @param refreshRunnable
+ * encapsulates the containing object's own refresh logic
+ */
+ public RefreshHelper(Callable<?> refreshRunnable) {
+ this.dependencies = Lists.newArrayListWithCapacity(3);
+ this.refreshLock = new ReentrantLock();
+ this.refreshRunnable = refreshRunnable;
+ }
+
+ /** Add a dependency to be refreshed first when the encapsulating object does. */
+ public void addDependency(Refreshable refreshable) {
+ if (refreshable != null) {
+ dependencies.add(refreshable);
+ }
+ }
+
+ public void removeDependency(Refreshable refreshable) {
+ if (refreshable != null) {
+ dependencies.remove(refreshable);
+ }
+ }
+
+ /**
+ * Typically this is called in {@link Refreshable#refresh(java.util.Collection)} and is the entire body of
+ * that method.
+ */
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ if (refreshLock.tryLock()) {
+ try {
+ alreadyRefreshed = buildRefreshed(alreadyRefreshed);
+ for (Refreshable dependency : dependencies) {
+ maybeRefresh(alreadyRefreshed, dependency);
+ }
+ if (refreshRunnable != null) {
+ try {
+ refreshRunnable.call();
+ } catch (Exception e) {
+ log.warn("Unexpected exception while refreshing", e);
+ }
+ }
+ } finally {
+ refreshLock.unlock();
+ }
+ }
+ }
+
+ /**
+ * Creates a new and empty {@link Collection} if the method parameter is {@code null}.
+ *
+ * @param currentAlreadyRefreshed
+ * {@link Refreshable}s to refresh later on
+ * @return an empty {@link Collection} if the method param was {@code null} or the unmodified method
+ * param.
+ */
+ public static Collection<Refreshable> buildRefreshed(Collection<Refreshable> currentAlreadyRefreshed) {
+ return currentAlreadyRefreshed == null ? Sets.<Refreshable>newHashSetWithExpectedSize(3) : currentAlreadyRefreshed;
+ }
+
+ /**
+ * Adds the specified {@link Refreshable} to the given collection of {@link Refreshable}s if it is not
+ * already there and immediately refreshes it.
+ *
+ * @param alreadyRefreshed
+ * the collection of {@link Refreshable}s
+ * @param refreshable
+ * the {@link Refreshable} to potentially add and refresh
+ */
+ public static void maybeRefresh(Collection<Refreshable> alreadyRefreshed, Refreshable refreshable) {
+ if (!alreadyRefreshed.contains(refreshable)) {
+ alreadyRefreshed.add(refreshable);
+ log.info("Added refreshable: {}", refreshable);
+ refreshable.refresh(alreadyRefreshed);
+ log.info("Refreshed: {}", alreadyRefreshed);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Retriever.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Retriever.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Retriever.java
new file mode 100644
index 0000000..40da9de
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/Retriever.java
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+
+/**
+ * <p>
+ * Implementations can retrieve a value for a given key.
+ * </p>
+ */
+public interface Retriever<K,V> {
+
+ /**
+ * @param key key for which a value should be retrieved
+ * @return value for key
+ * @throws TasteException if an error occurs while retrieving the value
+ */
+ V get(K key) throws TasteException;
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverage.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverage.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverage.java
new file mode 100644
index 0000000..bf8e39c
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverage.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+/**
+ * <p>
+ * Interface for classes that can keep track of a running average of a series of numbers. One can add to or
+ * remove from the series, as well as update a datum in the series. The class does not actually keep track of
+ * the series of values, just its running average, so it doesn't even matter if you remove/change a value that
+ * wasn't added.
+ * </p>
+ */
+public interface RunningAverage {
+
+ /**
+ * @param datum
+ * new item to add to the running average
+ * @throws IllegalArgumentException
+ * if datum is {@link Double#NaN}
+ */
+ void addDatum(double datum);
+
+ /**
+ * @param datum
+ * item to remove to the running average
+ * @throws IllegalArgumentException
+ * if datum is {@link Double#NaN}
+ * @throws IllegalStateException
+ * if count is 0
+ */
+ void removeDatum(double datum);
+
+ /**
+ * @param delta
+ * amount by which to change a datum in the running average
+ * @throws IllegalArgumentException
+ * if delta is {@link Double#NaN}
+ * @throws IllegalStateException
+ * if count is 0
+ */
+ void changeDatum(double delta);
+
+ int getCount();
+
+ double getAverage();
+
+ /**
+ * @return a (possibly immutable) object whose average is the negative of this object's
+ */
+ RunningAverage inverse();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDev.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDev.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDev.java
new file mode 100644
index 0000000..4ac6108
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/RunningAverageAndStdDev.java
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+/**
+ * <p>
+ * Extends {@link RunningAverage} by adding standard deviation too.
+ * </p>
+ */
+public interface RunningAverageAndStdDev extends RunningAverage {
+
+ /** @return standard deviation of data */
+ double getStandardDeviation();
+
+ /**
+ * @return a (possibly immutable) object whose average is the negative of this object's
+ */
+ @Override
+ RunningAverageAndStdDev inverse();
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java
new file mode 100644
index 0000000..6da709d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SamplingLongPrimitiveIterator.java
@@ -0,0 +1,111 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.util.NoSuchElementException;
+
+import com.google.common.base.Preconditions;
+import org.apache.commons.math3.distribution.PascalDistribution;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+
+/**
+ * Wraps a {@link LongPrimitiveIterator} and returns only some subset of the elements that it would,
+ * as determined by a sampling rate parameter.
+ */
+public final class SamplingLongPrimitiveIterator extends AbstractLongPrimitiveIterator {
+
+ private final PascalDistribution geometricDistribution;
+ private final LongPrimitiveIterator delegate;
+ private long next;
+ private boolean hasNext;
+
+ public SamplingLongPrimitiveIterator(LongPrimitiveIterator delegate, double samplingRate) {
+ this(RandomUtils.getRandom(), delegate, samplingRate);
+ }
+
+ public SamplingLongPrimitiveIterator(RandomWrapper random, LongPrimitiveIterator delegate, double samplingRate) {
+ Preconditions.checkNotNull(delegate);
+ Preconditions.checkArgument(samplingRate > 0.0 && samplingRate <= 1.0, "Must be: 0.0 < samplingRate <= 1.0");
+ // Geometric distribution is special case of negative binomial (aka Pascal) with r=1:
+ geometricDistribution = new PascalDistribution(random.getRandomGenerator(), 1, samplingRate);
+ this.delegate = delegate;
+ this.hasNext = true;
+ doNext();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return hasNext;
+ }
+
+ @Override
+ public long nextLong() {
+ if (hasNext) {
+ long result = next;
+ doNext();
+ return result;
+ }
+ throw new NoSuchElementException();
+ }
+
+ @Override
+ public long peek() {
+ if (hasNext) {
+ return next;
+ }
+ throw new NoSuchElementException();
+ }
+
+ private void doNext() {
+ int toSkip = geometricDistribution.sample();
+ delegate.skip(toSkip);
+ if (delegate.hasNext()) {
+ next = delegate.next();
+ } else {
+ hasNext = false;
+ }
+ }
+
+ /**
+ * @throws UnsupportedOperationException
+ */
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void skip(int n) {
+ int toSkip = 0;
+ for (int i = 0; i < n; i++) {
+ toSkip += geometricDistribution.sample();
+ }
+ delegate.skip(toSkip);
+ if (delegate.hasNext()) {
+ next = delegate.next();
+ } else {
+ hasNext = false;
+ }
+ }
+
+ public static LongPrimitiveIterator maybeWrapIterator(LongPrimitiveIterator delegate, double samplingRate) {
+ return samplingRate >= 1.0 ? delegate : new SamplingLongPrimitiveIterator(delegate, samplingRate);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java
new file mode 100644
index 0000000..e88f98a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/common/SkippingIterator.java
@@ -0,0 +1,35 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.common;
+
+import java.util.Iterator;
+
+/**
+ * Adds ability to skip ahead in an iterator, perhaps more efficiently than by calling {@link #next()}
+ * repeatedly.
+ */
+public interface SkippingIterator<V> extends Iterator<V> {
+
+ /**
+ * Skip the next n elements supplied by this {@link Iterator}. If there are less than n elements remaining,
+ * this skips all remaining elements in the {@link Iterator}. This method has the same effect as calling
+ * {@link #next()} n times, except that it will never throw {@link java.util.NoSuchElementException}.
+ */
+ void skip(int n);
+
+}
[32/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
new file mode 100644
index 0000000..0b2c41b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/AbstractOnlineLogisticRegression.java
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Generic definition of a 1 of n logistic regression classifier that returns probabilities in
+ * response to a feature vector. This classifier uses 1 of n-1 coding where the 0-th category
+ * is not stored explicitly.
+ * <p/>
+ * Provides the SGD based algorithm for learning a logistic regression, but omits all
+ * annealing of learning rates. Any extension of this abstract class must define the overall
+ * and per-term annealing for themselves.
+ */
+public abstract class AbstractOnlineLogisticRegression extends AbstractVectorClassifier implements OnlineLearner {
+ // coefficients for the classification. This is a dense matrix
+ // that is (numCategories-1) x numFeatures
+ protected Matrix beta;
+
+ // number of categories we are classifying. This should the number of rows of beta plus one.
+ protected int numCategories;
+
+ protected int step;
+
+ // information about how long since coefficient rows were updated. This allows lazy regularization.
+ protected Vector updateSteps;
+
+ // information about how many updates we have had on a location. This allows per-term
+ // annealing a la confidence weighted learning.
+ protected Vector updateCounts;
+
+ // weight of the prior on beta
+ private double lambda = 1.0e-5;
+ protected PriorFunction prior;
+
+ // can we ignore any further regularization when doing classification?
+ private boolean sealed;
+
+ // by default we don't do any fancy training
+ private Gradient gradient = new DefaultGradient();
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param lambda New value of lambda, the weighting factor for the prior distribution.
+ * @return This, so other configurations can be chained.
+ */
+ public AbstractOnlineLogisticRegression lambda(double lambda) {
+ this.lambda = lambda;
+ return this;
+ }
+
+ /**
+ * Computes the inverse link function, by default the logistic link function.
+ *
+ * @param v The output of the linear combination in a GLM. Note that the value
+ * of v is disturbed.
+ * @return A version of v with the link function applied.
+ */
+ public static Vector link(Vector v) {
+ double max = v.maxValue();
+ if (max >= 40) {
+ // if max > 40, we subtract the large offset first
+ // the size of the max means that 1+sum(exp(v)) = sum(exp(v)) to within round-off
+ v.assign(Functions.minus(max)).assign(Functions.EXP);
+ return v.divide(v.norm(1));
+ } else {
+ v.assign(Functions.EXP);
+ return v.divide(1 + v.norm(1));
+ }
+ }
+
+ /**
+ * Computes the binomial logistic inverse link function.
+ *
+ * @param r The value to transform.
+ * @return The logit of r.
+ */
+ public static double link(double r) {
+ if (r < 0.0) {
+ double s = Math.exp(r);
+ return s / (1.0 + s);
+ } else {
+ double s = Math.exp(-r);
+ return 1.0 / (1.0 + s);
+ }
+ }
+
+ @Override
+ public Vector classifyNoLink(Vector instance) {
+ // apply pending regularization to whichever coefficients matter
+ regularize(instance);
+ return beta.times(instance);
+ }
+
+ public double classifyScalarNoLink(Vector instance) {
+ return beta.viewRow(0).dot(instance);
+ }
+
+ /**
+ * Returns n-1 probabilities, one for each category but the 0-th. The probability of the 0-th
+ * category is 1 - sum(this result).
+ *
+ * @param instance A vector of features to be classified.
+ * @return A vector of probabilities, one for each of the first n-1 categories.
+ */
+ @Override
+ public Vector classify(Vector instance) {
+ return link(classifyNoLink(instance));
+ }
+
+ /**
+ * Returns a single scalar probability in the case where we have two categories. Using this
+ * method avoids an extra vector allocation as opposed to calling classify() or an extra two
+ * vector allocations relative to classifyFull().
+ *
+ * @param instance The vector of features to be classified.
+ * @return The probability of the first of two categories.
+ * @throws IllegalArgumentException If the classifier doesn't have two categories.
+ */
+ @Override
+ public double classifyScalar(Vector instance) {
+ Preconditions.checkArgument(numCategories() == 2, "Can only call classifyScalar with two categories");
+
+ // apply pending regularization to whichever coefficients matter
+ regularize(instance);
+
+ // result is a vector with one element so we can just use dot product
+ return link(classifyScalarNoLink(instance));
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ unseal();
+
+ double learningRate = currentLearningRate();
+
+ // push coefficients back to zero based on the prior
+ regularize(instance);
+
+ // update each row of coefficients according to result
+ Vector gradient = this.gradient.apply(groupKey, actual, instance, this);
+ for (int i = 0; i < numCategories - 1; i++) {
+ double gradientBase = gradient.get(i);
+
+ // then we apply the gradientBase to the resulting element.
+ for (Element updateLocation : instance.nonZeroes()) {
+ int j = updateLocation.index();
+
+ double newValue = beta.getQuick(i, j) + gradientBase * learningRate * perTermLearningRate(j) * instance.get(j);
+ beta.setQuick(i, j, newValue);
+ }
+ }
+
+ // remember that these elements got updated
+ for (Element element : instance.nonZeroes()) {
+ int j = element.index();
+ updateSteps.setQuick(j, getStep());
+ updateCounts.incrementQuick(j, 1);
+ }
+ nextStep();
+
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ train(0, null, actual, instance);
+ }
+
+ public void regularize(Vector instance) {
+ if (updateSteps == null || isSealed()) {
+ return;
+ }
+
+ // anneal learning rate
+ double learningRate = currentLearningRate();
+
+ // here we lazily apply the prior to make up for our neglect
+ for (int i = 0; i < numCategories - 1; i++) {
+ for (Element updateLocation : instance.nonZeroes()) {
+ int j = updateLocation.index();
+ double missingUpdates = getStep() - updateSteps.get(j);
+ if (missingUpdates > 0) {
+ double rate = getLambda() * learningRate * perTermLearningRate(j);
+ double newValue = prior.age(beta.get(i, j), missingUpdates, rate);
+ beta.set(i, j, newValue);
+ updateSteps.set(j, getStep());
+ }
+ }
+ }
+ }
+
+ // these two abstract methods are how extensions can modify the basic learning behavior of this object.
+
+ public abstract double perTermLearningRate(int j);
+
+ public abstract double currentLearningRate();
+
+ public void setPrior(PriorFunction prior) {
+ this.prior = prior;
+ }
+
+ public void setGradient(Gradient gradient) {
+ this.gradient = gradient;
+ }
+
+ public PriorFunction getPrior() {
+ return prior;
+ }
+
+ public Matrix getBeta() {
+ close();
+ return beta;
+ }
+
+ public void setBeta(int i, int j, double betaIJ) {
+ beta.set(i, j, betaIJ);
+ }
+
+ @Override
+ public int numCategories() {
+ return numCategories;
+ }
+
+ public int numFeatures() {
+ return beta.numCols();
+ }
+
+ public double getLambda() {
+ return lambda;
+ }
+
+ public int getStep() {
+ return step;
+ }
+
+ protected void nextStep() {
+ step++;
+ }
+
+ public boolean isSealed() {
+ return sealed;
+ }
+
+ protected void unseal() {
+ sealed = false;
+ }
+
+ private void regularizeAll() {
+ Vector all = new DenseVector(beta.numCols());
+ all.assign(1);
+ regularize(all);
+ }
+
+ @Override
+ public void close() {
+ if (!sealed) {
+ step++;
+ regularizeAll();
+ sealed = true;
+ }
+ }
+
+ public void copyFrom(AbstractOnlineLogisticRegression other) {
+ // number of categories we are classifying. This should the number of rows of beta plus one.
+ Preconditions.checkArgument(numCategories == other.numCategories,
+ "Can't copy unless number of target categories is the same");
+
+ beta.assign(other.beta);
+
+ step = other.step;
+
+ updateSteps.assign(other.updateSteps);
+ updateCounts.assign(other.updateCounts);
+ }
+
+ public boolean validModel() {
+ double k = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+ @Override
+ public double apply(double v) {
+ return Double.isNaN(v) || Double.isInfinite(v) ? 1 : 0;
+ }
+ });
+ return k < 1;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
new file mode 100644
index 0000000..d00b021
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegression.java
@@ -0,0 +1,586 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.ep.EvolutionaryProcess;
+import org.apache.mahout.ep.Mapping;
+import org.apache.mahout.ep.Payload;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.stats.OnlineAuc;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.ExecutionException;
+
+/**
+ * This is a meta-learner that maintains a pool of ordinary
+ * {@link org.apache.mahout.classifier.sgd.OnlineLogisticRegression} learners. Each
+ * member of the pool has different learning rates. Whichever of the learners in the pool falls
+ * behind in terms of average log-likelihood will be tossed out and replaced with variants of the
+ * survivors. This will let us automatically derive an annealing schedule that optimizes learning
+ * speed. Since on-line learners tend to be IO bound anyway, it doesn't cost as much as it might
+ * seem that it would to maintain multiple learners in memory. Doing this adaptation on-line as we
+ * learn also decreases the number of learning rate parameters required and replaces the normal
+ * hyper-parameter search.
+ * <p/>
+ * One wrinkle is that the pool of learners that we maintain is actually a pool of
+ * {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} which themselves contain several OnlineLogisticRegression
+ * objects. These pools allow estimation
+ * of performance on the fly even if we make many passes through the data. This does, however,
+ * increase the cost of training since if we are using 5-fold cross-validation, each vector is used
+ * 4 times for training and once for classification. If this becomes a problem, then we should
+ * probably use a 2-way unbalanced train/test split rather than full cross validation. With the
+ * current default settings, we have 100 learners running. This is better than the alternative of
+ * running hundreds of training passes to find good hyper-parameters because we only have to parse
+ * and feature-ize our inputs once. If you already have good hyper-parameters, then you might
+ * prefer to just run one CrossFoldLearner with those settings.
+ * <p/>
+ * The fitness used here is AUC. Another alternative would be to try log-likelihood, but it is much
+ * easier to get bogus values of log-likelihood than with AUC and the results seem to accord pretty
+ * well. It would be nice to allow the fitness function to be pluggable. This use of AUC means that
+ * AdaptiveLogisticRegression is mostly suited for binary target variables. This will be fixed
+ * before long by extending OnlineAuc to handle non-binary cases or by using a different fitness
+ * value in non-binary cases.
+ */
+public class AdaptiveLogisticRegression implements OnlineLearner, Writable {
+ public static final int DEFAULT_THREAD_COUNT = 20;
+ public static final int DEFAULT_POOL_SIZE = 20;
+ private static final int SURVIVORS = 2;
+
+ private int record;
+ private int cutoff = 1000;
+ private int minInterval = 1000;
+ private int maxInterval = 1000;
+ private int currentStep = 1000;
+ private int bufferSize = 1000;
+
+ private List<TrainingExample> buffer = Lists.newArrayList();
+ private EvolutionaryProcess<Wrapper, CrossFoldLearner> ep;
+ private State<Wrapper, CrossFoldLearner> best;
+ private int threadCount = DEFAULT_THREAD_COUNT;
+ private int poolSize = DEFAULT_POOL_SIZE;
+ private State<Wrapper, CrossFoldLearner> seed;
+ private int numFeatures;
+
+ private boolean freezeSurvivors = true;
+
+ private static final Logger log = LoggerFactory.getLogger(AdaptiveLogisticRegression.class);
+
+ public AdaptiveLogisticRegression() {}
+
+ /**
+ * Uses {@link #DEFAULT_THREAD_COUNT} and {@link #DEFAULT_POOL_SIZE}
+ * @param numCategories The number of categories (labels) to train on
+ * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector)
+ * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use
+ *
+ * @see #AdaptiveLogisticRegression(int, int, org.apache.mahout.classifier.sgd.PriorFunction, int, int)
+ */
+ public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior) {
+ this(numCategories, numFeatures, prior, DEFAULT_THREAD_COUNT, DEFAULT_POOL_SIZE);
+ }
+
+ /**
+ *
+ * @param numCategories The number of categories (labels) to train on
+ * @param numFeatures The number of features used in creating the vectors (i.e. the cardinality of the vector)
+ * @param prior The {@link org.apache.mahout.classifier.sgd.PriorFunction} to use
+ * @param threadCount The number of threads to use for training
+ * @param poolSize The number of {@link org.apache.mahout.classifier.sgd.CrossFoldLearner} to use.
+ */
+ public AdaptiveLogisticRegression(int numCategories, int numFeatures, PriorFunction prior, int threadCount,
+ int poolSize) {
+ this.numFeatures = numFeatures;
+ this.threadCount = threadCount;
+ this.poolSize = poolSize;
+ seed = new State<Wrapper, CrossFoldLearner>(new double[2], 10);
+ Wrapper w = new Wrapper(numCategories, numFeatures, prior);
+ seed.setPayload(w);
+
+ Wrapper.setMappings(seed);
+ seed.setPayload(w);
+ setPoolSize(this.poolSize);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ train(record, null, actual, instance);
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ record++;
+
+ buffer.add(new TrainingExample(trackingKey, groupKey, actual, instance));
+ //don't train until we have enough examples
+ if (buffer.size() > bufferSize) {
+ trainWithBufferedExamples();
+ }
+ }
+
+ private void trainWithBufferedExamples() {
+ try {
+ this.best = ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() {
+ @Override
+ public double apply(Payload<CrossFoldLearner> z, double[] params) {
+ Wrapper x = (Wrapper) z;
+ for (TrainingExample example : buffer) {
+ x.train(example);
+ }
+ if (x.getLearner().validModel()) {
+ if (x.getLearner().numCategories() == 2) {
+ return x.wrapped.auc();
+ } else {
+ return x.wrapped.logLikelihood();
+ }
+ } else {
+ return Double.NaN;
+ }
+ }
+ });
+ } catch (InterruptedException e) {
+ // ignore ... shouldn't happen
+ log.warn("Ignoring exception", e);
+ } catch (ExecutionException e) {
+ throw new IllegalStateException(e.getCause());
+ }
+ buffer.clear();
+
+ if (record > cutoff) {
+ cutoff = nextStep(record);
+
+ // evolve based on new fitness
+ ep.mutatePopulation(SURVIVORS);
+
+ if (freezeSurvivors) {
+ // now grossly hack the top survivors so they stick around. Set their
+ // mutation rates small and also hack their learning rate to be small
+ // as well.
+ for (State<Wrapper, CrossFoldLearner> state : ep.getPopulation().subList(0, SURVIVORS)) {
+ Wrapper.freeze(state);
+ }
+ }
+ }
+
+ }
+
+ public int nextStep(int recordNumber) {
+ int stepSize = stepSize(recordNumber, 2.6);
+ if (stepSize < minInterval) {
+ stepSize = minInterval;
+ }
+
+ if (stepSize > maxInterval) {
+ stepSize = maxInterval;
+ }
+
+ int newCutoff = stepSize * (recordNumber / stepSize + 1);
+ if (newCutoff < cutoff + currentStep) {
+ newCutoff = cutoff + currentStep;
+ } else {
+ this.currentStep = stepSize;
+ }
+ return newCutoff;
+ }
+
+ public static int stepSize(int recordNumber, double multiplier) {
+ int[] bumps = {1, 2, 5};
+ double log = Math.floor(multiplier * Math.log10(recordNumber));
+ int bump = bumps[(int) log % bumps.length];
+ int scale = (int) Math.pow(10, Math.floor(log / bumps.length));
+
+ return bump * scale;
+ }
+
+ @Override
+ public void close() {
+ trainWithBufferedExamples();
+ try {
+ ep.parallelDo(new EvolutionaryProcess.Function<Payload<CrossFoldLearner>>() {
+ @Override
+ public double apply(Payload<CrossFoldLearner> payload, double[] params) {
+ CrossFoldLearner learner = ((Wrapper) payload).getLearner();
+ learner.close();
+ return learner.logLikelihood();
+ }
+ });
+ } catch (InterruptedException e) {
+ log.warn("Ignoring exception", e);
+ } catch (ExecutionException e) {
+ throw new IllegalStateException(e);
+ } finally {
+ ep.close();
+ }
+ }
+
+ /**
+ * How often should the evolutionary optimization of learning parameters occur?
+ *
+ * @param interval Number of training examples to use in each epoch of optimization.
+ */
+ public void setInterval(int interval) {
+ setInterval(interval, interval);
+ }
+
+ /**
+ * Starts optimization using the shorter interval and progresses to the longer using the specified
+ * number of steps per decade. Note that values < 200 are not accepted. Values even that small
+ * are unlikely to be useful.
+ *
+ * @param minInterval The minimum epoch length for the evolutionary optimization
+ * @param maxInterval The maximum epoch length
+ */
+ public void setInterval(int minInterval, int maxInterval) {
+ this.minInterval = Math.max(200, minInterval);
+ this.maxInterval = Math.max(200, maxInterval);
+ this.cutoff = minInterval * (record / minInterval + 1);
+ this.currentStep = minInterval;
+ bufferSize = Math.min(minInterval, bufferSize);
+ }
+
+ public final void setPoolSize(int poolSize) {
+ this.poolSize = poolSize;
+ setupOptimizer(poolSize);
+ }
+
+ public void setThreadCount(int threadCount) {
+ this.threadCount = threadCount;
+ setupOptimizer(poolSize);
+ }
+
+ public void setAucEvaluator(OnlineAuc auc) {
+ seed.getPayload().setAucEvaluator(auc);
+ setupOptimizer(poolSize);
+ }
+
+ private void setupOptimizer(int poolSize) {
+ ep = new EvolutionaryProcess<Wrapper, CrossFoldLearner>(threadCount, poolSize, seed);
+ }
+
+ /**
+ * Returns the size of the internal feature vector. Note that this is not the same as the number
+ * of distinct features, especially if feature hashing is being used.
+ *
+ * @return The internal feature vector size.
+ */
+ public int numFeatures() {
+ return numFeatures;
+ }
+
+ /**
+ * What is the AUC for the current best member of the population. If no member is best, usually
+ * because we haven't done any training yet, then the result is set to NaN.
+ *
+ * @return The AUC of the best member of the population or NaN if we can't figure that out.
+ */
+ public double auc() {
+ if (best == null) {
+ return Double.NaN;
+ } else {
+ Wrapper payload = best.getPayload();
+ return payload.getLearner().auc();
+ }
+ }
+
+ public State<Wrapper, CrossFoldLearner> getBest() {
+ return best;
+ }
+
+ public void setBest(State<Wrapper, CrossFoldLearner> best) {
+ this.best = best;
+ }
+
+ public int getRecord() {
+ return record;
+ }
+
+ public void setRecord(int record) {
+ this.record = record;
+ }
+
+ public int getMinInterval() {
+ return minInterval;
+ }
+
+ public int getMaxInterval() {
+ return maxInterval;
+ }
+
+ public int getNumCategories() {
+ return seed.getPayload().getLearner().numCategories();
+ }
+
+ public PriorFunction getPrior() {
+ return seed.getPayload().getLearner().getPrior();
+ }
+
+ public void setBuffer(List<TrainingExample> buffer) {
+ this.buffer = buffer;
+ }
+
+ public List<TrainingExample> getBuffer() {
+ return buffer;
+ }
+
+ public EvolutionaryProcess<Wrapper, CrossFoldLearner> getEp() {
+ return ep;
+ }
+
+ public void setEp(EvolutionaryProcess<Wrapper, CrossFoldLearner> ep) {
+ this.ep = ep;
+ }
+
+ public State<Wrapper, CrossFoldLearner> getSeed() {
+ return seed;
+ }
+
+ public void setSeed(State<Wrapper, CrossFoldLearner> seed) {
+ this.seed = seed;
+ }
+
+ public int getNumFeatures() {
+ return numFeatures;
+ }
+
+ public void setAveragingWindow(int averagingWindow) {
+ seed.getPayload().getLearner().setWindowSize(averagingWindow);
+ setupOptimizer(poolSize);
+ }
+
+ public void setFreezeSurvivors(boolean freezeSurvivors) {
+ this.freezeSurvivors = freezeSurvivors;
+ }
+
+ /**
+ * Provides a shim between the EP optimization stuff and the CrossFoldLearner. The most important
+ * interface has to do with the parameters of the optimization. These are taken from the double[]
+ * params in the following order <ul> <li> regularization constant lambda <li> learningRate </ul>.
+ * All other parameters are set in such a way so as to defeat annealing to the extent possible.
+ * This lets the evolutionary algorithm handle the annealing.
+ * <p/>
+ * Note that per coefficient annealing is still done and no optimization of the per coefficient
+ * offset is done.
+ */
+ public static class Wrapper implements Payload<CrossFoldLearner> {
+ private CrossFoldLearner wrapped;
+
+ public Wrapper() {
+ }
+
+ public Wrapper(int numCategories, int numFeatures, PriorFunction prior) {
+ wrapped = new CrossFoldLearner(5, numCategories, numFeatures, prior);
+ }
+
+ @Override
+ public Wrapper copy() {
+ Wrapper r = new Wrapper();
+ r.wrapped = wrapped.copy();
+ return r;
+ }
+
+ @Override
+ public void update(double[] params) {
+ int i = 0;
+ wrapped.lambda(params[i++]);
+ wrapped.learningRate(params[i]);
+
+ wrapped.stepOffset(1);
+ wrapped.alpha(1);
+ wrapped.decayExponent(0);
+ }
+
+ public static void freeze(State<Wrapper, CrossFoldLearner> s) {
+ // radically decrease learning rate
+ double[] params = s.getParams();
+ params[1] -= 10;
+
+ // and cause evolution to hold (almost)
+ s.setOmni(s.getOmni() / 20);
+ double[] step = s.getStep();
+ for (int i = 0; i < step.length; i++) {
+ step[i] /= 20;
+ }
+ }
+
+ public static void setMappings(State<Wrapper, CrossFoldLearner> x) {
+ int i = 0;
+ // set the range for regularization (lambda)
+ x.setMap(i++, Mapping.logLimit(1.0e-8, 0.1));
+ // set the range for learning rate (mu)
+ x.setMap(i, Mapping.logLimit(1.0e-8, 1));
+ }
+
+ public void train(TrainingExample example) {
+ wrapped.train(example.getKey(), example.getGroupKey(), example.getActual(), example.getInstance());
+ }
+
+ public CrossFoldLearner getLearner() {
+ return wrapped;
+ }
+
+ @Override
+ public String toString() {
+ return String.format(Locale.ENGLISH, "auc=%.2f", wrapped.auc());
+ }
+
+ public void setAucEvaluator(OnlineAuc auc) {
+ wrapped.setAucEvaluator(auc);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ wrapped.write(out);
+ }
+
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ wrapped = new CrossFoldLearner();
+ wrapped.readFields(input);
+ }
+ }
+
+ public static class TrainingExample implements Writable {
+ private long key;
+ private String groupKey;
+ private int actual;
+ private Vector instance;
+
+ private TrainingExample() {
+ }
+
+ public TrainingExample(long key, String groupKey, int actual, Vector instance) {
+ this.key = key;
+ this.groupKey = groupKey;
+ this.actual = actual;
+ this.instance = instance;
+ }
+
+ public long getKey() {
+ return key;
+ }
+
+ public int getActual() {
+ return actual;
+ }
+
+ public Vector getInstance() {
+ return instance;
+ }
+
+ public String getGroupKey() {
+ return groupKey;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeLong(key);
+ if (groupKey != null) {
+ out.writeBoolean(true);
+ out.writeUTF(groupKey);
+ } else {
+ out.writeBoolean(false);
+ }
+ out.writeInt(actual);
+ VectorWritable.writeVector(out, instance, true);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ key = in.readLong();
+ if (in.readBoolean()) {
+ groupKey = in.readUTF();
+ }
+ actual = in.readInt();
+ instance = VectorWritable.readVector(in);
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(record);
+ out.writeInt(cutoff);
+ out.writeInt(minInterval);
+ out.writeInt(maxInterval);
+ out.writeInt(currentStep);
+ out.writeInt(bufferSize);
+
+ out.writeInt(buffer.size());
+ for (TrainingExample example : buffer) {
+ example.write(out);
+ }
+
+ ep.write(out);
+
+ best.write(out);
+
+ out.writeInt(threadCount);
+ out.writeInt(poolSize);
+ seed.write(out);
+ out.writeInt(numFeatures);
+
+ out.writeBoolean(freezeSurvivors);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ record = in.readInt();
+ cutoff = in.readInt();
+ minInterval = in.readInt();
+ maxInterval = in.readInt();
+ currentStep = in.readInt();
+ bufferSize = in.readInt();
+
+ int n = in.readInt();
+ buffer = Lists.newArrayList();
+ for (int i = 0; i < n; i++) {
+ TrainingExample example = new TrainingExample();
+ example.readFields(in);
+ buffer.add(example);
+ }
+
+ ep = new EvolutionaryProcess<Wrapper, CrossFoldLearner>();
+ ep.readFields(in);
+
+ best = new State<Wrapper, CrossFoldLearner>();
+ best.readFields(in);
+
+ threadCount = in.readInt();
+ poolSize = in.readInt();
+ seed = new State<Wrapper, CrossFoldLearner>();
+ seed.readFields(in);
+
+ numFeatures = in.readInt();
+ freezeSurvivors = in.readBoolean();
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
new file mode 100644
index 0000000..36bcae0
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/CrossFoldLearner.java
@@ -0,0 +1,334 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleDoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.stats.GlobalOnlineAuc;
+import org.apache.mahout.math.stats.OnlineAuc;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Does cross-fold validation of log-likelihood and AUC on several online logistic regression
+ * models. Each record is passed to all but one of the models for training and to the remaining
+ * model for evaluation. In order to maintain proper segregation between the different folds across
+ * training data iterations, data should either be passed to this learner in the same order each
+ * time the training data is traversed or a tracking key such as the file offset of the training
+ * record should be passed with each training example.
+ */
+public class CrossFoldLearner extends AbstractVectorClassifier implements OnlineLearner, Writable {
+ private int record;
+ // minimum score to be used for computing log likelihood
+ private static final double MIN_SCORE = 1.0e-50;
+ private OnlineAuc auc = new GlobalOnlineAuc();
+ private double logLikelihood;
+ private final List<OnlineLogisticRegression> models = Lists.newArrayList();
+
+ // lambda, learningRate, perTermOffset, perTermExponent
+ private double[] parameters = new double[4];
+ private int numFeatures;
+ private PriorFunction prior;
+ private double percentCorrect;
+
+ private int windowSize = Integer.MAX_VALUE;
+
+ public CrossFoldLearner() {
+ }
+
+ public CrossFoldLearner(int folds, int numCategories, int numFeatures, PriorFunction prior) {
+ this.numFeatures = numFeatures;
+ this.prior = prior;
+ for (int i = 0; i < folds; i++) {
+ OnlineLogisticRegression model = new OnlineLogisticRegression(numCategories, numFeatures, prior);
+ model.alpha(1).stepOffset(0).decayExponent(0);
+ models.add(model);
+ }
+ }
+
+ // -------- builder-like configuration methods
+
+ public CrossFoldLearner lambda(double v) {
+ for (OnlineLogisticRegression model : models) {
+ model.lambda(v);
+ }
+ return this;
+ }
+
+ public CrossFoldLearner learningRate(double x) {
+ for (OnlineLogisticRegression model : models) {
+ model.learningRate(x);
+ }
+ return this;
+ }
+
+ public CrossFoldLearner stepOffset(int x) {
+ for (OnlineLogisticRegression model : models) {
+ model.stepOffset(x);
+ }
+ return this;
+ }
+
+ public CrossFoldLearner decayExponent(double x) {
+ for (OnlineLogisticRegression model : models) {
+ model.decayExponent(x);
+ }
+ return this;
+ }
+
+ public CrossFoldLearner alpha(double alpha) {
+ for (OnlineLogisticRegression model : models) {
+ model.alpha(alpha);
+ }
+ return this;
+ }
+
+ // -------- training methods
+ @Override
+ public void train(int actual, Vector instance) {
+ train(record, null, actual, instance);
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ record++;
+ int k = 0;
+ for (OnlineLogisticRegression model : models) {
+ if (k == mod(trackingKey, models.size())) {
+ Vector v = model.classifyFull(instance);
+ double score = Math.max(v.get(actual), MIN_SCORE);
+ logLikelihood += (Math.log(score) - logLikelihood) / Math.min(record, windowSize);
+
+ int correct = v.maxValueIndex() == actual ? 1 : 0;
+ percentCorrect += (correct - percentCorrect) / Math.min(record, windowSize);
+ if (numCategories() == 2) {
+ auc.addSample(actual, groupKey, v.get(1));
+ }
+ } else {
+ model.train(trackingKey, groupKey, actual, instance);
+ }
+ k++;
+ }
+ }
+
+ private static long mod(long x, int y) {
+ long r = x % y;
+ return r < 0 ? r + y : r;
+ }
+
+ @Override
+ public void close() {
+ for (OnlineLogisticRegression m : models) {
+ m.close();
+ }
+ }
+
+ public void resetLineCounter() {
+ record = 0;
+ }
+
+ public boolean validModel() {
+ boolean r = true;
+ for (OnlineLogisticRegression model : models) {
+ r &= model.validModel();
+ }
+ return r;
+ }
+
+ // -------- classification methods
+
+ @Override
+ public Vector classify(Vector instance) {
+ Vector r = new DenseVector(numCategories() - 1);
+ DoubleDoubleFunction scale = Functions.plusMult(1.0 / models.size());
+ for (OnlineLogisticRegression model : models) {
+ r.assign(model.classify(instance), scale);
+ }
+ return r;
+ }
+
+ @Override
+ public Vector classifyNoLink(Vector instance) {
+ Vector r = new DenseVector(numCategories() - 1);
+ DoubleDoubleFunction scale = Functions.plusMult(1.0 / models.size());
+ for (OnlineLogisticRegression model : models) {
+ r.assign(model.classifyNoLink(instance), scale);
+ }
+ return r;
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ double r = 0;
+ int n = 0;
+ for (OnlineLogisticRegression model : models) {
+ n++;
+ r += model.classifyScalar(instance);
+ }
+ return r / n;
+ }
+
+ // -------- status reporting methods
+
+ @Override
+ public int numCategories() {
+ return models.get(0).numCategories();
+ }
+
+ public double auc() {
+ return auc.auc();
+ }
+
+ public double logLikelihood() {
+ return logLikelihood;
+ }
+
+ public double percentCorrect() {
+ return percentCorrect;
+ }
+
+ // -------- evolutionary optimization
+
+ public CrossFoldLearner copy() {
+ CrossFoldLearner r = new CrossFoldLearner(models.size(), numCategories(), numFeatures, prior);
+ r.models.clear();
+ for (OnlineLogisticRegression model : models) {
+ model.close();
+ OnlineLogisticRegression newModel =
+ new OnlineLogisticRegression(model.numCategories(), model.numFeatures(), model.prior);
+ newModel.copyFrom(model);
+ r.models.add(newModel);
+ }
+ return r;
+ }
+
+ public int getRecord() {
+ return record;
+ }
+
+ public void setRecord(int record) {
+ this.record = record;
+ }
+
+ public OnlineAuc getAucEvaluator() {
+ return auc;
+ }
+
+ public void setAucEvaluator(OnlineAuc auc) {
+ this.auc = auc;
+ }
+
+ public double getLogLikelihood() {
+ return logLikelihood;
+ }
+
+ public void setLogLikelihood(double logLikelihood) {
+ this.logLikelihood = logLikelihood;
+ }
+
+ public List<OnlineLogisticRegression> getModels() {
+ return models;
+ }
+
+ public void addModel(OnlineLogisticRegression model) {
+ models.add(model);
+ }
+
+ public double[] getParameters() {
+ return parameters;
+ }
+
+ public void setParameters(double[] parameters) {
+ this.parameters = parameters;
+ }
+
+ public int getNumFeatures() {
+ return numFeatures;
+ }
+
+ public void setNumFeatures(int numFeatures) {
+ this.numFeatures = numFeatures;
+ }
+
+ public void setWindowSize(int windowSize) {
+ this.windowSize = windowSize;
+ auc.setWindowSize(windowSize);
+ }
+
+ public PriorFunction getPrior() {
+ return prior;
+ }
+
+ public void setPrior(PriorFunction prior) {
+ this.prior = prior;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(record);
+ PolymorphicWritable.write(out, auc);
+ out.writeDouble(logLikelihood);
+ out.writeInt(models.size());
+ for (OnlineLogisticRegression model : models) {
+ model.write(out);
+ }
+
+ for (double x : parameters) {
+ out.writeDouble(x);
+ }
+ out.writeInt(numFeatures);
+ PolymorphicWritable.write(out, prior);
+ out.writeDouble(percentCorrect);
+ out.writeInt(windowSize);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ record = in.readInt();
+ auc = PolymorphicWritable.read(in, OnlineAuc.class);
+ logLikelihood = in.readDouble();
+ int n = in.readInt();
+ for (int i = 0; i < n; i++) {
+ OnlineLogisticRegression olr = new OnlineLogisticRegression();
+ olr.readFields(in);
+ models.add(olr);
+ }
+ parameters = new double[4];
+ for (int i = 0; i < 4; i++) {
+ parameters[i] = in.readDouble();
+ }
+ numFeatures = in.readInt();
+ prior = PolymorphicWritable.read(in, PriorFunction.class);
+ percentCorrect = in.readDouble();
+ windowSize = in.readInt();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
new file mode 100644
index 0000000..b21860f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/CsvRecordFactory.java
@@ -0,0 +1,393 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.Function;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Collections2;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import org.apache.commons.csv.CSVUtils;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.vectorizer.encoders.ConstantValueEncoder;
+import org.apache.mahout.vectorizer.encoders.ContinuousValueEncoder;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+import org.apache.mahout.vectorizer.encoders.FeatureVectorEncoder;
+import org.apache.mahout.vectorizer.encoders.StaticWordValueEncoder;
+import org.apache.mahout.vectorizer.encoders.TextValueEncoder;
+
+import java.io.IOException;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.InvocationTargetException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Converts CSV data lines to vectors.
+ *
+ * Use of this class proceeds in a few steps.
+ * <ul>
+ * <li> At construction time, you tell the class about the target variable and provide
+ * a dictionary of the types of the predictor values. At this point,
+ * the class yet cannot decode inputs because it doesn't know the fields that are in the
+ * data records, nor their order.
+ * <li> Optionally, you tell the parser object about the possible values of the target
+ * variable. If you don't do this then you probably should set the number of distinct
+ * values so that the target variable values will be taken from a restricted range.
+ * <li> Later, when you get a list of the fields, typically from the first line of a CSV
+ * file, you tell the factory about these fields and it builds internal data structures
+ * that allow it to decode inputs. The most important internal state is the field numbers
+ * for various fields. After this point, you can use the factory for decoding data.
+ * <li> To encode data as a vector, you present a line of input to the factory and it
+ * mutates a vector that you provide. The factory also retains trace information so
+ * that it can approximately reverse engineer vectors later.
+ * <li> After converting data, you can ask for an explanation of the data in terms of
+ * terms and weights. In order to explain a vector accurately, the factory needs to
+ * have seen the particular values of categorical fields (typically during encoding vectors)
+ * and needs to have a reasonably small number of collisions in the vector encoding.
+ * </ul>
+ */
+public class CsvRecordFactory implements RecordFactory {
+ private static final String INTERCEPT_TERM = "Intercept Term";
+
+ private static final Map<String, Class<? extends FeatureVectorEncoder>> TYPE_DICTIONARY =
+ ImmutableMap.<String, Class<? extends FeatureVectorEncoder>>builder()
+ .put("continuous", ContinuousValueEncoder.class)
+ .put("numeric", ContinuousValueEncoder.class)
+ .put("n", ContinuousValueEncoder.class)
+ .put("word", StaticWordValueEncoder.class)
+ .put("w", StaticWordValueEncoder.class)
+ .put("text", TextValueEncoder.class)
+ .put("t", TextValueEncoder.class)
+ .build();
+
+ private final Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
+
+ private int target;
+ private final Dictionary targetDictionary;
+
+ //Which column is used for identify a CSV file line
+ private String idName;
+ private int id = -1;
+
+ private List<Integer> predictors;
+ private Map<Integer, FeatureVectorEncoder> predictorEncoders;
+ private int maxTargetValue = Integer.MAX_VALUE;
+ private final String targetName;
+ private final Map<String, String> typeMap;
+ private List<String> variableNames;
+ private boolean includeBiasTerm;
+ private static final String CANNOT_CONSTRUCT_CONVERTER =
+ "Unable to construct type converter... shouldn't be possible";
+
+ /**
+ * Parse a single line of CSV-formatted text.
+ *
+ * Separated to make changing this functionality for the entire class easier
+ * in the future.
+ * @param line - CSV formatted text
+ * @return List<String>
+ */
+ private List<String> parseCsvLine(String line) {
+ try {
+ return Arrays.asList(CSVUtils.parseLine(line));
+ }
+ catch (IOException e) {
+ List<String> list = Lists.newArrayList();
+ list.add(line);
+ return list;
+ }
+ }
+
+ private List<String> parseCsvLine(CharSequence line) {
+ return parseCsvLine(line.toString());
+ }
+
+ /**
+ * Construct a parser for CSV lines that encodes the parsed data in vector form.
+ * @param targetName The name of the target variable.
+ * @param typeMap A map describing the types of the predictor variables.
+ */
+ public CsvRecordFactory(String targetName, Map<String, String> typeMap) {
+ this.targetName = targetName;
+ this.typeMap = typeMap;
+ targetDictionary = new Dictionary();
+ }
+
+ public CsvRecordFactory(String targetName, String idName, Map<String, String> typeMap) {
+ this(targetName, typeMap);
+ this.idName = idName;
+ }
+
+ /**
+ * Defines the values and thus the encoding of values of the target variables. Note
+ * that any values of the target variable not present in this list will be given the
+ * value of the last member of the list.
+ * @param values The values the target variable can have.
+ */
+ @Override
+ public void defineTargetCategories(List<String> values) {
+ Preconditions.checkArgument(
+ values.size() <= maxTargetValue,
+ "Must have less than or equal to " + maxTargetValue + " categories for target variable, but found "
+ + values.size());
+ if (maxTargetValue == Integer.MAX_VALUE) {
+ maxTargetValue = values.size();
+ }
+
+ for (String value : values) {
+ targetDictionary.intern(value);
+ }
+ }
+
+ /**
+ * Defines the number of target variable categories, but allows this parser to
+ * pick encodings for them as they appear.
+ * @param max The number of categories that will be expected. Once this many have been
+ * seen, all others will get the encoding max-1.
+ */
+ @Override
+ public CsvRecordFactory maxTargetValue(int max) {
+ maxTargetValue = max;
+ return this;
+ }
+
+ @Override
+ public boolean usesFirstLineAsSchema() {
+ return true;
+ }
+
+ /**
+ * Processes the first line of a file (which should contain the variable names). The target and
+ * predictor column numbers are set from the names on this line.
+ *
+ * @param line Header line for the file.
+ */
+ @Override
+ public void firstLine(String line) {
+ // read variable names, build map of name -> column
+ final Map<String, Integer> vars = Maps.newHashMap();
+ variableNames = parseCsvLine(line);
+ int column = 0;
+ for (String var : variableNames) {
+ vars.put(var, column++);
+ }
+
+ // record target column and establish dictionary for decoding target
+ target = vars.get(targetName);
+
+ // record id column
+ if (idName != null) {
+ id = vars.get(idName);
+ }
+
+ // create list of predictor column numbers
+ predictors = Lists.newArrayList(Collections2.transform(typeMap.keySet(), new Function<String, Integer>() {
+ @Override
+ public Integer apply(String from) {
+ Integer r = vars.get(from);
+ Preconditions.checkArgument(r != null, "Can't find variable %s, only know about %s", from, vars);
+ return r;
+ }
+ }));
+
+ if (includeBiasTerm) {
+ predictors.add(-1);
+ }
+ Collections.sort(predictors);
+
+ // and map from column number to type encoder for each column that is a predictor
+ predictorEncoders = Maps.newHashMap();
+ for (Integer predictor : predictors) {
+ String name;
+ Class<? extends FeatureVectorEncoder> c;
+ if (predictor == -1) {
+ name = INTERCEPT_TERM;
+ c = ConstantValueEncoder.class;
+ } else {
+ name = variableNames.get(predictor);
+ c = TYPE_DICTIONARY.get(typeMap.get(name));
+ }
+ try {
+ Preconditions.checkArgument(c != null, "Invalid type of variable %s, wanted one of %s",
+ typeMap.get(name), TYPE_DICTIONARY.keySet());
+ Constructor<? extends FeatureVectorEncoder> constructor = c.getConstructor(String.class);
+ Preconditions.checkArgument(constructor != null, "Can't find correct constructor for %s", typeMap.get(name));
+ FeatureVectorEncoder encoder = constructor.newInstance(name);
+ predictorEncoders.put(predictor, encoder);
+ encoder.setTraceDictionary(traceDictionary);
+ } catch (InstantiationException e) {
+ throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+ } catch (IllegalAccessException e) {
+ throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+ } catch (InvocationTargetException e) {
+ throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+ } catch (NoSuchMethodException e) {
+ throw new IllegalStateException(CANNOT_CONSTRUCT_CONVERTER, e);
+ }
+ }
+ }
+
+
+ /**
+ * Decodes a single line of CSV data and records the target and predictor variables in a record.
+ * As a side effect, features are added into the featureVector. Returns the value of the target
+ * variable.
+ *
+ * @param line The raw data.
+ * @param featureVector Where to fill in the features. Should be zeroed before calling
+ * processLine.
+ * @return The value of the target variable.
+ */
+ @Override
+ public int processLine(String line, Vector featureVector) {
+ List<String> values = parseCsvLine(line);
+
+ int targetValue = targetDictionary.intern(values.get(target));
+ if (targetValue >= maxTargetValue) {
+ targetValue = maxTargetValue - 1;
+ }
+
+ for (Integer predictor : predictors) {
+ String value;
+ if (predictor >= 0) {
+ value = values.get(predictor);
+ } else {
+ value = null;
+ }
+ predictorEncoders.get(predictor).addToVector(value, featureVector);
+ }
+ return targetValue;
+ }
+
+ /***
+ * Decodes a single line of CSV data and records the target(if retrunTarget is true)
+ * and predictor variables in a record. As a side effect, features are added into the featureVector.
+ * Returns the value of the target variable. When used during classify against production data without
+ * target value, the method will be called with returnTarget = false.
+ * @param line The raw data.
+ * @param featureVector Where to fill in the features. Should be zeroed before calling
+ * processLine.
+ * @param returnTarget whether process and return target value, -1 will be returned if false.
+ * @return The value of the target variable.
+ */
+ public int processLine(CharSequence line, Vector featureVector, boolean returnTarget) {
+ List<String> values = parseCsvLine(line);
+ int targetValue = -1;
+ if (returnTarget) {
+ targetValue = targetDictionary.intern(values.get(target));
+ if (targetValue >= maxTargetValue) {
+ targetValue = maxTargetValue - 1;
+ }
+ }
+
+ for (Integer predictor : predictors) {
+ String value = predictor >= 0 ? values.get(predictor) : null;
+ predictorEncoders.get(predictor).addToVector(value, featureVector);
+ }
+ return targetValue;
+ }
+
+ /***
+ * Extract the raw target string from a line read from a CSV file.
+ * @param line the line of content read from CSV file
+ * @return the raw target value in the corresponding column of CSV line
+ */
+ public String getTargetString(CharSequence line) {
+ List<String> values = parseCsvLine(line);
+ return values.get(target);
+
+ }
+
+ /***
+ * Extract the corresponding raw target label according to a code
+ * @param code the integer code encoded during training process
+ * @return the raw target label
+ */
+ public String getTargetLabel(int code) {
+ for (String key : targetDictionary.values()) {
+ if (targetDictionary.intern(key) == code) {
+ return key;
+ }
+ }
+ return null;
+ }
+
+ /***
+ * Extract the id column value from the CSV record
+ * @param line the line of content read from CSV file
+ * @return the id value of the CSV record
+ */
+ public String getIdString(CharSequence line) {
+ List<String> values = parseCsvLine(line);
+ return values.get(id);
+ }
+
+ /**
+ * Returns a list of the names of the predictor variables.
+ *
+ * @return A list of variable names.
+ */
+ @Override
+ public Iterable<String> getPredictors() {
+ return Lists.transform(predictors, new Function<Integer, String>() {
+ @Override
+ public String apply(Integer v) {
+ if (v >= 0) {
+ return variableNames.get(v);
+ } else {
+ return INTERCEPT_TERM;
+ }
+ }
+ });
+ }
+
+ @Override
+ public Map<String, Set<Integer>> getTraceDictionary() {
+ return traceDictionary;
+ }
+
+ @Override
+ public CsvRecordFactory includeBiasTerm(boolean useBias) {
+ includeBiasTerm = useBias;
+ return this;
+ }
+
+ @Override
+ public List<String> getTargetCategories() {
+ List<String> r = targetDictionary.values();
+ if (r.size() > maxTargetValue) {
+ r.subList(maxTargetValue, r.size()).clear();
+ }
+ return r;
+ }
+
+ public String getIdName() {
+ return idName;
+ }
+
+ public void setIdName(String idName) {
+ this.idName = idName;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
new file mode 100644
index 0000000..f81d8ce
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/DefaultGradient.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Implements the basic logistic training law.
+ */
+public class DefaultGradient implements Gradient {
+ /**
+ * Provides a default gradient computation useful for logistic regression.
+ *
+ * @param groupKey A grouping key to allow per-something AUC loss to be used for training.
+ * @param actual The target variable value.
+ * @param instance The current feature vector to use for gradient computation
+ * @param classifier The classifier that can compute scores
+ * @return The gradient to be applied to beta
+ */
+ @Override
+ public final Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+ // what does the current model say?
+ Vector v = classifier.classify(instance);
+
+ Vector r = v.like();
+ if (actual != 0) {
+ r.setQuick(actual - 1, 1);
+ }
+ r.assign(v, Functions.MINUS);
+ return r;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
new file mode 100644
index 0000000..8128370
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/ElasticBandPrior.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Implements a linear combination of L1 and L2 priors. This can give an
+ * interesting mixture of sparsity and load-sharing between redundant predictors.
+ */
+public class ElasticBandPrior implements PriorFunction {
+ private double alphaByLambda;
+ private L1 l1;
+ private L2 l2;
+
+ // Exists for Writable
+ public ElasticBandPrior() {
+ this(0.0);
+ }
+
+ public ElasticBandPrior(double alphaByLambda) {
+ this.alphaByLambda = alphaByLambda;
+ l1 = new L1();
+ l2 = new L2(1);
+ }
+
+ @Override
+ public double age(double oldValue, double generations, double learningRate) {
+ oldValue *= Math.pow(1 - alphaByLambda * learningRate, generations);
+ double newValue = oldValue - Math.signum(oldValue) * learningRate * generations;
+ if (newValue * oldValue < 0.0) {
+ // don't allow the value to change sign
+ return 0.0;
+ } else {
+ return newValue;
+ }
+ }
+
+ @Override
+ public double logP(double betaIJ) {
+ return l1.logP(betaIJ) + alphaByLambda * l2.logP(betaIJ);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(alphaByLambda);
+ l1.write(out);
+ l2.write(out);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ alphaByLambda = in.readDouble();
+ l1 = new L1();
+ l1.readFields(in);
+ l2 = new L2();
+ l2.readFields(in);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
new file mode 100644
index 0000000..524fc06
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/Gradient.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Provides the ability to inject a gradient into the SGD logistic regresion.
+ * Typical uses of this are to use a ranking score such as AUC instead of a
+ * normal loss function.
+ */
+public interface Gradient {
+ Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
new file mode 100644
index 0000000..d158f4d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/GradientMachine.java
@@ -0,0 +1,405 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.collect.Sets;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Random;
+
+/**
+ * Online gradient machine learner that tries to minimize the label ranking hinge loss.
+ * Implements a gradient machine with one sigmpod hidden layer.
+ * It tries to minimize the ranking loss of some given set of labels,
+ * so this can be used for multi-class, multi-label
+ * or auto-encoding of sparse data (e.g. text).
+ */
+public class GradientMachine extends AbstractVectorClassifier implements OnlineLearner, Writable {
+
+ public static final int WRITABLE_VERSION = 1;
+
+ // the learning rate of the algorithm
+ private double learningRate = 0.1;
+
+ // the regularization term, a positive number that controls the size of the weight vector
+ private double regularization = 0.1;
+
+ // the sparsity term, a positive number that controls the sparsity of the hidden layer. (0 - 1)
+ private double sparsity = 0.1;
+
+ // the sparsity learning rate.
+ private double sparsityLearningRate = 0.1;
+
+ // the number of features
+ private int numFeatures = 10;
+ // the number of hidden nodes
+ private int numHidden = 100;
+ // the number of output nodes
+ private int numOutput = 2;
+
+ // coefficients for the input to hidden layer.
+ // There are numHidden Vectors of dimension numFeatures.
+ private Vector[] hiddenWeights;
+
+ // coefficients for the hidden to output layer.
+ // There are numOuput Vectors of dimension numHidden.
+ private Vector[] outputWeights;
+
+ // hidden unit bias
+ private Vector hiddenBias;
+
+ // output unit bias
+ private Vector outputBias;
+
+ private final Random rnd;
+
+ public GradientMachine(int numFeatures, int numHidden, int numOutput) {
+ this.numFeatures = numFeatures;
+ this.numHidden = numHidden;
+ this.numOutput = numOutput;
+ hiddenWeights = new DenseVector[numHidden];
+ for (int i = 0; i < numHidden; i++) {
+ hiddenWeights[i] = new DenseVector(numFeatures);
+ hiddenWeights[i].assign(0);
+ }
+ hiddenBias = new DenseVector(numHidden);
+ hiddenBias.assign(0);
+ outputWeights = new DenseVector[numOutput];
+ for (int i = 0; i < numOutput; i++) {
+ outputWeights[i] = new DenseVector(numHidden);
+ outputWeights[i].assign(0);
+ }
+ outputBias = new DenseVector(numOutput);
+ outputBias.assign(0);
+ rnd = RandomUtils.getRandom();
+ }
+
+ /**
+ * Initialize weights.
+ *
+ * @param gen random number generator.
+ */
+ public void initWeights(Random gen) {
+ double hiddenFanIn = 1.0 / Math.sqrt(numFeatures);
+ for (int i = 0; i < numHidden; i++) {
+ for (int j = 0; j < numFeatures; j++) {
+ double val = (2.0 * gen.nextDouble() - 1.0) * hiddenFanIn;
+ hiddenWeights[i].setQuick(j, val);
+ }
+ }
+ double outputFanIn = 1.0 / Math.sqrt(numHidden);
+ for (int i = 0; i < numOutput; i++) {
+ for (int j = 0; j < numHidden; j++) {
+ double val = (2.0 * gen.nextDouble() - 1.0) * outputFanIn;
+ outputWeights[i].setQuick(j, val);
+ }
+ }
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param learningRate New value of initial learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public GradientMachine learningRate(double learningRate) {
+ this.learningRate = learningRate;
+ return this;
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param regularization A positive value that controls the weight vector size.
+ * @return This, so other configurations can be chained.
+ */
+ public GradientMachine regularization(double regularization) {
+ this.regularization = regularization;
+ return this;
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param sparsity A value between zero and one that controls the fraction of hidden units
+ * that are activated on average.
+ * @return This, so other configurations can be chained.
+ */
+ public GradientMachine sparsity(double sparsity) {
+ this.sparsity = sparsity;
+ return this;
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param sparsityLearningRate New value of initial learning rate for sparsity.
+ * @return This, so other configurations can be chained.
+ */
+ public GradientMachine sparsityLearningRate(double sparsityLearningRate) {
+ this.sparsityLearningRate = sparsityLearningRate;
+ return this;
+ }
+
+ public void copyFrom(GradientMachine other) {
+ numFeatures = other.numFeatures;
+ numHidden = other.numHidden;
+ numOutput = other.numOutput;
+ learningRate = other.learningRate;
+ regularization = other.regularization;
+ sparsity = other.sparsity;
+ sparsityLearningRate = other.sparsityLearningRate;
+ hiddenWeights = new DenseVector[numHidden];
+ for (int i = 0; i < numHidden; i++) {
+ hiddenWeights[i] = other.hiddenWeights[i].clone();
+ }
+ hiddenBias = other.hiddenBias.clone();
+ outputWeights = new DenseVector[numOutput];
+ for (int i = 0; i < numOutput; i++) {
+ outputWeights[i] = other.outputWeights[i].clone();
+ }
+ outputBias = other.outputBias.clone();
+ }
+
+ @Override
+ public int numCategories() {
+ return numOutput;
+ }
+
+ public int numFeatures() {
+ return numFeatures;
+ }
+
+ public int numHidden() {
+ return numHidden;
+ }
+
+ /**
+ * Feeds forward from input to hidden unit..
+ *
+ * @return Hidden unit activations.
+ */
+ public DenseVector inputToHidden(Vector input) {
+ DenseVector activations = new DenseVector(numHidden);
+ for (int i = 0; i < numHidden; i++) {
+ activations.setQuick(i, hiddenWeights[i].dot(input));
+ }
+ activations.assign(hiddenBias, Functions.PLUS);
+ activations.assign(Functions.min(40.0)).assign(Functions.max(-40));
+ activations.assign(Functions.SIGMOID);
+ return activations;
+ }
+
+ /**
+ * Feeds forward from hidden to output
+ *
+ * @return Output unit activations.
+ */
+ public DenseVector hiddenToOutput(Vector hiddenActivation) {
+ DenseVector activations = new DenseVector(numOutput);
+ for (int i = 0; i < numOutput; i++) {
+ activations.setQuick(i, outputWeights[i].dot(hiddenActivation));
+ }
+ activations.assign(outputBias, Functions.PLUS);
+ return activations;
+ }
+
+ /**
+ * Updates using ranking loss.
+ *
+ * @param hiddenActivation the hidden unit's activation
+ * @param goodLabels the labels you want ranked above others.
+ * @param numTrials how many times you want to search for the highest scoring bad label.
+ * @param gen Random number generator.
+ */
+ public void updateRanking(Vector hiddenActivation,
+ Collection<Integer> goodLabels,
+ int numTrials,
+ Random gen) {
+ // All the labels are good, do nothing.
+ if (goodLabels.size() >= numOutput) {
+ return;
+ }
+ for (Integer good : goodLabels) {
+ double goodScore = outputWeights[good].dot(hiddenActivation);
+ int highestBad = -1;
+ double highestBadScore = Double.NEGATIVE_INFINITY;
+ for (int i = 0; i < numTrials; i++) {
+ int bad = gen.nextInt(numOutput);
+ while (goodLabels.contains(bad)) {
+ bad = gen.nextInt(numOutput);
+ }
+ double badScore = outputWeights[bad].dot(hiddenActivation);
+ if (badScore > highestBadScore) {
+ highestBadScore = badScore;
+ highestBad = bad;
+ }
+ }
+ int bad = highestBad;
+ double loss = 1.0 - goodScore + highestBadScore;
+ if (loss < 0.0) {
+ continue;
+ }
+ // Note from the loss above the gradient dloss/dy , y being the label is -1 for good
+ // and +1 for bad.
+ // dy / dw is just w since y = x' * w + b.
+ // Hence by the chain rule, dloss / dw = dloss / dy * dy / dw = -w.
+ // For the regularization part, 0.5 * lambda * w' w, the gradient is lambda * w.
+ // dy / db = 1.
+ Vector gradGood = outputWeights[good].clone();
+ gradGood.assign(Functions.NEGATE);
+ Vector propHidden = gradGood.clone();
+ Vector gradBad = outputWeights[bad].clone();
+ propHidden.assign(gradBad, Functions.PLUS);
+ gradGood.assign(Functions.mult(-learningRate * (1.0 - regularization)));
+ outputWeights[good].assign(gradGood, Functions.PLUS);
+ gradBad.assign(Functions.mult(-learningRate * (1.0 + regularization)));
+ outputWeights[bad].assign(gradBad, Functions.PLUS);
+ outputBias.setQuick(good, outputBias.get(good) + learningRate);
+ outputBias.setQuick(bad, outputBias.get(bad) - learningRate);
+ // Gradient of sigmoid is s * (1 -s).
+ Vector gradSig = hiddenActivation.clone();
+ gradSig.assign(Functions.SIGMOIDGRADIENT);
+ // Multiply by the change caused by the ranking loss.
+ for (int i = 0; i < numHidden; i++) {
+ gradSig.setQuick(i, gradSig.get(i) * propHidden.get(i));
+ }
+ for (int i = 0; i < numHidden; i++) {
+ for (int j = 0; j < numFeatures; j++) {
+ double v = hiddenWeights[i].get(j);
+ v -= learningRate * (gradSig.get(i) + regularization * v);
+ hiddenWeights[i].setQuick(j, v);
+ }
+ }
+ }
+ }
+
+ @Override
+ public Vector classify(Vector instance) {
+ Vector result = classifyNoLink(instance);
+ // Find the max value's index.
+ int max = result.maxValueIndex();
+ result.assign(0);
+ result.setQuick(max, 1.0);
+ return result.viewPart(1, result.size() - 1);
+ }
+
+ @Override
+ public Vector classifyNoLink(Vector instance) {
+ DenseVector hidden = inputToHidden(instance);
+ return hiddenToOutput(hidden);
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ Vector output = classifyNoLink(instance);
+ if (output.get(0) > output.get(1)) {
+ return 0;
+ }
+ return 1;
+ }
+
+ public GradientMachine copy() {
+ close();
+ GradientMachine r = new GradientMachine(numFeatures(), numHidden(), numCategories());
+ r.copyFrom(this);
+ return r;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(WRITABLE_VERSION);
+ out.writeDouble(learningRate);
+ out.writeDouble(regularization);
+ out.writeDouble(sparsity);
+ out.writeDouble(sparsityLearningRate);
+ out.writeInt(numFeatures);
+ out.writeInt(numHidden);
+ out.writeInt(numOutput);
+ VectorWritable.writeVector(out, hiddenBias);
+ for (int i = 0; i < numHidden; i++) {
+ VectorWritable.writeVector(out, hiddenWeights[i]);
+ }
+ VectorWritable.writeVector(out, outputBias);
+ for (int i = 0; i < numOutput; i++) {
+ VectorWritable.writeVector(out, outputWeights[i]);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int version = in.readInt();
+ if (version == WRITABLE_VERSION) {
+ learningRate = in.readDouble();
+ regularization = in.readDouble();
+ sparsity = in.readDouble();
+ sparsityLearningRate = in.readDouble();
+ numFeatures = in.readInt();
+ numHidden = in.readInt();
+ numOutput = in.readInt();
+ hiddenWeights = new DenseVector[numHidden];
+ hiddenBias = VectorWritable.readVector(in);
+ for (int i = 0; i < numHidden; i++) {
+ hiddenWeights[i] = VectorWritable.readVector(in);
+ }
+ outputWeights = new DenseVector[numOutput];
+ outputBias = VectorWritable.readVector(in);
+ for (int i = 0; i < numOutput; i++) {
+ outputWeights[i] = VectorWritable.readVector(in);
+ }
+ } else {
+ throw new IOException("Incorrect object version, wanted " + WRITABLE_VERSION + " got " + version);
+ }
+ }
+
+ @Override
+ public void close() {
+ // This is an online classifier, nothing to do.
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector instance) {
+ Vector hiddenActivation = inputToHidden(instance);
+ hiddenToOutput(hiddenActivation);
+ Collection<Integer> goodLabels = Sets.newHashSet();
+ goodLabels.add(actual);
+ updateRanking(hiddenActivation, goodLabels, 2, rnd);
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ train(0, null, actual, instance);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java
new file mode 100644
index 0000000..28a05f2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/L1.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Implements the Laplacian or bi-exponential prior. This prior has a strong tendency to set coefficients to zero
+ * and thus is useful as an alternative to variable selection. This version implements truncation which prevents
+ * a coefficient from changing sign. If a correction would change the sign, the coefficient is truncated to zero.
+ *
+ * Note that it doesn't matter to have a scale for this distribution because after taking the derivative of the logP,
+ * the lambda coefficient used to combine the prior with the observations has the same effect. If we had a scale here,
+ * then it would be the same effect as just changing lambda.
+ */
+public class L1 implements PriorFunction {
+ @Override
+ public double age(double oldValue, double generations, double learningRate) {
+ double newValue = oldValue - Math.signum(oldValue) * learningRate * generations;
+ if (newValue * oldValue < 0) {
+ // don't allow the value to change sign
+ return 0;
+ } else {
+ return newValue;
+ }
+ }
+
+ @Override
+ public double logP(double betaIJ) {
+ return -Math.abs(betaIJ);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ // stateless class has nothing to serialize
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) throws IOException {
+ // stateless class has nothing to serialize
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java
new file mode 100644
index 0000000..3dfb9fc
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/L2.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Implements the Gaussian prior. This prior has a tendency to decrease large coefficients toward zero, but
+ * doesn't tend to set them to exactly zero.
+ */
+public class L2 implements PriorFunction {
+
+ private static final double HALF_LOG_2PI = Math.log(2.0 * Math.PI) / 2.0;
+
+ private double s2;
+ private double s;
+
+ public L2(double scale) {
+ s = scale;
+ s2 = scale * scale;
+ }
+
+ public L2() {
+ s = 1.0;
+ s2 = 1.0;
+ }
+
+ @Override
+ public double age(double oldValue, double generations, double learningRate) {
+ return oldValue * Math.pow(1.0 - learningRate / s2, generations);
+ }
+
+ @Override
+ public double logP(double betaIJ) {
+ return -betaIJ * betaIJ / s2 / 2.0 - Math.log(s) - HALF_LOG_2PI;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(s2);
+ out.writeDouble(s);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ s2 = in.readDouble();
+ s = in.readDouble();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java b/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
new file mode 100644
index 0000000..a290b22
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/classifier/sgd/MixedGradient.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Vector;
+
+import java.util.Random;
+
+/**
+ * <p>Provides a stochastic mixture of ranking updates and normal logistic updates. This uses a
+ * combination of AUC driven learning to improve ranking performance and traditional log-loss driven
+ * learning to improve log-likelihood.</p>
+ *
+ * <p>See www.eecs.tufts.edu/~dsculley/papers/combined-ranking-and-regression.pdf</p>
+ *
+ * <p>This implementation only makes sense for the binomial case.</p>
+ */
+public class MixedGradient implements Gradient {
+
+ private final double alpha;
+ private final RankingGradient rank;
+ private final Gradient basic;
+ private final Random random = RandomUtils.getRandom();
+ private boolean hasZero;
+ private boolean hasOne;
+
+ public MixedGradient(double alpha, int window) {
+ this.alpha = alpha;
+ this.rank = new RankingGradient(window);
+ this.basic = this.rank.getBaseGradient();
+ }
+
+ @Override
+ public Vector apply(String groupKey, int actual, Vector instance, AbstractVectorClassifier classifier) {
+ if (random.nextDouble() < alpha) {
+ // one option is to apply a ranking update relative to our recent history
+ if (!hasZero || !hasOne) {
+ throw new IllegalStateException();
+ }
+ return rank.apply(groupKey, actual, instance, classifier);
+ } else {
+ hasZero |= actual == 0;
+ hasOne |= actual == 1;
+ // the other option is a normal update, but we have to update our history on the way
+ rank.addToHistory(actual, instance);
+ return basic.apply(groupKey, actual, instance, classifier);
+ }
+ }
+}
[44/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanItemPreferenceArray.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanItemPreferenceArray.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanItemPreferenceArray.java
new file mode 100644
index 0000000..6db5807
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanItemPreferenceArray.java
@@ -0,0 +1,234 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterators;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.iterator.CountingIterator;
+
+/**
+ * <p>
+ * Like {@link BooleanUserPreferenceArray} but stores preferences for one item (all item IDs the same) rather
+ * than one user.
+ * </p>
+ *
+ * @see BooleanPreference
+ * @see BooleanUserPreferenceArray
+ * @see GenericItemPreferenceArray
+ */
+public final class BooleanItemPreferenceArray implements PreferenceArray {
+
+ private final long[] ids;
+ private long id;
+
+ public BooleanItemPreferenceArray(int size) {
+ this.ids = new long[size];
+ this.id = Long.MIN_VALUE; // as a sort of 'unspecified' value
+ }
+
+ public BooleanItemPreferenceArray(List<? extends Preference> prefs, boolean forOneUser) {
+ this(prefs.size());
+ int size = prefs.size();
+ for (int i = 0; i < size; i++) {
+ Preference pref = prefs.get(i);
+ ids[i] = forOneUser ? pref.getItemID() : pref.getUserID();
+ }
+ if (size > 0) {
+ id = forOneUser ? prefs.get(0).getUserID() : prefs.get(0).getItemID();
+ }
+ }
+
+ /**
+ * This is a private copy constructor for clone().
+ */
+ private BooleanItemPreferenceArray(long[] ids, long id) {
+ this.ids = ids;
+ this.id = id;
+ }
+
+ @Override
+ public int length() {
+ return ids.length;
+ }
+
+ @Override
+ public Preference get(int i) {
+ return new PreferenceView(i);
+ }
+
+ @Override
+ public void set(int i, Preference pref) {
+ id = pref.getItemID();
+ ids[i] = pref.getUserID();
+ }
+
+ @Override
+ public long getUserID(int i) {
+ return ids[i];
+ }
+
+ @Override
+ public void setUserID(int i, long userID) {
+ ids[i] = userID;
+ }
+
+ @Override
+ public long getItemID(int i) {
+ return id;
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * Note that this method will actually set the item ID for <em>all</em> preferences.
+ */
+ @Override
+ public void setItemID(int i, long itemID) {
+ id = itemID;
+ }
+
+ /**
+ * @return all user IDs
+ */
+ @Override
+ public long[] getIDs() {
+ return ids;
+ }
+
+ @Override
+ public float getValue(int i) {
+ return 1.0f;
+ }
+
+ @Override
+ public void setValue(int i, float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void sortByUser() {
+ Arrays.sort(ids);
+ }
+
+ @Override
+ public void sortByItem() { }
+
+ @Override
+ public void sortByValue() { }
+
+ @Override
+ public void sortByValueReversed() { }
+
+ @Override
+ public boolean hasPrefWithUserID(long userID) {
+ for (long id : ids) {
+ if (userID == id) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public boolean hasPrefWithItemID(long itemID) {
+ return id == itemID;
+ }
+
+ @Override
+ public BooleanItemPreferenceArray clone() {
+ return new BooleanItemPreferenceArray(ids.clone(), id);
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) (id >> 32) ^ (int) id ^ Arrays.hashCode(ids);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof BooleanItemPreferenceArray)) {
+ return false;
+ }
+ BooleanItemPreferenceArray otherArray = (BooleanItemPreferenceArray) other;
+ return id == otherArray.id && Arrays.equals(ids, otherArray.ids);
+ }
+
+ @Override
+ public Iterator<Preference> iterator() {
+ return Iterators.transform(new CountingIterator(length()),
+ new Function<Integer, Preference>() {
+ @Override
+ public Preference apply(Integer from) {
+ return new PreferenceView(from);
+ }
+ });
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder result = new StringBuilder(10 * ids.length);
+ result.append("BooleanItemPreferenceArray[itemID:");
+ result.append(id);
+ result.append(",{");
+ for (int i = 0; i < ids.length; i++) {
+ if (i > 0) {
+ result.append(',');
+ }
+ result.append(ids[i]);
+ }
+ result.append("}]");
+ return result.toString();
+ }
+
+ private final class PreferenceView implements Preference {
+
+ private final int i;
+
+ private PreferenceView(int i) {
+ this.i = i;
+ }
+
+ @Override
+ public long getUserID() {
+ return BooleanItemPreferenceArray.this.getUserID(i);
+ }
+
+ @Override
+ public long getItemID() {
+ return BooleanItemPreferenceArray.this.getItemID(i);
+ }
+
+ @Override
+ public float getValue() {
+ return 1.0f;
+ }
+
+ @Override
+ public void setValue(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanPreference.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanPreference.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanPreference.java
new file mode 100644
index 0000000..2093af8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanPreference.java
@@ -0,0 +1,64 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.io.Serializable;
+
+import org.apache.mahout.cf.taste.model.Preference;
+
+/**
+ * Encapsulates a simple boolean "preference" for an item whose value does not matter (is fixed at 1.0). This
+ * is appropriate in situations where users conceptually have only a general "yes" preference for items,
+ * rather than a spectrum of preference values.
+ */
+public final class BooleanPreference implements Preference, Serializable {
+
+ private final long userID;
+ private final long itemID;
+
+ public BooleanPreference(long userID, long itemID) {
+ this.userID = userID;
+ this.itemID = itemID;
+ }
+
+ @Override
+ public long getUserID() {
+ return userID;
+ }
+
+ @Override
+ public long getItemID() {
+ return itemID;
+ }
+
+ @Override
+ public float getValue() {
+ return 1.0f;
+ }
+
+ @Override
+ public void setValue(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public String toString() {
+ return "BooleanPreference[userID: " + userID + ", itemID:" + itemID + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanUserPreferenceArray.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanUserPreferenceArray.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanUserPreferenceArray.java
new file mode 100644
index 0000000..629e0cf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/BooleanUserPreferenceArray.java
@@ -0,0 +1,234 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterators;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.iterator.CountingIterator;
+
+/**
+ * <p>
+ * Like {@link GenericUserPreferenceArray} but stores, conceptually, {@link BooleanPreference} objects which
+ * have no associated preference value.
+ * </p>
+ *
+ * @see BooleanPreference
+ * @see BooleanItemPreferenceArray
+ * @see GenericUserPreferenceArray
+ */
+public final class BooleanUserPreferenceArray implements PreferenceArray {
+
+ private final long[] ids;
+ private long id;
+
+ public BooleanUserPreferenceArray(int size) {
+ this.ids = new long[size];
+ this.id = Long.MIN_VALUE; // as a sort of 'unspecified' value
+ }
+
+ public BooleanUserPreferenceArray(List<? extends Preference> prefs) {
+ this(prefs.size());
+ int size = prefs.size();
+ for (int i = 0; i < size; i++) {
+ Preference pref = prefs.get(i);
+ ids[i] = pref.getItemID();
+ }
+ if (size > 0) {
+ id = prefs.get(0).getUserID();
+ }
+ }
+
+ /**
+ * This is a private copy constructor for clone().
+ */
+ private BooleanUserPreferenceArray(long[] ids, long id) {
+ this.ids = ids;
+ this.id = id;
+ }
+
+ @Override
+ public int length() {
+ return ids.length;
+ }
+
+ @Override
+ public Preference get(int i) {
+ return new PreferenceView(i);
+ }
+
+ @Override
+ public void set(int i, Preference pref) {
+ id = pref.getUserID();
+ ids[i] = pref.getItemID();
+ }
+
+ @Override
+ public long getUserID(int i) {
+ return id;
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * Note that this method will actually set the user ID for <em>all</em> preferences.
+ */
+ @Override
+ public void setUserID(int i, long userID) {
+ id = userID;
+ }
+
+ @Override
+ public long getItemID(int i) {
+ return ids[i];
+ }
+
+ @Override
+ public void setItemID(int i, long itemID) {
+ ids[i] = itemID;
+ }
+
+ /**
+ * @return all item IDs
+ */
+ @Override
+ public long[] getIDs() {
+ return ids;
+ }
+
+ @Override
+ public float getValue(int i) {
+ return 1.0f;
+ }
+
+ @Override
+ public void setValue(int i, float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void sortByUser() { }
+
+ @Override
+ public void sortByItem() {
+ Arrays.sort(ids);
+ }
+
+ @Override
+ public void sortByValue() { }
+
+ @Override
+ public void sortByValueReversed() { }
+
+ @Override
+ public boolean hasPrefWithUserID(long userID) {
+ return id == userID;
+ }
+
+ @Override
+ public boolean hasPrefWithItemID(long itemID) {
+ for (long id : ids) {
+ if (itemID == id) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public BooleanUserPreferenceArray clone() {
+ return new BooleanUserPreferenceArray(ids.clone(), id);
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) (id >> 32) ^ (int) id ^ Arrays.hashCode(ids);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof BooleanUserPreferenceArray)) {
+ return false;
+ }
+ BooleanUserPreferenceArray otherArray = (BooleanUserPreferenceArray) other;
+ return id == otherArray.id && Arrays.equals(ids, otherArray.ids);
+ }
+
+ @Override
+ public Iterator<Preference> iterator() {
+ return Iterators.transform(new CountingIterator(length()),
+ new Function<Integer, Preference>() {
+ @Override
+ public Preference apply(Integer from) {
+ return new PreferenceView(from);
+ }
+ });
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder result = new StringBuilder(10 * ids.length);
+ result.append("BooleanUserPreferenceArray[userID:");
+ result.append(id);
+ result.append(",{");
+ for (int i = 0; i < ids.length; i++) {
+ if (i > 0) {
+ result.append(',');
+ }
+ result.append(ids[i]);
+ }
+ result.append("}]");
+ return result.toString();
+ }
+
+ private final class PreferenceView implements Preference {
+
+ private final int i;
+
+ private PreferenceView(int i) {
+ this.i = i;
+ }
+
+ @Override
+ public long getUserID() {
+ return BooleanUserPreferenceArray.this.getUserID(i);
+ }
+
+ @Override
+ public long getItemID() {
+ return BooleanUserPreferenceArray.this.getItemID(i);
+ }
+
+ @Override
+ public float getValue() {
+ return 1.0f;
+ }
+
+ @Override
+ public void setValue(float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericBooleanPrefDataModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericBooleanPrefDataModel.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericBooleanPrefDataModel.java
new file mode 100644
index 0000000..2c1ff4d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericBooleanPrefDataModel.java
@@ -0,0 +1,320 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Map;
+
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveArrayIterator;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * A simple {@link DataModel} which uses given user data as its data source. This implementation
+ * is mostly useful for small experiments and is not recommended for contexts where performance is important.
+ * </p>
+ */
+public final class GenericBooleanPrefDataModel extends AbstractDataModel {
+
+ private final long[] userIDs;
+ private final FastByIDMap<FastIDSet> preferenceFromUsers;
+ private final long[] itemIDs;
+ private final FastByIDMap<FastIDSet> preferenceForItems;
+ private final FastByIDMap<FastByIDMap<Long>> timestamps;
+
+ /**
+ * <p>
+ * Creates a new {@link GenericDataModel} from the given users (and their preferences). This
+ * {@link DataModel} retains all this information in memory and is effectively immutable.
+ * </p>
+ *
+ * @param userData users to include
+ */
+ public GenericBooleanPrefDataModel(FastByIDMap<FastIDSet> userData) {
+ this(userData, null);
+ }
+
+ /**
+ * <p>
+ * Creates a new {@link GenericDataModel} from the given users (and their preferences). This
+ * {@link DataModel} retains all this information in memory and is effectively immutable.
+ * </p>
+ *
+ * @param userData users to include
+ * @param timestamps optionally, provided timestamps of preferences as milliseconds since the epoch.
+ * User IDs are mapped to maps of item IDs to Long timestamps.
+ */
+ public GenericBooleanPrefDataModel(FastByIDMap<FastIDSet> userData, FastByIDMap<FastByIDMap<Long>> timestamps) {
+ Preconditions.checkArgument(userData != null, "userData is null");
+
+ this.preferenceFromUsers = userData;
+ this.preferenceForItems = new FastByIDMap<>();
+ FastIDSet itemIDSet = new FastIDSet();
+ for (Map.Entry<Long, FastIDSet> entry : preferenceFromUsers.entrySet()) {
+ long userID = entry.getKey();
+ FastIDSet itemIDs = entry.getValue();
+ itemIDSet.addAll(itemIDs);
+ LongPrimitiveIterator it = itemIDs.iterator();
+ while (it.hasNext()) {
+ long itemID = it.nextLong();
+ FastIDSet userIDs = preferenceForItems.get(itemID);
+ if (userIDs == null) {
+ userIDs = new FastIDSet(2);
+ preferenceForItems.put(itemID, userIDs);
+ }
+ userIDs.add(userID);
+ }
+ }
+
+ this.itemIDs = itemIDSet.toArray();
+ itemIDSet = null; // Might help GC -- this is big
+ Arrays.sort(itemIDs);
+
+ this.userIDs = new long[userData.size()];
+ int i = 0;
+ LongPrimitiveIterator it = userData.keySetIterator();
+ while (it.hasNext()) {
+ userIDs[i++] = it.next();
+ }
+ Arrays.sort(userIDs);
+
+ this.timestamps = timestamps;
+ }
+
+ /**
+ * <p>
+ * Creates a new {@link GenericDataModel} containing an immutable copy of the data from another given
+ * {@link DataModel}.
+ * </p>
+ *
+ * @param dataModel
+ * {@link DataModel} to copy
+ * @throws TasteException
+ * if an error occurs while retrieving the other {@link DataModel}'s users
+ * @deprecated without direct replacement.
+ * Consider {@link #toDataMap(DataModel)} with {@link #GenericBooleanPrefDataModel(FastByIDMap)}
+ */
+ @Deprecated
+ public GenericBooleanPrefDataModel(DataModel dataModel) throws TasteException {
+ this(toDataMap(dataModel));
+ }
+
+ /**
+ * Exports the simple user IDs and associated item IDs in the data model.
+ *
+ * @return a {@link FastByIDMap} mapping user IDs to {@link FastIDSet}s representing
+ * that user's associated items
+ */
+ public static FastByIDMap<FastIDSet> toDataMap(DataModel dataModel) throws TasteException {
+ FastByIDMap<FastIDSet> data = new FastByIDMap<>(dataModel.getNumUsers());
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ long userID = it.nextLong();
+ data.put(userID, dataModel.getItemIDsFromUser(userID));
+ }
+ return data;
+ }
+
+ public static FastByIDMap<FastIDSet> toDataMap(FastByIDMap<PreferenceArray> data) {
+ for (Map.Entry<Long,Object> entry : ((FastByIDMap<Object>) (FastByIDMap<?>) data).entrySet()) {
+ PreferenceArray prefArray = (PreferenceArray) entry.getValue();
+ int size = prefArray.length();
+ FastIDSet itemIDs = new FastIDSet(size);
+ for (int i = 0; i < size; i++) {
+ itemIDs.add(prefArray.getItemID(i));
+ }
+ entry.setValue(itemIDs);
+ }
+ return (FastByIDMap<FastIDSet>) (FastByIDMap<?>) data;
+ }
+
+ /**
+ * This is used mostly internally to the framework, and shouldn't be relied upon otherwise.
+ */
+ public FastByIDMap<FastIDSet> getRawUserData() {
+ return this.preferenceFromUsers;
+ }
+
+ /**
+ * This is used mostly internally to the framework, and shouldn't be relied upon otherwise.
+ */
+ public FastByIDMap<FastIDSet> getRawItemData() {
+ return this.preferenceForItems;
+ }
+
+ @Override
+ public LongPrimitiveArrayIterator getUserIDs() {
+ return new LongPrimitiveArrayIterator(userIDs);
+ }
+
+ /**
+ * @throws NoSuchUserException
+ * if there is no such user
+ */
+ @Override
+ public PreferenceArray getPreferencesFromUser(long userID) throws NoSuchUserException {
+ FastIDSet itemIDs = preferenceFromUsers.get(userID);
+ if (itemIDs == null) {
+ throw new NoSuchUserException(userID);
+ }
+ PreferenceArray prefArray = new BooleanUserPreferenceArray(itemIDs.size());
+ int i = 0;
+ LongPrimitiveIterator it = itemIDs.iterator();
+ while (it.hasNext()) {
+ prefArray.setUserID(i, userID);
+ prefArray.setItemID(i, it.nextLong());
+ i++;
+ }
+ return prefArray;
+ }
+
+ @Override
+ public FastIDSet getItemIDsFromUser(long userID) throws TasteException {
+ FastIDSet itemIDs = preferenceFromUsers.get(userID);
+ if (itemIDs == null) {
+ throw new NoSuchUserException(userID);
+ }
+ return itemIDs;
+ }
+
+ @Override
+ public LongPrimitiveArrayIterator getItemIDs() {
+ return new LongPrimitiveArrayIterator(itemIDs);
+ }
+
+ @Override
+ public PreferenceArray getPreferencesForItem(long itemID) throws NoSuchItemException {
+ FastIDSet userIDs = preferenceForItems.get(itemID);
+ if (userIDs == null) {
+ throw new NoSuchItemException(itemID);
+ }
+ PreferenceArray prefArray = new BooleanItemPreferenceArray(userIDs.size());
+ int i = 0;
+ LongPrimitiveIterator it = userIDs.iterator();
+ while (it.hasNext()) {
+ prefArray.setUserID(i, it.nextLong());
+ prefArray.setItemID(i, itemID);
+ i++;
+ }
+ return prefArray;
+ }
+
+ @Override
+ public Float getPreferenceValue(long userID, long itemID) throws NoSuchUserException {
+ FastIDSet itemIDs = preferenceFromUsers.get(userID);
+ if (itemIDs == null) {
+ throw new NoSuchUserException(userID);
+ }
+ if (itemIDs.contains(itemID)) {
+ return 1.0f;
+ }
+ return null;
+ }
+
+ @Override
+ public Long getPreferenceTime(long userID, long itemID) throws TasteException {
+ if (timestamps == null) {
+ return null;
+ }
+ FastByIDMap<Long> itemTimestamps = timestamps.get(userID);
+ if (itemTimestamps == null) {
+ throw new NoSuchUserException(userID);
+ }
+ return itemTimestamps.get(itemID);
+ }
+
+ @Override
+ public int getNumItems() {
+ return itemIDs.length;
+ }
+
+ @Override
+ public int getNumUsers() {
+ return userIDs.length;
+ }
+
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID) {
+ FastIDSet userIDs1 = preferenceForItems.get(itemID);
+ return userIDs1 == null ? 0 : userIDs1.size();
+ }
+
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) {
+ FastIDSet userIDs1 = preferenceForItems.get(itemID1);
+ if (userIDs1 == null) {
+ return 0;
+ }
+ FastIDSet userIDs2 = preferenceForItems.get(itemID2);
+ if (userIDs2 == null) {
+ return 0;
+ }
+ return userIDs1.size() < userIDs2.size()
+ ? userIDs2.intersectionSize(userIDs1)
+ : userIDs1.intersectionSize(userIDs2);
+ }
+
+ @Override
+ public void removePreference(long userID, long itemID) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setPreference(long userID, long itemID, float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // Does nothing
+ }
+
+ @Override
+ public boolean hasPreferenceValues() {
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder result = new StringBuilder(200);
+ result.append("GenericBooleanPrefDataModel[users:");
+ for (int i = 0; i < Math.min(3, userIDs.length); i++) {
+ if (i > 0) {
+ result.append(',');
+ }
+ result.append(userIDs[i]);
+ }
+ if (userIDs.length > 3) {
+ result.append("...");
+ }
+ result.append(']');
+ return result.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericDataModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericDataModel.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericDataModel.java
new file mode 100644
index 0000000..f58d349
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericDataModel.java
@@ -0,0 +1,361 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveArrayIterator;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * A simple {@link DataModel} which uses a given {@link List} of users as its data source. This implementation
+ * is mostly useful for small experiments and is not recommended for contexts where performance is important.
+ * </p>
+ */
+public final class GenericDataModel extends AbstractDataModel {
+
+ private static final Logger log = LoggerFactory.getLogger(GenericDataModel.class);
+
+ private final long[] userIDs;
+ private final FastByIDMap<PreferenceArray> preferenceFromUsers;
+ private final long[] itemIDs;
+ private final FastByIDMap<PreferenceArray> preferenceForItems;
+ private final FastByIDMap<FastByIDMap<Long>> timestamps;
+
+ /**
+ * <p>
+ * Creates a new {@link GenericDataModel} from the given users (and their preferences). This
+ * {@link DataModel} retains all this information in memory and is effectively immutable.
+ * </p>
+ *
+ * @param userData users to include; (see also {@link #toDataMap(FastByIDMap, boolean)})
+ */
+ public GenericDataModel(FastByIDMap<PreferenceArray> userData) {
+ this(userData, null);
+ }
+
+ /**
+ * <p>
+ * Creates a new {@link GenericDataModel} from the given users (and their preferences). This
+ * {@link DataModel} retains all this information in memory and is effectively immutable.
+ * </p>
+ *
+ * @param userData users to include; (see also {@link #toDataMap(FastByIDMap, boolean)})
+ * @param timestamps optionally, provided timestamps of preferences as milliseconds since the epoch.
+ * User IDs are mapped to maps of item IDs to Long timestamps.
+ */
+ public GenericDataModel(FastByIDMap<PreferenceArray> userData, FastByIDMap<FastByIDMap<Long>> timestamps) {
+ Preconditions.checkArgument(userData != null, "userData is null");
+
+ this.preferenceFromUsers = userData;
+ FastByIDMap<Collection<Preference>> prefsForItems = new FastByIDMap<>();
+ FastIDSet itemIDSet = new FastIDSet();
+ int currentCount = 0;
+ float maxPrefValue = Float.NEGATIVE_INFINITY;
+ float minPrefValue = Float.POSITIVE_INFINITY;
+ for (Map.Entry<Long, PreferenceArray> entry : preferenceFromUsers.entrySet()) {
+ PreferenceArray prefs = entry.getValue();
+ prefs.sortByItem();
+ for (Preference preference : prefs) {
+ long itemID = preference.getItemID();
+ itemIDSet.add(itemID);
+ Collection<Preference> prefsForItem = prefsForItems.get(itemID);
+ if (prefsForItem == null) {
+ prefsForItem = Lists.newArrayListWithCapacity(2);
+ prefsForItems.put(itemID, prefsForItem);
+ }
+ prefsForItem.add(preference);
+ float value = preference.getValue();
+ if (value > maxPrefValue) {
+ maxPrefValue = value;
+ }
+ if (value < minPrefValue) {
+ minPrefValue = value;
+ }
+ }
+ if (++currentCount % 10000 == 0) {
+ log.info("Processed {} users", currentCount);
+ }
+ }
+ log.info("Processed {} users", currentCount);
+
+ setMinPreference(minPrefValue);
+ setMaxPreference(maxPrefValue);
+
+ this.itemIDs = itemIDSet.toArray();
+ itemIDSet = null; // Might help GC -- this is big
+ Arrays.sort(itemIDs);
+
+ this.preferenceForItems = toDataMap(prefsForItems, false);
+
+ for (Map.Entry<Long, PreferenceArray> entry : preferenceForItems.entrySet()) {
+ entry.getValue().sortByUser();
+ }
+
+ this.userIDs = new long[userData.size()];
+ int i = 0;
+ LongPrimitiveIterator it = userData.keySetIterator();
+ while (it.hasNext()) {
+ userIDs[i++] = it.next();
+ }
+ Arrays.sort(userIDs);
+
+ this.timestamps = timestamps;
+ }
+
+ /**
+ * <p>
+ * Creates a new {@link GenericDataModel} containing an immutable copy of the data from another given
+ * {@link DataModel}.
+ * </p>
+ *
+ * @param dataModel {@link DataModel} to copy
+ * @throws TasteException if an error occurs while retrieving the other {@link DataModel}'s users
+ * @deprecated without direct replacement.
+ * Consider {@link #toDataMap(DataModel)} with {@link #GenericDataModel(FastByIDMap)}
+ */
+ @Deprecated
+ public GenericDataModel(DataModel dataModel) throws TasteException {
+ this(toDataMap(dataModel));
+ }
+
+ /**
+ * Swaps, in-place, {@link List}s for arrays in {@link Map} values .
+ *
+ * @return input value
+ */
+ public static FastByIDMap<PreferenceArray> toDataMap(FastByIDMap<Collection<Preference>> data,
+ boolean byUser) {
+ for (Map.Entry<Long,Object> entry : ((FastByIDMap<Object>) (FastByIDMap<?>) data).entrySet()) {
+ List<Preference> prefList = (List<Preference>) entry.getValue();
+ entry.setValue(byUser ? new GenericUserPreferenceArray(prefList) : new GenericItemPreferenceArray(
+ prefList));
+ }
+ return (FastByIDMap<PreferenceArray>) (FastByIDMap<?>) data;
+ }
+
+ /**
+ * Exports the simple user IDs and preferences in the data model.
+ *
+ * @return a {@link FastByIDMap} mapping user IDs to {@link PreferenceArray}s representing
+ * that user's preferences
+ */
+ public static FastByIDMap<PreferenceArray> toDataMap(DataModel dataModel) throws TasteException {
+ FastByIDMap<PreferenceArray> data = new FastByIDMap<>(dataModel.getNumUsers());
+ LongPrimitiveIterator it = dataModel.getUserIDs();
+ while (it.hasNext()) {
+ long userID = it.nextLong();
+ data.put(userID, dataModel.getPreferencesFromUser(userID));
+ }
+ return data;
+ }
+
+ /**
+ * This is used mostly internally to the framework, and shouldn't be relied upon otherwise.
+ */
+ public FastByIDMap<PreferenceArray> getRawUserData() {
+ return this.preferenceFromUsers;
+ }
+
+ /**
+ * This is used mostly internally to the framework, and shouldn't be relied upon otherwise.
+ */
+ public FastByIDMap<PreferenceArray> getRawItemData() {
+ return this.preferenceForItems;
+ }
+
+ @Override
+ public LongPrimitiveArrayIterator getUserIDs() {
+ return new LongPrimitiveArrayIterator(userIDs);
+ }
+
+ /**
+ * @throws NoSuchUserException
+ * if there is no such user
+ */
+ @Override
+ public PreferenceArray getPreferencesFromUser(long userID) throws NoSuchUserException {
+ PreferenceArray prefs = preferenceFromUsers.get(userID);
+ if (prefs == null) {
+ throw new NoSuchUserException(userID);
+ }
+ return prefs;
+ }
+
+ @Override
+ public FastIDSet getItemIDsFromUser(long userID) throws TasteException {
+ PreferenceArray prefs = getPreferencesFromUser(userID);
+ int size = prefs.length();
+ FastIDSet result = new FastIDSet(size);
+ for (int i = 0; i < size; i++) {
+ result.add(prefs.getItemID(i));
+ }
+ return result;
+ }
+
+ @Override
+ public LongPrimitiveArrayIterator getItemIDs() {
+ return new LongPrimitiveArrayIterator(itemIDs);
+ }
+
+ @Override
+ public PreferenceArray getPreferencesForItem(long itemID) throws NoSuchItemException {
+ PreferenceArray prefs = preferenceForItems.get(itemID);
+ if (prefs == null) {
+ throw new NoSuchItemException(itemID);
+ }
+ return prefs;
+ }
+
+ @Override
+ public Float getPreferenceValue(long userID, long itemID) throws TasteException {
+ PreferenceArray prefs = getPreferencesFromUser(userID);
+ int size = prefs.length();
+ for (int i = 0; i < size; i++) {
+ if (prefs.getItemID(i) == itemID) {
+ return prefs.getValue(i);
+ }
+ }
+ return null;
+ }
+
+ @Override
+ public Long getPreferenceTime(long userID, long itemID) throws TasteException {
+ if (timestamps == null) {
+ return null;
+ }
+ FastByIDMap<Long> itemTimestamps = timestamps.get(userID);
+ if (itemTimestamps == null) {
+ throw new NoSuchUserException(userID);
+ }
+ return itemTimestamps.get(itemID);
+ }
+
+ @Override
+ public int getNumItems() {
+ return itemIDs.length;
+ }
+
+ @Override
+ public int getNumUsers() {
+ return userIDs.length;
+ }
+
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID) {
+ PreferenceArray prefs1 = preferenceForItems.get(itemID);
+ return prefs1 == null ? 0 : prefs1.length();
+ }
+
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) {
+ PreferenceArray prefs1 = preferenceForItems.get(itemID1);
+ if (prefs1 == null) {
+ return 0;
+ }
+ PreferenceArray prefs2 = preferenceForItems.get(itemID2);
+ if (prefs2 == null) {
+ return 0;
+ }
+
+ int size1 = prefs1.length();
+ int size2 = prefs2.length();
+ int count = 0;
+ int i = 0;
+ int j = 0;
+ long userID1 = prefs1.getUserID(0);
+ long userID2 = prefs2.getUserID(0);
+ while (true) {
+ if (userID1 < userID2) {
+ if (++i == size1) {
+ break;
+ }
+ userID1 = prefs1.getUserID(i);
+ } else if (userID1 > userID2) {
+ if (++j == size2) {
+ break;
+ }
+ userID2 = prefs2.getUserID(j);
+ } else {
+ count++;
+ if (++i == size1 || ++j == size2) {
+ break;
+ }
+ userID1 = prefs1.getUserID(i);
+ userID2 = prefs2.getUserID(j);
+ }
+ }
+ return count;
+ }
+
+ @Override
+ public void removePreference(long userID, long itemID) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setPreference(long userID, long itemID, float value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ // Does nothing
+ }
+
+ @Override
+ public boolean hasPreferenceValues() {
+ return true;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder result = new StringBuilder(200);
+ result.append("GenericDataModel[users:");
+ for (int i = 0; i < Math.min(3, userIDs.length); i++) {
+ if (i > 0) {
+ result.append(',');
+ }
+ result.append(userIDs[i]);
+ }
+ if (userIDs.length > 3) {
+ result.append("...");
+ }
+ result.append(']');
+ return result.toString();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericItemPreferenceArray.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericItemPreferenceArray.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericItemPreferenceArray.java
new file mode 100644
index 0000000..fde9314
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericItemPreferenceArray.java
@@ -0,0 +1,301 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterators;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.iterator.CountingIterator;
+
+/**
+ * <p>
+ * Like {@link GenericUserPreferenceArray} but stores preferences for one item (all item IDs the same) rather
+ * than one user.
+ * </p>
+ *
+ * @see BooleanItemPreferenceArray
+ * @see GenericUserPreferenceArray
+ * @see GenericPreference
+ */
+public final class GenericItemPreferenceArray implements PreferenceArray {
+
+ private static final int USER = 0;
+ private static final int VALUE = 2;
+ private static final int VALUE_REVERSED = 3;
+
+ private final long[] ids;
+ private long id;
+ private final float[] values;
+
+ public GenericItemPreferenceArray(int size) {
+ this.ids = new long[size];
+ values = new float[size];
+ this.id = Long.MIN_VALUE; // as a sort of 'unspecified' value
+ }
+
+ public GenericItemPreferenceArray(List<? extends Preference> prefs) {
+ this(prefs.size());
+ int size = prefs.size();
+ long itemID = Long.MIN_VALUE;
+ for (int i = 0; i < size; i++) {
+ Preference pref = prefs.get(i);
+ ids[i] = pref.getUserID();
+ if (i == 0) {
+ itemID = pref.getItemID();
+ } else {
+ if (itemID != pref.getItemID()) {
+ throw new IllegalArgumentException("Not all item IDs are the same");
+ }
+ }
+ values[i] = pref.getValue();
+ }
+ id = itemID;
+ }
+
+ /**
+ * This is a private copy constructor for clone().
+ */
+ private GenericItemPreferenceArray(long[] ids, long id, float[] values) {
+ this.ids = ids;
+ this.id = id;
+ this.values = values;
+ }
+
+ @Override
+ public int length() {
+ return ids.length;
+ }
+
+ @Override
+ public Preference get(int i) {
+ return new PreferenceView(i);
+ }
+
+ @Override
+ public void set(int i, Preference pref) {
+ id = pref.getItemID();
+ ids[i] = pref.getUserID();
+ values[i] = pref.getValue();
+ }
+
+ @Override
+ public long getUserID(int i) {
+ return ids[i];
+ }
+
+ @Override
+ public void setUserID(int i, long userID) {
+ ids[i] = userID;
+ }
+
+ @Override
+ public long getItemID(int i) {
+ return id;
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * Note that this method will actually set the item ID for <em>all</em> preferences.
+ */
+ @Override
+ public void setItemID(int i, long itemID) {
+ id = itemID;
+ }
+
+ /**
+ * @return all user IDs
+ */
+ @Override
+ public long[] getIDs() {
+ return ids;
+ }
+
+ @Override
+ public float getValue(int i) {
+ return values[i];
+ }
+
+ @Override
+ public void setValue(int i, float value) {
+ values[i] = value;
+ }
+
+ @Override
+ public void sortByUser() {
+ lateralSort(USER);
+ }
+
+ @Override
+ public void sortByItem() { }
+
+ @Override
+ public void sortByValue() {
+ lateralSort(VALUE);
+ }
+
+ @Override
+ public void sortByValueReversed() {
+ lateralSort(VALUE_REVERSED);
+ }
+
+ @Override
+ public boolean hasPrefWithUserID(long userID) {
+ for (long id : ids) {
+ if (userID == id) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public boolean hasPrefWithItemID(long itemID) {
+ return id == itemID;
+ }
+
+ private void lateralSort(int type) {
+ //Comb sort: http://en.wikipedia.org/wiki/Comb_sort
+ int length = length();
+ int gap = length;
+ boolean swapped = false;
+ while (gap > 1 || swapped) {
+ if (gap > 1) {
+ gap /= 1.247330950103979; // = 1 / (1 - 1/e^phi)
+ }
+ swapped = false;
+ int max = length - gap;
+ for (int i = 0; i < max; i++) {
+ int other = i + gap;
+ if (isLess(other, i, type)) {
+ swap(i, other);
+ swapped = true;
+ }
+ }
+ }
+ }
+
+ private boolean isLess(int i, int j, int type) {
+ switch (type) {
+ case USER:
+ return ids[i] < ids[j];
+ case VALUE:
+ return values[i] < values[j];
+ case VALUE_REVERSED:
+ return values[i] > values[j];
+ default:
+ throw new IllegalStateException();
+ }
+ }
+
+ private void swap(int i, int j) {
+ long temp1 = ids[i];
+ float temp2 = values[i];
+ ids[i] = ids[j];
+ values[i] = values[j];
+ ids[j] = temp1;
+ values[j] = temp2;
+ }
+
+ @Override
+ public GenericItemPreferenceArray clone() {
+ return new GenericItemPreferenceArray(ids.clone(), id, values.clone());
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) (id >> 32) ^ (int) id ^ Arrays.hashCode(ids) ^ Arrays.hashCode(values);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof GenericItemPreferenceArray)) {
+ return false;
+ }
+ GenericItemPreferenceArray otherArray = (GenericItemPreferenceArray) other;
+ return id == otherArray.id && Arrays.equals(ids, otherArray.ids) && Arrays.equals(values, otherArray.values);
+ }
+
+ @Override
+ public Iterator<Preference> iterator() {
+ return Iterators.transform(new CountingIterator(length()),
+ new Function<Integer, Preference>() {
+ @Override
+ public Preference apply(Integer from) {
+ return new PreferenceView(from);
+ }
+ });
+ }
+
+ @Override
+ public String toString() {
+ if (ids == null || ids.length == 0) {
+ return "GenericItemPreferenceArray[{}]";
+ }
+ StringBuilder result = new StringBuilder(20 * ids.length);
+ result.append("GenericItemPreferenceArray[itemID:");
+ result.append(id);
+ result.append(",{");
+ for (int i = 0; i < ids.length; i++) {
+ if (i > 0) {
+ result.append(',');
+ }
+ result.append(ids[i]);
+ result.append('=');
+ result.append(values[i]);
+ }
+ result.append("}]");
+ return result.toString();
+ }
+
+ private final class PreferenceView implements Preference {
+
+ private final int i;
+
+ private PreferenceView(int i) {
+ this.i = i;
+ }
+
+ @Override
+ public long getUserID() {
+ return GenericItemPreferenceArray.this.getUserID(i);
+ }
+
+ @Override
+ public long getItemID() {
+ return GenericItemPreferenceArray.this.getItemID(i);
+ }
+
+ @Override
+ public float getValue() {
+ return values[i];
+ }
+
+ @Override
+ public void setValue(float value) {
+ values[i] = value;
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericPreference.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericPreference.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericPreference.java
new file mode 100644
index 0000000..e6c7f43
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericPreference.java
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.io.Serializable;
+
+import org.apache.mahout.cf.taste.model.Preference;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * A simple {@link Preference} encapsulating an item and preference value.
+ * </p>
+ */
+public class GenericPreference implements Preference, Serializable {
+
+ private final long userID;
+ private final long itemID;
+ private float value;
+
+ public GenericPreference(long userID, long itemID, float value) {
+ Preconditions.checkArgument(!Float.isNaN(value), "NaN value");
+ this.userID = userID;
+ this.itemID = itemID;
+ this.value = value;
+ }
+
+ @Override
+ public long getUserID() {
+ return userID;
+ }
+
+ @Override
+ public long getItemID() {
+ return itemID;
+ }
+
+ @Override
+ public float getValue() {
+ return value;
+ }
+
+ @Override
+ public void setValue(float value) {
+ Preconditions.checkArgument(!Float.isNaN(value), "NaN value");
+ this.value = value;
+ }
+
+ @Override
+ public String toString() {
+ return "GenericPreference[userID: " + userID + ", itemID:" + itemID + ", value:" + value + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericUserPreferenceArray.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericUserPreferenceArray.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericUserPreferenceArray.java
new file mode 100644
index 0000000..647feeb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/GenericUserPreferenceArray.java
@@ -0,0 +1,307 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Iterators;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.iterator.CountingIterator;
+
+/**
+ * <p>
+ * Like {@link GenericItemPreferenceArray} but stores preferences for one user (all user IDs the same) rather
+ * than one item.
+ * </p>
+ *
+ * <p>
+ * This implementation maintains two parallel arrays, of item IDs and values. The idea is to save allocating
+ * {@link Preference} objects themselves. This saves the overhead of {@link Preference} objects but also
+ * duplicating the user ID value.
+ * </p>
+ *
+ * @see BooleanUserPreferenceArray
+ * @see GenericItemPreferenceArray
+ * @see GenericPreference
+ */
+public final class GenericUserPreferenceArray implements PreferenceArray {
+
+ private static final int ITEM = 1;
+ private static final int VALUE = 2;
+ private static final int VALUE_REVERSED = 3;
+
+ private final long[] ids;
+ private long id;
+ private final float[] values;
+
+ public GenericUserPreferenceArray(int size) {
+ this.ids = new long[size];
+ values = new float[size];
+ this.id = Long.MIN_VALUE; // as a sort of 'unspecified' value
+ }
+
+ public GenericUserPreferenceArray(List<? extends Preference> prefs) {
+ this(prefs.size());
+ int size = prefs.size();
+ long userID = Long.MIN_VALUE;
+ for (int i = 0; i < size; i++) {
+ Preference pref = prefs.get(i);
+ if (i == 0) {
+ userID = pref.getUserID();
+ } else {
+ if (userID != pref.getUserID()) {
+ throw new IllegalArgumentException("Not all user IDs are the same");
+ }
+ }
+ ids[i] = pref.getItemID();
+ values[i] = pref.getValue();
+ }
+ id = userID;
+ }
+
+ /**
+ * This is a private copy constructor for clone().
+ */
+ private GenericUserPreferenceArray(long[] ids, long id, float[] values) {
+ this.ids = ids;
+ this.id = id;
+ this.values = values;
+ }
+
+ @Override
+ public int length() {
+ return ids.length;
+ }
+
+ @Override
+ public Preference get(int i) {
+ return new PreferenceView(i);
+ }
+
+ @Override
+ public void set(int i, Preference pref) {
+ id = pref.getUserID();
+ ids[i] = pref.getItemID();
+ values[i] = pref.getValue();
+ }
+
+ @Override
+ public long getUserID(int i) {
+ return id;
+ }
+
+ /**
+ * {@inheritDoc}
+ *
+ * Note that this method will actually set the user ID for <em>all</em> preferences.
+ */
+ @Override
+ public void setUserID(int i, long userID) {
+ id = userID;
+ }
+
+ @Override
+ public long getItemID(int i) {
+ return ids[i];
+ }
+
+ @Override
+ public void setItemID(int i, long itemID) {
+ ids[i] = itemID;
+ }
+
+ /**
+ * @return all item IDs
+ */
+ @Override
+ public long[] getIDs() {
+ return ids;
+ }
+
+ @Override
+ public float getValue(int i) {
+ return values[i];
+ }
+
+ @Override
+ public void setValue(int i, float value) {
+ values[i] = value;
+ }
+
+ @Override
+ public void sortByUser() { }
+
+ @Override
+ public void sortByItem() {
+ lateralSort(ITEM);
+ }
+
+ @Override
+ public void sortByValue() {
+ lateralSort(VALUE);
+ }
+
+ @Override
+ public void sortByValueReversed() {
+ lateralSort(VALUE_REVERSED);
+ }
+
+ @Override
+ public boolean hasPrefWithUserID(long userID) {
+ return id == userID;
+ }
+
+ @Override
+ public boolean hasPrefWithItemID(long itemID) {
+ for (long id : ids) {
+ if (itemID == id) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ private void lateralSort(int type) {
+ //Comb sort: http://en.wikipedia.org/wiki/Comb_sort
+ int length = length();
+ int gap = length;
+ boolean swapped = false;
+ while (gap > 1 || swapped) {
+ if (gap > 1) {
+ gap /= 1.247330950103979; // = 1 / (1 - 1/e^phi)
+ }
+ swapped = false;
+ int max = length - gap;
+ for (int i = 0; i < max; i++) {
+ int other = i + gap;
+ if (isLess(other, i, type)) {
+ swap(i, other);
+ swapped = true;
+ }
+ }
+ }
+ }
+
+ private boolean isLess(int i, int j, int type) {
+ switch (type) {
+ case ITEM:
+ return ids[i] < ids[j];
+ case VALUE:
+ return values[i] < values[j];
+ case VALUE_REVERSED:
+ return values[i] > values[j];
+ default:
+ throw new IllegalStateException();
+ }
+ }
+
+ private void swap(int i, int j) {
+ long temp1 = ids[i];
+ float temp2 = values[i];
+ ids[i] = ids[j];
+ values[i] = values[j];
+ ids[j] = temp1;
+ values[j] = temp2;
+ }
+
+ @Override
+ public GenericUserPreferenceArray clone() {
+ return new GenericUserPreferenceArray(ids.clone(), id, values.clone());
+ }
+
+ @Override
+ public int hashCode() {
+ return (int) (id >> 32) ^ (int) id ^ Arrays.hashCode(ids) ^ Arrays.hashCode(values);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (!(other instanceof GenericUserPreferenceArray)) {
+ return false;
+ }
+ GenericUserPreferenceArray otherArray = (GenericUserPreferenceArray) other;
+ return id == otherArray.id && Arrays.equals(ids, otherArray.ids) && Arrays.equals(values, otherArray.values);
+ }
+
+ @Override
+ public Iterator<Preference> iterator() {
+ return Iterators.transform(new CountingIterator(length()),
+ new Function<Integer, Preference>() {
+ @Override
+ public Preference apply(Integer from) {
+ return new PreferenceView(from);
+ }
+ });
+ }
+
+ @Override
+ public String toString() {
+ if (ids == null || ids.length == 0) {
+ return "GenericUserPreferenceArray[{}]";
+ }
+ StringBuilder result = new StringBuilder(20 * ids.length);
+ result.append("GenericUserPreferenceArray[userID:");
+ result.append(id);
+ result.append(",{");
+ for (int i = 0; i < ids.length; i++) {
+ if (i > 0) {
+ result.append(',');
+ }
+ result.append(ids[i]);
+ result.append('=');
+ result.append(values[i]);
+ }
+ result.append("}]");
+ return result.toString();
+ }
+
+ private final class PreferenceView implements Preference {
+
+ private final int i;
+
+ private PreferenceView(int i) {
+ this.i = i;
+ }
+
+ @Override
+ public long getUserID() {
+ return GenericUserPreferenceArray.this.getUserID(i);
+ }
+
+ @Override
+ public long getItemID() {
+ return GenericUserPreferenceArray.this.getItemID(i);
+ }
+
+ @Override
+ public float getValue() {
+ return values[i];
+ }
+
+ @Override
+ public void setValue(float value) {
+ values[i] = value;
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/MemoryIDMigrator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/MemoryIDMigrator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/MemoryIDMigrator.java
new file mode 100644
index 0000000..3463ff5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/MemoryIDMigrator.java
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.model.UpdatableIDMigrator;
+
+/**
+ * Implementation which stores the reverse long-to-String mapping in memory.
+ */
+public final class MemoryIDMigrator extends AbstractIDMigrator implements UpdatableIDMigrator {
+
+ private final FastByIDMap<String> longToString;
+
+ public MemoryIDMigrator() {
+ this.longToString = new FastByIDMap<>(100);
+ }
+
+ @Override
+ public void storeMapping(long longID, String stringID) {
+ synchronized (longToString) {
+ longToString.put(longID, stringID);
+ }
+ }
+
+ @Override
+ public String toStringID(long longID) {
+ synchronized (longToString) {
+ return longToString.get(longID);
+ }
+ }
+
+ @Override
+ public void initialize(Iterable<String> stringIDs) {
+ for (String stringID : stringIDs) {
+ storeMapping(toLongID(stringID), stringID);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/MySQLJDBCIDMigrator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/MySQLJDBCIDMigrator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/MySQLJDBCIDMigrator.java
new file mode 100644
index 0000000..b134598
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/MySQLJDBCIDMigrator.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import javax.sql.DataSource;
+
+/**
+ * <p>
+ * An implementation for MySQL. The following statement would create a table suitable for use with this class:
+ * </p>
+ *
+ * <p>
+ *
+ * <pre>
+ * CREATE TABLE taste_id_migration (
+ * long_id BIGINT NOT NULL PRIMARY KEY,
+ * string_id VARCHAR(255) NOT NULL UNIQUE
+ * )
+ * </pre>
+ *
+ * </p>
+ *
+ * <p>
+ * Separately, note that in a MySQL database, the following function calls will convert a string value into a
+ * numeric value in the same way that the standard implementations in this package do. This may be useful in
+ * writing SQL statements for use with
+ * {@code AbstractJDBCDataModel} subclasses which convert string
+ * column values to appropriate numeric values -- though this should be viewed as a temporary arrangement
+ * since it will impact performance:
+ * </p>
+ *
+ * <p>
+ * {@code cast(conv(substring(md5([column name]), 1, 16),16,10) as signed)}
+ * </p>
+ */
+public final class MySQLJDBCIDMigrator extends AbstractJDBCIDMigrator {
+
+ public MySQLJDBCIDMigrator(DataSource dataSource) {
+ this(dataSource, DEFAULT_MAPPING_TABLE,
+ DEFAULT_LONG_ID_COLUMN, DEFAULT_STRING_ID_COLUMN);
+ }
+
+ public MySQLJDBCIDMigrator(DataSource dataSource,
+ String mappingTable,
+ String longIDColumn,
+ String stringIDColumn) {
+ super(dataSource,
+ "SELECT " + stringIDColumn + " FROM " + mappingTable + " WHERE " + longIDColumn + "=?",
+ "INSERT IGNORE INTO " + mappingTable + " (" + longIDColumn + ',' + stringIDColumn + ") VALUES (?,?)");
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousConcurrentUserDataModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousConcurrentUserDataModel.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousConcurrentUserDataModel.java
new file mode 100644
index 0000000..c97a545
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousConcurrentUserDataModel.java
@@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import com.google.common.base.Preconditions;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>
+ * This is a special thread-safe version of {@link PlusAnonymousUserDataModel}
+ * which allow multiple concurrent anonymous requests.
+ * </p>
+ *
+ * <p>
+ * To use it, you have to estimate the number of concurrent anonymous users of your application.
+ * The pool of users with the given size will be created. For each anonymous recommendations request,
+ * a user has to be taken from the pool and returned back immediately afterwards.
+ * </p>
+ *
+ * <p>
+ * If no more users are available in the pool, anonymous recommendations cannot be produced.
+ * </p>
+ *
+ * </p>
+ *
+ * Setup:
+ * <pre>
+ * int concurrentUsers = 100;
+ * DataModel realModel = ..
+ * PlusAnonymousConcurrentUserDataModel plusModel =
+ * new PlusAnonymousConcurrentUserDataModel(realModel, concurrentUsers);
+ * Recommender recommender = ...;
+ * </pre>
+ *
+ * Real-time recommendation:
+ * <pre>
+ * PlusAnonymousConcurrentUserDataModel plusModel =
+ * (PlusAnonymousConcurrentUserDataModel) recommender.getDataModel();
+ *
+ * // Take the next available anonymous user from the pool
+ * Long anonymousUserID = plusModel.takeAvailableUser();
+ *
+ * PreferenceArray tempPrefs = ..
+ * tempPrefs.setUserID(0, anonymousUserID);
+ * tempPrefs.setItemID(0, itemID);
+ * plusModel.setTempPrefs(tempPrefs, anonymousUserID);
+ *
+ * // Produce recommendations
+ * recommender.recommend(anonymousUserID, howMany);
+ *
+ * // It is very IMPORTANT to release user back to the pool
+ * plusModel.releaseUser(anonymousUserID);
+ * </pre>
+ *
+ * </p>
+ */
+public final class PlusAnonymousConcurrentUserDataModel extends PlusAnonymousUserDataModel {
+
+ /** Preferences for all anonymous users */
+ private final Map<Long,PreferenceArray> tempPrefs;
+ /** Item IDs set for all anonymous users */
+ private final Map<Long,FastIDSet> prefItemIDs;
+ /** Pool of the users (FIFO) */
+ private Queue<Long> usersPool;
+
+ private static final Logger log = LoggerFactory.getLogger(PlusAnonymousUserDataModel.class);
+
+ /**
+ * @param delegate Real model where anonymous users will be added to
+ * @param maxConcurrentUsers Maximum allowed number of concurrent anonymous users
+ */
+ public PlusAnonymousConcurrentUserDataModel(DataModel delegate, int maxConcurrentUsers) {
+ super(delegate);
+
+ tempPrefs = new ConcurrentHashMap<>();
+ prefItemIDs = new ConcurrentHashMap<>();
+
+ initializeUsersPools(maxConcurrentUsers);
+ }
+
+ /**
+ * Initialize the pool of concurrent anonymous users.
+ *
+ * @param usersPoolSize Maximum allowed number of concurrent anonymous user. Depends on the consumer system.
+ */
+ private void initializeUsersPools(int usersPoolSize) {
+ usersPool = new ConcurrentLinkedQueue<>();
+ for (int i = 0; i < usersPoolSize; i++) {
+ usersPool.add(TEMP_USER_ID + i);
+ }
+ }
+
+ /**
+ * Take the next available concurrent anonymous users from the pool.
+ *
+ * @return User ID or null if no more users are available
+ */
+ public Long takeAvailableUser() {
+ Long takenUserID = usersPool.poll();
+ if (takenUserID != null) {
+ // Initialize the preferences array to indicate that the user is taken.
+ tempPrefs.put(takenUserID, new GenericUserPreferenceArray(0));
+ return takenUserID;
+ }
+ return null;
+ }
+
+ /**
+ * Release previously taken anonymous user and return it to the pool.
+ *
+ * @param userID ID of a previously taken anonymous user
+ * @return true if the user was previously taken, false otherwise
+ */
+ public boolean releaseUser(Long userID) {
+ if (tempPrefs.containsKey(userID)) {
+ this.clearTempPrefs(userID);
+ // Return previously taken user to the pool
+ usersPool.offer(userID);
+ return true;
+ }
+ return false;
+ }
+
+ /**
+ * Checks whether a given user is a valid previously acquired anonymous user.
+ */
+ private boolean isAnonymousUser(long userID) {
+ return tempPrefs.containsKey(userID);
+ }
+
+ /**
+ * Sets temporary preferences for a given anonymous user.
+ */
+ public void setTempPrefs(PreferenceArray prefs, long anonymousUserID) {
+ Preconditions.checkArgument(prefs != null && prefs.length() > 0, "prefs is null or empty");
+
+ this.tempPrefs.put(anonymousUserID, prefs);
+ FastIDSet userPrefItemIDs = new FastIDSet();
+
+ for (int i = 0; i < prefs.length(); i++) {
+ userPrefItemIDs.add(prefs.getItemID(i));
+ }
+
+ this.prefItemIDs.put(anonymousUserID, userPrefItemIDs);
+ }
+
+ /**
+ * Clears temporary preferences for a given anonymous user.
+ */
+ public void clearTempPrefs(long anonymousUserID) {
+ this.tempPrefs.remove(anonymousUserID);
+ this.prefItemIDs.remove(anonymousUserID);
+ }
+
+ @Override
+ public LongPrimitiveIterator getUserIDs() throws TasteException {
+ // Anonymous users have short lifetime and should not be included into the neighbohoods of the real users.
+ // Thus exclude them from the universe.
+ return getDelegate().getUserIDs();
+ }
+
+ @Override
+ public PreferenceArray getPreferencesFromUser(long userID) throws TasteException {
+ if (isAnonymousUser(userID)) {
+ return tempPrefs.get(userID);
+ }
+ return getDelegate().getPreferencesFromUser(userID);
+ }
+
+ @Override
+ public FastIDSet getItemIDsFromUser(long userID) throws TasteException {
+ if (isAnonymousUser(userID)) {
+ return prefItemIDs.get(userID);
+ }
+ return getDelegate().getItemIDsFromUser(userID);
+ }
+
+ @Override
+ public PreferenceArray getPreferencesForItem(long itemID) throws TasteException {
+ if (tempPrefs.isEmpty()) {
+ return getDelegate().getPreferencesForItem(itemID);
+ }
+
+ PreferenceArray delegatePrefs = null;
+
+ try {
+ delegatePrefs = getDelegate().getPreferencesForItem(itemID);
+ } catch (NoSuchItemException nsie) {
+ // OK. Probably an item that only the anonymous user has
+ if (log.isDebugEnabled()) {
+ log.debug("Item {} unknown", itemID);
+ }
+ }
+
+ List<Preference> anonymousPreferences = Lists.newArrayList();
+
+ for (Map.Entry<Long, PreferenceArray> prefsMap : tempPrefs.entrySet()) {
+ PreferenceArray singleUserTempPrefs = prefsMap.getValue();
+ for (int i = 0; i < singleUserTempPrefs.length(); i++) {
+ if (singleUserTempPrefs.getItemID(i) == itemID) {
+ anonymousPreferences.add(singleUserTempPrefs.get(i));
+ }
+ }
+ }
+
+ int delegateLength = delegatePrefs == null ? 0 : delegatePrefs.length();
+ int anonymousPrefsLength = anonymousPreferences.size();
+ int prefsCounter = 0;
+
+ // Merge the delegate and anonymous preferences into a single array
+ PreferenceArray newPreferenceArray = new GenericItemPreferenceArray(delegateLength + anonymousPrefsLength);
+
+ for (int i = 0; i < delegateLength; i++) {
+ newPreferenceArray.set(prefsCounter++, delegatePrefs.get(i));
+ }
+
+ for (Preference anonymousPreference : anonymousPreferences) {
+ newPreferenceArray.set(prefsCounter++, anonymousPreference);
+ }
+
+ if (newPreferenceArray.length() == 0) {
+ // No, didn't find it among the anonymous user prefs
+ throw new NoSuchItemException(itemID);
+ }
+
+ return newPreferenceArray;
+ }
+
+ @Override
+ public Float getPreferenceValue(long userID, long itemID) throws TasteException {
+ if (isAnonymousUser(userID)) {
+ PreferenceArray singleUserTempPrefs = tempPrefs.get(userID);
+ for (int i = 0; i < singleUserTempPrefs.length(); i++) {
+ if (singleUserTempPrefs.getItemID(i) == itemID) {
+ return singleUserTempPrefs.getValue(i);
+ }
+ }
+ return null;
+ }
+ return getDelegate().getPreferenceValue(userID, itemID);
+ }
+
+ @Override
+ public Long getPreferenceTime(long userID, long itemID) throws TasteException {
+ if (isAnonymousUser(userID)) {
+ // Timestamps are not saved for anonymous preferences
+ return null;
+ }
+ return getDelegate().getPreferenceTime(userID, itemID);
+ }
+
+ @Override
+ public int getNumUsers() throws TasteException {
+ // Anonymous users have short lifetime and should not be included into the neighbohoods of the real users.
+ // Thus exclude them from the universe.
+ return getDelegate().getNumUsers();
+ }
+
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID) throws TasteException {
+ if (tempPrefs.isEmpty()) {
+ return getDelegate().getNumUsersWithPreferenceFor(itemID);
+ }
+
+ int countAnonymousUsersWithPreferenceFor = 0;
+
+ for (Map.Entry<Long, PreferenceArray> singleUserTempPrefs : tempPrefs.entrySet()) {
+ for (int i = 0; i < singleUserTempPrefs.getValue().length(); i++) {
+ if (singleUserTempPrefs.getValue().getItemID(i) == itemID) {
+ countAnonymousUsersWithPreferenceFor++;
+ break;
+ }
+ }
+ }
+ return getDelegate().getNumUsersWithPreferenceFor(itemID) + countAnonymousUsersWithPreferenceFor;
+ }
+
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) throws TasteException {
+ if (tempPrefs.isEmpty()) {
+ return getDelegate().getNumUsersWithPreferenceFor(itemID1, itemID2);
+ }
+
+ int countAnonymousUsersWithPreferenceFor = 0;
+
+ for (Map.Entry<Long, PreferenceArray> singleUserTempPrefs : tempPrefs.entrySet()) {
+ boolean found1 = false;
+ boolean found2 = false;
+ for (int i = 0; i < singleUserTempPrefs.getValue().length() && !(found1 && found2); i++) {
+ long itemID = singleUserTempPrefs.getValue().getItemID(i);
+ if (itemID == itemID1) {
+ found1 = true;
+ }
+ if (itemID == itemID2) {
+ found2 = true;
+ }
+ }
+
+ if (found1 && found2) {
+ countAnonymousUsersWithPreferenceFor++;
+ }
+ }
+
+ return getDelegate().getNumUsersWithPreferenceFor(itemID1, itemID2) + countAnonymousUsersWithPreferenceFor;
+ }
+
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ if (isAnonymousUser(userID)) {
+ throw new UnsupportedOperationException();
+ }
+ getDelegate().setPreference(userID, itemID, value);
+ }
+
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ if (isAnonymousUser(userID)) {
+ throw new UnsupportedOperationException();
+ }
+ getDelegate().removePreference(userID, itemID);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousUserDataModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousUserDataModel.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousUserDataModel.java
new file mode 100644
index 0000000..546349b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousUserDataModel.java
@@ -0,0 +1,320 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.NoSuchItemException;
+import org.apache.mahout.cf.taste.common.NoSuchUserException;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+
+import com.google.common.base.Preconditions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>
+ * This {@link DataModel} decorator class is useful in a situation where you wish to recommend to a user that
+ * doesn't really exist yet in your actual {@link DataModel}. For example maybe you wish to recommend DVDs to
+ * a user who has browsed a few titles on your DVD store site, but, the user is not yet registered.
+ * </p>
+ *
+ * <p>
+ * This enables you to temporarily add a temporary user to an existing {@link DataModel} in a way that
+ * recommenders can then produce recommendations anyway. To do so, wrap your real implementation in this
+ * class:
+ * </p>
+ *
+ * <p>
+ *
+ * <pre>
+ * DataModel realModel = ...;
+ * DataModel plusModel = new PlusAnonymousUserDataModel(realModel);
+ * ...
+ * ItemSimilarity similarity = new LogLikelihoodSimilarity(realModel); // not plusModel
+ * </pre>
+ *
+ * </p>
+ *
+ * <p>
+ * But, you may continue to use {@code realModel} as input to other components. To recommend, first construct and
+ * set the temporary user information on the model and then simply call the recommender. The
+ * {@code synchronized} block exists to remind you that this is of course not thread-safe. Only one set
+ * of temp data can be inserted into the model and used at one time.
+ * </p>
+ *
+ * <p>
+ *
+ * <pre>
+ * Recommender recommender = ...;
+ * ...
+ * synchronized(...) {
+ * PreferenceArray tempPrefs = ...;
+ * plusModel.setTempPrefs(tempPrefs);
+ * recommender.recommend(PlusAnonymousUserDataModel.TEMP_USER_ID, 10);
+ * plusModel.setTempPrefs(null);
+ * }
+ * </pre>
+ *
+ * </p>
+ */
+public class PlusAnonymousUserDataModel implements DataModel {
+
+ public static final long TEMP_USER_ID = Long.MIN_VALUE;
+
+ private final DataModel delegate;
+ private PreferenceArray tempPrefs;
+ private final FastIDSet prefItemIDs;
+
+ private static final Logger log = LoggerFactory.getLogger(PlusAnonymousUserDataModel.class);
+
+ public PlusAnonymousUserDataModel(DataModel delegate) {
+ this.delegate = delegate;
+ this.prefItemIDs = new FastIDSet();
+ }
+
+ protected DataModel getDelegate() {
+ return delegate;
+ }
+
+ public void setTempPrefs(PreferenceArray prefs) {
+ Preconditions.checkArgument(prefs != null && prefs.length() > 0, "prefs is null or empty");
+ this.tempPrefs = prefs;
+ this.prefItemIDs.clear();
+ for (int i = 0; i < prefs.length(); i++) {
+ this.prefItemIDs.add(prefs.getItemID(i));
+ }
+ }
+
+ public void clearTempPrefs() {
+ tempPrefs = null;
+ prefItemIDs.clear();
+ }
+
+ @Override
+ public LongPrimitiveIterator getUserIDs() throws TasteException {
+ if (tempPrefs == null) {
+ return delegate.getUserIDs();
+ }
+ return new PlusAnonymousUserLongPrimitiveIterator(delegate.getUserIDs(), TEMP_USER_ID);
+ }
+
+ @Override
+ public PreferenceArray getPreferencesFromUser(long userID) throws TasteException {
+ if (userID == TEMP_USER_ID) {
+ if (tempPrefs == null) {
+ throw new NoSuchUserException(TEMP_USER_ID);
+ }
+ return tempPrefs;
+ }
+ return delegate.getPreferencesFromUser(userID);
+ }
+
+ @Override
+ public FastIDSet getItemIDsFromUser(long userID) throws TasteException {
+ if (userID == TEMP_USER_ID) {
+ if (tempPrefs == null) {
+ throw new NoSuchUserException(TEMP_USER_ID);
+ }
+ return prefItemIDs;
+ }
+ return delegate.getItemIDsFromUser(userID);
+ }
+
+ @Override
+ public LongPrimitiveIterator getItemIDs() throws TasteException {
+ return delegate.getItemIDs();
+ // Yeah ignoring items that only the plus-one user knows about... can't really happen
+ }
+
+ @Override
+ public PreferenceArray getPreferencesForItem(long itemID) throws TasteException {
+ if (tempPrefs == null) {
+ return delegate.getPreferencesForItem(itemID);
+ }
+ PreferenceArray delegatePrefs = null;
+ try {
+ delegatePrefs = delegate.getPreferencesForItem(itemID);
+ } catch (NoSuchItemException nsie) {
+ // OK. Probably an item that only the anonymous user has
+ if (log.isDebugEnabled()) {
+ log.debug("Item {} unknown", itemID);
+ }
+ }
+ for (int i = 0; i < tempPrefs.length(); i++) {
+ if (tempPrefs.getItemID(i) == itemID) {
+ return cloneAndMergeInto(delegatePrefs, itemID, tempPrefs.getUserID(i), tempPrefs.getValue(i));
+ }
+ }
+ if (delegatePrefs == null) {
+ // No, didn't find it among the anonymous user prefs
+ throw new NoSuchItemException(itemID);
+ }
+ return delegatePrefs;
+ }
+
+ private static PreferenceArray cloneAndMergeInto(PreferenceArray delegatePrefs,
+ long itemID,
+ long newUserID,
+ float value) {
+
+ int length = delegatePrefs == null ? 0 : delegatePrefs.length();
+ int newLength = length + 1;
+ PreferenceArray newPreferenceArray = new GenericItemPreferenceArray(newLength);
+
+ // Set item ID once
+ newPreferenceArray.setItemID(0, itemID);
+
+ int positionToInsert = 0;
+ while (positionToInsert < length && newUserID > delegatePrefs.getUserID(positionToInsert)) {
+ positionToInsert++;
+ }
+
+ for (int i = 0; i < positionToInsert; i++) {
+ newPreferenceArray.setUserID(i, delegatePrefs.getUserID(i));
+ newPreferenceArray.setValue(i, delegatePrefs.getValue(i));
+ }
+ newPreferenceArray.setUserID(positionToInsert, newUserID);
+ newPreferenceArray.setValue(positionToInsert, value);
+ for (int i = positionToInsert + 1; i < newLength; i++) {
+ newPreferenceArray.setUserID(i, delegatePrefs.getUserID(i - 1));
+ newPreferenceArray.setValue(i, delegatePrefs.getValue(i - 1));
+ }
+
+ return newPreferenceArray;
+ }
+
+ @Override
+ public Float getPreferenceValue(long userID, long itemID) throws TasteException {
+ if (userID == TEMP_USER_ID) {
+ if (tempPrefs == null) {
+ throw new NoSuchUserException(TEMP_USER_ID);
+ }
+ for (int i = 0; i < tempPrefs.length(); i++) {
+ if (tempPrefs.getItemID(i) == itemID) {
+ return tempPrefs.getValue(i);
+ }
+ }
+ return null;
+ }
+ return delegate.getPreferenceValue(userID, itemID);
+ }
+
+ @Override
+ public Long getPreferenceTime(long userID, long itemID) throws TasteException {
+ if (userID == TEMP_USER_ID) {
+ if (tempPrefs == null) {
+ throw new NoSuchUserException(TEMP_USER_ID);
+ }
+ return null;
+ }
+ return delegate.getPreferenceTime(userID, itemID);
+ }
+
+ @Override
+ public int getNumItems() throws TasteException {
+ return delegate.getNumItems();
+ }
+
+ @Override
+ public int getNumUsers() throws TasteException {
+ return delegate.getNumUsers() + (tempPrefs == null ? 0 : 1);
+ }
+
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID) throws TasteException {
+ if (tempPrefs == null) {
+ return delegate.getNumUsersWithPreferenceFor(itemID);
+ }
+ boolean found = false;
+ for (int i = 0; i < tempPrefs.length(); i++) {
+ if (tempPrefs.getItemID(i) == itemID) {
+ found = true;
+ break;
+ }
+ }
+ return delegate.getNumUsersWithPreferenceFor(itemID) + (found ? 1 : 0);
+ }
+
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) throws TasteException {
+ if (tempPrefs == null) {
+ return delegate.getNumUsersWithPreferenceFor(itemID1, itemID2);
+ }
+ boolean found1 = false;
+ boolean found2 = false;
+ for (int i = 0; i < tempPrefs.length() && !(found1 && found2); i++) {
+ long itemID = tempPrefs.getItemID(i);
+ if (itemID == itemID1) {
+ found1 = true;
+ }
+ if (itemID == itemID2) {
+ found2 = true;
+ }
+ }
+ return delegate.getNumUsersWithPreferenceFor(itemID1, itemID2) + (found1 && found2 ? 1 : 0);
+ }
+
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ if (userID == TEMP_USER_ID) {
+ if (tempPrefs == null) {
+ throw new NoSuchUserException(TEMP_USER_ID);
+ }
+ throw new UnsupportedOperationException();
+ }
+ delegate.setPreference(userID, itemID, value);
+ }
+
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ if (userID == TEMP_USER_ID) {
+ if (tempPrefs == null) {
+ throw new NoSuchUserException(TEMP_USER_ID);
+ }
+ throw new UnsupportedOperationException();
+ }
+ delegate.removePreference(userID, itemID);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ delegate.refresh(alreadyRefreshed);
+ }
+
+ @Override
+ public boolean hasPreferenceValues() {
+ return delegate.hasPreferenceValues();
+ }
+
+ @Override
+ public float getMaxPreference() {
+ return delegate.getMaxPreference();
+ }
+
+ @Override
+ public float getMinPreference() {
+ return delegate.getMinPreference();
+ }
+
+}
[14/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/LLRReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/LLRReducer.java b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/LLRReducer.java
new file mode 100644
index 0000000..d414416
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/LLRReducer.java
@@ -0,0 +1,170 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.collocations.llr;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.stats.LogLikelihood;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Reducer for pass 2 of the collocation discovery job. Collects ngram and sub-ngram frequencies and performs
+ * the Log-likelihood ratio calculation.
+ */
+public class LLRReducer extends Reducer<Gram, Gram, Text, DoubleWritable> {
+
+ /** Counter to track why a particlar entry was skipped */
+ public enum Skipped {
+ EXTRA_HEAD, EXTRA_TAIL, MISSING_HEAD, MISSING_TAIL, LESS_THAN_MIN_LLR, LLR_CALCULATION_ERROR,
+ }
+
+ private static final Logger log = LoggerFactory.getLogger(LLRReducer.class);
+
+ public static final String NGRAM_TOTAL = "ngramTotal";
+ public static final String MIN_LLR = "minLLR";
+ public static final float DEFAULT_MIN_LLR = 1.0f;
+
+ private long ngramTotal;
+ private float minLLRValue;
+ private boolean emitUnigrams;
+ private final LLCallback ll;
+
+ /**
+ * Perform LLR calculation, input is: k:ngram:ngramFreq v:(h_|t_)subgram:subgramfreq N = ngram total
+ *
+ * Each ngram will have 2 subgrams, a head and a tail, referred to as A and B respectively below.
+ *
+ * A+ B: number of times a+b appear together: ngramFreq A+!B: number of times A appears without B:
+ * hSubgramFreq - ngramFreq !A+ B: number of times B appears without A: tSubgramFreq - ngramFreq !A+!B:
+ * number of times neither A or B appears (in that order): N - (subgramFreqA + subgramFreqB - ngramFreq)
+ */
+ @Override
+ protected void reduce(Gram ngram, Iterable<Gram> values, Context context) throws IOException, InterruptedException {
+
+ int[] gramFreq = {-1, -1};
+
+ if (ngram.getType() == Gram.Type.UNIGRAM && emitUnigrams) {
+ DoubleWritable dd = new DoubleWritable(ngram.getFrequency());
+ Text t = new Text(ngram.getString());
+ context.write(t, dd);
+ return;
+ }
+ // TODO better way to handle errors? Wouldn't an exception thrown here
+ // cause hadoop to re-try the job?
+ String[] gram = new String[2];
+ for (Gram value : values) {
+
+ int pos = value.getType() == Gram.Type.HEAD ? 0 : 1;
+
+ if (gramFreq[pos] != -1) {
+ log.warn("Extra {} for {}, skipping", value.getType(), ngram);
+ if (value.getType() == Gram.Type.HEAD) {
+ context.getCounter(Skipped.EXTRA_HEAD).increment(1);
+ } else {
+ context.getCounter(Skipped.EXTRA_TAIL).increment(1);
+ }
+ return;
+ }
+
+ gram[pos] = value.getString();
+ gramFreq[pos] = value.getFrequency();
+ }
+
+ if (gramFreq[0] == -1) {
+ log.warn("Missing head for {}, skipping.", ngram);
+ context.getCounter(Skipped.MISSING_HEAD).increment(1);
+ return;
+ }
+ if (gramFreq[1] == -1) {
+ log.warn("Missing tail for {}, skipping", ngram);
+ context.getCounter(Skipped.MISSING_TAIL).increment(1);
+ return;
+ }
+
+ long k11 = ngram.getFrequency(); /* a&b */
+ long k12 = gramFreq[0] - ngram.getFrequency(); /* a&!b */
+ long k21 = gramFreq[1] - ngram.getFrequency(); /* !b&a */
+ long k22 = ngramTotal - (gramFreq[0] + gramFreq[1] - ngram.getFrequency()); /* !a&!b */
+
+ double llr;
+ try {
+ llr = ll.logLikelihoodRatio(k11, k12, k21, k22);
+ } catch (IllegalArgumentException ex) {
+ context.getCounter(Skipped.LLR_CALCULATION_ERROR).increment(1);
+ log.warn("Problem calculating LLR ratio for ngram {}, HEAD {}:{}, TAIL {}:{}, k11/k12/k21/k22: {}/{}/{}/{}",
+ ngram, gram[0], gramFreq[0], gram[1], gramFreq[1], k11, k12, k21, k22, ex);
+ return;
+ }
+ if (llr < minLLRValue) {
+ context.getCounter(Skipped.LESS_THAN_MIN_LLR).increment(1);
+ } else {
+ context.write(new Text(ngram.getString()), new DoubleWritable(llr));
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ this.ngramTotal = conf.getLong(NGRAM_TOTAL, -1);
+ this.minLLRValue = conf.getFloat(MIN_LLR, DEFAULT_MIN_LLR);
+
+ this.emitUnigrams = conf.getBoolean(CollocDriver.EMIT_UNIGRAMS, CollocDriver.DEFAULT_EMIT_UNIGRAMS);
+
+ log.info("NGram Total: {}, Min LLR value: {}, Emit Unigrams: {}",
+ ngramTotal, minLLRValue, emitUnigrams);
+
+ if (ngramTotal == -1) {
+ throw new IllegalStateException("No NGRAM_TOTAL available in job config");
+ }
+ }
+
+ public LLRReducer() {
+ this.ll = new ConcreteLLCallback();
+ }
+
+ /**
+ * plug in an alternate LL implementation, used for testing
+ *
+ * @param ll
+ * the LL to use.
+ */
+ LLRReducer(LLCallback ll) {
+ this.ll = ll;
+ }
+
+ /**
+ * provide interface so the input to the llr calculation can be captured for validation in unit testing
+ */
+ public interface LLCallback {
+ double logLikelihoodRatio(long k11, long k12, long k21, long k22);
+ }
+
+ /** concrete implementation delegates to LogLikelihood class */
+ public static final class ConcreteLLCallback implements LLCallback {
+ @Override
+ public double logLikelihoodRatio(long k11, long k12, long k21, long k22) {
+ return LogLikelihood.logLikelihoodRatio(k11, k12, k21, k22);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMergeReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMergeReducer.java b/mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMergeReducer.java
new file mode 100644
index 0000000..a8eacc3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMergeReducer.java
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.common;
+
+import java.io.IOException;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+
+/**
+ * Merges partial vectors in to a full sparse vector
+ */
+public class PartialVectorMergeReducer extends
+ Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
+
+ private double normPower;
+
+ private int dimension;
+
+ private boolean sequentialAccess;
+
+ private boolean namedVector;
+
+ private boolean logNormalize;
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context context) throws IOException,
+ InterruptedException {
+
+ Vector vector = new RandomAccessSparseVector(dimension, 10);
+ for (VectorWritable value : values) {
+ vector.assign(value.get(), Functions.PLUS);
+ }
+ if (normPower != PartialVectorMerger.NO_NORMALIZING) {
+ if (logNormalize) {
+ vector = vector.logNormalize(normPower);
+ } else {
+ vector = vector.normalize(normPower);
+ }
+ }
+ if (sequentialAccess) {
+ vector = new SequentialAccessSparseVector(vector);
+ }
+
+ if (namedVector) {
+ vector = new NamedVector(vector, key.toString());
+ }
+
+ // drop empty vectors.
+ if (vector.getNumNondefaultElements() > 0) {
+ VectorWritable vectorWritable = new VectorWritable(vector);
+ context.write(key, vectorWritable);
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ normPower = conf.getFloat(PartialVectorMerger.NORMALIZATION_POWER, PartialVectorMerger.NO_NORMALIZING);
+ dimension = conf.getInt(PartialVectorMerger.DIMENSION, Integer.MAX_VALUE);
+ sequentialAccess = conf.getBoolean(PartialVectorMerger.SEQUENTIAL_ACCESS, false);
+ namedVector = conf.getBoolean(PartialVectorMerger.NAMED_VECTOR, false);
+ logNormalize = conf.getBoolean(PartialVectorMerger.LOG_NORMALIZE, false);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMerger.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMerger.java b/mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMerger.java
new file mode 100644
index 0000000..287a813
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMerger.java
@@ -0,0 +1,144 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.common;
+
+import java.io.IOException;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * This class groups a set of input vectors. The Sequence file input should have a
+ * {@link org.apache.hadoop.io.WritableComparable}
+ * key containing document id and a {@link VectorWritable} value containing the term frequency vector. This
+ * class also does normalization of the vector.
+ *
+ */
+public final class PartialVectorMerger {
+
+ public static final float NO_NORMALIZING = -1.0f;
+
+ public static final String NORMALIZATION_POWER = "normalization.power";
+
+ public static final String DIMENSION = "vector.dimension";
+
+ public static final String SEQUENTIAL_ACCESS = "vector.sequentialAccess";
+
+ public static final String NAMED_VECTOR = "vector.named";
+
+ public static final String LOG_NORMALIZE = "vector.lognormalize";
+
+ /**
+ * Cannot be initialized. Use the static functions
+ */
+ private PartialVectorMerger() {
+
+ }
+
+ /**
+ * Merge all the partial {@link org.apache.mahout.math.RandomAccessSparseVector}s into the complete Document
+ * {@link org.apache.mahout.math.RandomAccessSparseVector}
+ *
+ * @param partialVectorPaths
+ * input directory of the vectors in {@link org.apache.hadoop.io.SequenceFile} format
+ * @param output
+ * output directory were the partial vectors have to be created
+ * @param baseConf
+ * job configuration
+ * @param normPower
+ * The normalization value. Must be greater than or equal to 0 or equal to {@link #NO_NORMALIZING}
+ * @param dimension cardinality of the vectors
+ * @param sequentialAccess
+ * output vectors should be optimized for sequential access
+ * @param namedVector
+ * output vectors should be named, retaining key (doc id) as a label
+ * @param numReducers
+ * The number of reducers to spawn
+ */
+ public static void mergePartialVectors(Iterable<Path> partialVectorPaths,
+ Path output,
+ Configuration baseConf,
+ float normPower,
+ boolean logNormalize,
+ int dimension,
+ boolean sequentialAccess,
+ boolean namedVector,
+ int numReducers)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ Preconditions.checkArgument(normPower == NO_NORMALIZING || normPower >= 0,
+ "If specified normPower must be nonnegative", normPower);
+ Preconditions.checkArgument(normPower == NO_NORMALIZING
+ || (normPower > 1 && !Double.isInfinite(normPower))
+ || !logNormalize,
+ "normPower must be > 1 and not infinite if log normalization is chosen", normPower);
+
+ Configuration conf = new Configuration(baseConf);
+ // this conf parameter needs to be set enable serialisation of conf values
+ conf.set("io.serializations", "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ conf.setBoolean(SEQUENTIAL_ACCESS, sequentialAccess);
+ conf.setBoolean(NAMED_VECTOR, namedVector);
+ conf.setInt(DIMENSION, dimension);
+ conf.setFloat(NORMALIZATION_POWER, normPower);
+ conf.setBoolean(LOG_NORMALIZE, logNormalize);
+
+ Job job = new Job(conf);
+ job.setJobName("PartialVectorMerger::MergePartialVectors");
+ job.setJarByClass(PartialVectorMerger.class);
+
+ job.setOutputKeyClass(Text.class);
+ job.setOutputValueClass(VectorWritable.class);
+
+ FileInputFormat.setInputPaths(job, getCommaSeparatedPaths(partialVectorPaths));
+
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setMapperClass(Mapper.class);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setReducerClass(PartialVectorMergeReducer.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ job.setNumReduceTasks(numReducers);
+
+ HadoopUtil.delete(conf, output);
+
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+ }
+
+ private static String getCommaSeparatedPaths(Iterable<Path> paths) {
+ StringBuilder commaSeparatedPaths = new StringBuilder(100);
+ String sep = "";
+ for (Path path : paths) {
+ commaSeparatedPaths.append(sep).append(path.toString());
+ sep = ",";
+ }
+ return commaSeparatedPaths.toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/document/SequenceFileTokenizerMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/document/SequenceFileTokenizerMapper.java b/mr/src/main/java/org/apache/mahout/vectorizer/document/SequenceFileTokenizerMapper.java
new file mode 100644
index 0000000..690e0e5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/document/SequenceFileTokenizerMapper.java
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.document;
+
+import java.io.IOException;
+import java.io.StringReader;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.lucene.AnalyzerUtils;
+import org.apache.mahout.vectorizer.DocumentProcessor;
+
+/**
+ * Tokenizes a text document and outputs tokens in a StringTuple
+ */
+public class SequenceFileTokenizerMapper extends Mapper<Text, Text, Text, StringTuple> {
+
+ private Analyzer analyzer;
+
+ @Override
+ protected void map(Text key, Text value, Context context) throws IOException, InterruptedException {
+ TokenStream stream = analyzer.tokenStream(key.toString(), new StringReader(value.toString()));
+ CharTermAttribute termAtt = stream.addAttribute(CharTermAttribute.class);
+ stream.reset();
+ StringTuple document = new StringTuple();
+ while (stream.incrementToken()) {
+ if (termAtt.length() > 0) {
+ document.add(new String(termAtt.buffer(), 0, termAtt.length()));
+ }
+ }
+ stream.end();
+ Closeables.close(stream, true);
+ context.write(key, document);
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+
+ String analyzerClassName = context.getConfiguration().get(DocumentProcessor.ANALYZER_CLASS,
+ StandardAnalyzer.class.getName());
+ try {
+ analyzer = AnalyzerUtils.createAnalyzer(analyzerClassName);
+ } catch (ClassNotFoundException e) {
+ throw new IOException("Unable to create analyzer: " + analyzerClassName, e);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/AdaptiveWordValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/AdaptiveWordValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/AdaptiveWordValueEncoder.java
new file mode 100644
index 0000000..04b718e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/AdaptiveWordValueEncoder.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Multiset;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Encodes words into vectors much as does WordValueEncoder while maintaining
+ * an adaptive dictionary of values seen so far. This allows weighting of terms
+ * without a pre-scan of all of the data.
+ */
+public class AdaptiveWordValueEncoder extends WordValueEncoder {
+
+ private final Multiset<String> dictionary;
+
+ public AdaptiveWordValueEncoder(String name) {
+ super(name);
+ dictionary = HashMultiset.create();
+ }
+
+ /**
+ * Adds a value to a vector.
+ *
+ * @param originalForm The original form of the value as a string.
+ * @param data The vector to which the value should be added.
+ */
+ @Override
+ public void addToVector(String originalForm, double weight, Vector data) {
+ dictionary.add(originalForm);
+ super.addToVector(originalForm, weight, data);
+ }
+
+ @Override
+ protected double getWeight(byte[] originalForm, double w) {
+ return w * weight(originalForm);
+ }
+
+ @Override
+ protected double weight(byte[] originalForm) {
+ // the counts here are adjusted so that every observed value has an extra 0.5 count
+ // as does a hypothetical unobserved value. This smooths our estimates a bit and
+ // allows the first word seen to have a non-zero weight of -log(1.5 / 2)
+ double thisWord = dictionary.count(new String(originalForm, Charsets.UTF_8)) + 0.5;
+ double allWords = dictionary.size() + dictionary.elementSet().size() * 0.5 + 0.5;
+ return -Math.log(thisWord / allWords);
+ }
+
+ public Multiset<String> getDictionary() {
+ return dictionary;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingContinuousValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingContinuousValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingContinuousValueEncoder.java
new file mode 100644
index 0000000..0b350c6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingContinuousValueEncoder.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import java.util.Arrays;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.map.OpenIntIntHashMap;
+
+public class CachingContinuousValueEncoder extends ContinuousValueEncoder {
+ private final int dataSize;
+ private OpenIntIntHashMap[] caches;
+
+ public CachingContinuousValueEncoder(String name, int dataSize) {
+ super(name);
+ this.dataSize = dataSize;
+ initCaches();
+ }
+
+ private void initCaches() {
+ this.caches = new OpenIntIntHashMap[getProbes()];
+ for (int probe = 0; probe < getProbes(); probe++) {
+ caches[probe] = new OpenIntIntHashMap();
+ }
+ }
+
+ OpenIntIntHashMap[] getCaches() {
+ return caches;
+ }
+
+ @Override
+ public void setProbes(int probes) {
+ super.setProbes(probes);
+ initCaches();
+ }
+
+ @Override
+ protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) {
+ Preconditions.checkArgument(dataSize == this.dataSize,
+ "dataSize argument [" + dataSize + "] does not match expected dataSize [" + this.dataSize + ']');
+ int originalHashcode = Arrays.hashCode(originalForm);
+ if (caches[probe].containsKey(originalHashcode)) {
+ return caches[probe].get(originalHashcode);
+ }
+ int hash = super.hashForProbe(originalForm, dataSize, name, probe);
+ caches[probe].put(originalHashcode, hash);
+ return hash;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingStaticWordValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingStaticWordValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingStaticWordValueEncoder.java
new file mode 100644
index 0000000..258ff84
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingStaticWordValueEncoder.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import java.util.Arrays;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.math.map.OpenIntIntHashMap;
+
+public class CachingStaticWordValueEncoder extends StaticWordValueEncoder {
+
+ private final int dataSize;
+ private OpenIntIntHashMap[] caches;
+
+ public CachingStaticWordValueEncoder(String name, int dataSize) {
+ super(name);
+ this.dataSize = dataSize;
+ initCaches();
+ }
+
+ private void initCaches() {
+ caches = new OpenIntIntHashMap[getProbes()];
+ for (int probe = 0; probe < getProbes(); probe++) {
+ caches[probe] = new OpenIntIntHashMap();
+ }
+ }
+
+ OpenIntIntHashMap[] getCaches() {
+ return caches;
+ }
+
+ @Override
+ public void setProbes(int probes) {
+ super.setProbes(probes);
+ initCaches();
+ }
+
+ @Override
+ protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) {
+ Preconditions.checkArgument(dataSize == this.dataSize,
+ "dataSize argument [" + dataSize + "] does not match expected dataSize [" + this.dataSize + ']');
+ int originalHashcode = Arrays.hashCode(originalForm);
+ if (caches[probe].containsKey(originalHashcode)) {
+ return caches[probe].get(originalHashcode);
+ }
+ int hash = super.hashForProbe(originalForm, dataSize, name, probe);
+ caches[probe].put(originalHashcode, hash);
+ return hash;
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingTextValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingTextValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingTextValueEncoder.java
new file mode 100644
index 0000000..b109818
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingTextValueEncoder.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+public class CachingTextValueEncoder extends TextValueEncoder {
+ public CachingTextValueEncoder(String name, int dataSize) {
+ super(name);
+ setWordEncoder(new CachingStaticWordValueEncoder(name, dataSize));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java
new file mode 100644
index 0000000..08d3d3e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/CachingValueEncoder.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import org.apache.mahout.math.MurmurHash;
+
+/**
+ * Provides basic hashing semantics for encoders where the probe locations
+ * depend only on the name of the variable.
+ */
+public abstract class CachingValueEncoder extends FeatureVectorEncoder {
+ private int[] cachedProbes;
+
+ protected CachingValueEncoder(String name, int seed) {
+ super(name);
+ cacheProbeLocations(seed);
+ }
+
+ /**
+ * Sets the number of locations in the feature vector that a value should be in.
+ * This causes the cached probe locations to be recomputed.
+ *
+ * @param probes Number of locations to increment.
+ */
+ @Override
+ public void setProbes(int probes) {
+ super.setProbes(probes);
+ cacheProbeLocations(getSeed());
+ }
+
+ protected abstract int getSeed();
+
+ private void cacheProbeLocations(int seed) {
+ cachedProbes = new int[getProbes()];
+ for (int i = 0; i < getProbes(); i++) {
+ // note that the modulo operation is deferred
+ cachedProbes[i] = (int) MurmurHash.hash64A(bytesForString(getName()), seed + i);
+ }
+ }
+
+ @Override
+ protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) {
+ int h = cachedProbes[probe] % dataSize;
+ if (h < 0) {
+ h += dataSize;
+ }
+ return h;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java
new file mode 100644
index 0000000..d7dd9f6
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/ConstantValueEncoder.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * An encoder that does the standard thing for a virtual bias term.
+ */
+public class ConstantValueEncoder extends CachingValueEncoder {
+ public ConstantValueEncoder(String name) {
+ super(name, 0);
+ }
+
+ @Override
+ public void addToVector(byte[] originalForm, double weight, Vector data) {
+ int probes = getProbes();
+ String name = getName();
+ for (int i = 0; i < probes; i++) {
+ int n = hashForProbe(originalForm, data.size(), name, i);
+ if (isTraceEnabled()) {
+ trace((String) null, n);
+ }
+ data.set(n, data.get(n) + getWeight(originalForm,weight));
+ }
+ }
+
+ @Override
+ protected double getWeight(byte[] originalForm, double w) {
+ return w;
+ }
+
+ @Override
+ public String asString(String originalForm) {
+ return getName();
+ }
+
+ @Override
+ protected int getSeed() {
+ return 0;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java
new file mode 100644
index 0000000..14382a5
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/ContinuousValueEncoder.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import com.google.common.base.Charsets;
+import org.apache.mahout.math.Vector;
+
+/**
+ * Continuous values are stored in fixed randomized location in the feature vector.
+ */
+public class ContinuousValueEncoder extends CachingValueEncoder {
+
+ public ContinuousValueEncoder(String name) {
+ super(name, CONTINUOUS_VALUE_HASH_SEED);
+ }
+
+ /**
+ * Adds a value to a vector.
+ *
+ * @param originalForm The original form of the value as a string.
+ * @param data The vector to which the value should be added.
+ */
+ @Override
+ public void addToVector(byte[] originalForm, double weight, Vector data) {
+ int probes = getProbes();
+ String name = getName();
+ for (int i = 0; i < probes; i++) {
+ int n = hashForProbe(originalForm, data.size(), name, i);
+ if (isTraceEnabled()) {
+ trace((String) null, n);
+ }
+ data.set(n, data.get(n) + getWeight(originalForm,weight));
+ }
+ }
+
+ @Override
+ protected double getWeight(byte[] originalForm, double w) {
+ if (originalForm == null) {
+ return w;
+ }
+ return w * Double.parseDouble(new String(originalForm, Charsets.UTF_8));
+ }
+
+ /**
+ * Converts a value into a form that would help a human understand the internals of how the value
+ * is being interpreted. For text-like things, this is likely to be a list of the terms found with
+ * associated weights (if any).
+ *
+ * @param originalForm The original form of the value as a string.
+ * @return A string that a human can read.
+ */
+ @Override
+ public String asString(String originalForm) {
+ return getName() + ':' + originalForm;
+ }
+
+ @Override
+ protected int getSeed() {
+ return CONTINUOUS_VALUE_HASH_SEED;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/Dictionary.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/Dictionary.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/Dictionary.java
new file mode 100644
index 0000000..2ea9b1b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/Dictionary.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import java.util.List;
+import java.util.Map;
+
+/**
+* Assigns integer codes to strings as they appear.
+*/
+public class Dictionary {
+ private final Map<String, Integer> dict = Maps.newLinkedHashMap();
+
+ public int intern(String s) {
+ if (!dict.containsKey(s)) {
+ dict.put(s, dict.size());
+ }
+ return dict.get(s);
+ }
+
+ public List<String> values() {
+ // order of keySet is guaranteed to be insertion order
+ return Lists.newArrayList(dict.keySet());
+ }
+
+ public int size() {
+ return dict.size();
+ }
+
+ public static Dictionary fromList(Iterable<String> values) {
+ Dictionary dict = new Dictionary();
+ for (String value : values) {
+ dict.intern(value);
+ }
+ return dict;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/FeatureVectorEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/FeatureVectorEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/FeatureVectorEncoder.java
new file mode 100644
index 0000000..96498d7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/FeatureVectorEncoder.java
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import com.google.common.base.Charsets;
+import com.google.common.collect.Sets;
+import org.apache.mahout.math.MurmurHash;
+import org.apache.mahout.math.Vector;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * General interface for objects that record features into a feature vector.
+ * <p/>
+ * By convention, sub-classes should provide a constructor that accepts just a field name as well as
+ * setters to customize properties of the conversion such as adding tokenizers or a weight
+ * dictionary.
+ */
+public abstract class FeatureVectorEncoder {
+ protected static final int CONTINUOUS_VALUE_HASH_SEED = 1;
+ protected static final int WORD_LIKE_VALUE_HASH_SEED = 100;
+
+ private static final byte[] EMPTY_ARRAY = new byte[0];
+
+ private final String name;
+ private int probes;
+
+ private Map<String, Set<Integer>> traceDictionary;
+
+ protected FeatureVectorEncoder(String name) {
+ this(name, 1);
+ }
+
+ protected FeatureVectorEncoder(String name, int probes) {
+ this.name = name;
+ this.probes = probes;
+ }
+
+ /**
+ * Adds a value expressed in string form to a vector.
+ *
+ * @param originalForm The original form of the value as a string.
+ * @param data The vector to which the value should be added.
+ */
+ public void addToVector(String originalForm, Vector data) {
+ addToVector(originalForm, 1.0, data);
+ }
+
+ /**
+ * Adds a value expressed in byte array form to a vector.
+ *
+ * @param originalForm The original form of the value as a byte array.
+ * @param data The vector to which the value should be added.
+ */
+ public void addToVector(byte[] originalForm, Vector data) {
+ addToVector(originalForm, 1.0, data);
+ }
+
+ /**
+ * Adds a weighted value expressed in string form to a vector. In some cases it is convenient to
+ * use this method to encode continuous values using the weight as the value. In such cases, the
+ * string value should typically be set to null.
+ *
+ * @param originalForm The original form of the value as a string.
+ * @param weight The weight to be applied to this feature.
+ * @param data The vector to which the value should be added.
+ */
+ public void addToVector(String originalForm, double weight, Vector data) {
+ addToVector(bytesForString(originalForm), weight, data);
+ }
+
+ public abstract void addToVector(byte[] originalForm, double weight, Vector data);
+
+ /**
+ * Provides the unique hash for a particular probe. For all encoders except text, this
+ * is all that is needed and the default implementation of hashesForProbe will do the right
+ * thing. For text and similar values, hashesForProbe should be over-ridden and this method
+ * should not be used.
+ *
+ * @param originalForm The original byte array value
+ * @param dataSize The length of the vector being encoded
+ * @param name The name of the variable being encoded
+ * @param probe The probe number
+ * @return The hash of the current probe
+ */
+ protected abstract int hashForProbe(byte[] originalForm, int dataSize, String name, int probe);
+
+ /**
+ * Returns all of the hashes for this probe. For most encoders, this is a singleton, but
+ * for text, many hashes are returned, one for each word (unique or not). Most implementations
+ * should only implement hashForProbe for simplicity.
+ *
+ * @param originalForm The original byte array value.
+ * @param dataSize The length of the vector being encoded
+ * @param name The name of the variable being encoded
+ * @param probe The probe number
+ * @return an Iterable of the hashes
+ */
+ protected Iterable<Integer> hashesForProbe(byte[] originalForm, int dataSize, String name, int probe) {
+ return Collections.singletonList(hashForProbe(originalForm, dataSize, name, probe));
+ }
+
+ protected double getWeight(byte[] originalForm, double w) {
+ return 1.0;
+ }
+
+ // ******* Utility functions used by most implementations
+
+ /**
+ * Hash a string and an integer into the range [0..numFeatures-1].
+ *
+ * @param term The string.
+ * @param probe An integer that modifies the resulting hash.
+ * @param numFeatures The range into which the resulting hash must fit.
+ * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in
+ * term and probe.
+ */
+ protected int hash(String term, int probe, int numFeatures) {
+ long r = MurmurHash.hash64A(bytesForString(term), probe) % numFeatures;
+ if (r < 0) {
+ r += numFeatures;
+ }
+ return (int) r;
+ }
+
+ /**
+ * Hash a byte array and an integer into the range [0..numFeatures-1].
+ *
+ * @param term The bytes.
+ * @param probe An integer that modifies the resulting hash.
+ * @param numFeatures The range into which the resulting hash must fit.
+ * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in
+ * term and probe.
+ */
+ protected static int hash(byte[] term, int probe, int numFeatures) {
+ long r = MurmurHash.hash64A(term, probe) % numFeatures;
+ if (r < 0) {
+ r += numFeatures;
+ }
+ return (int) r;
+ }
+
+ /**
+ * Hash two strings and an integer into the range [0..numFeatures-1].
+ *
+ * @param term1 The first string.
+ * @param term2 The second string.
+ * @param probe An integer that modifies the resulting hash.
+ * @param numFeatures The range into which the resulting hash must fit.
+ * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in
+ * term and probe.
+ */
+ protected static int hash(String term1, String term2, int probe, int numFeatures) {
+ long r = MurmurHash.hash64A(bytesForString(term1), probe);
+ r = MurmurHash.hash64A(bytesForString(term2), (int) r) % numFeatures;
+ if (r < 0) {
+ r += numFeatures;
+ }
+ return (int) r;
+ }
+
+ /**
+ * Hash two byte arrays and an integer into the range [0..numFeatures-1].
+ *
+ * @param term1 The first string.
+ * @param term2 The second string.
+ * @param probe An integer that modifies the resulting hash.
+ * @param numFeatures The range into which the resulting hash must fit.
+ * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in
+ * term and probe.
+ */
+ protected int hash(byte[] term1, byte[] term2, int probe, int numFeatures) {
+ long r = MurmurHash.hash64A(term1, probe);
+ r = MurmurHash.hash64A(term2, (int) r) % numFeatures;
+ if (r < 0) {
+ r += numFeatures;
+ }
+ return (int) r;
+ }
+
+ /**
+ * Hash four strings and an integer into the range [0..numFeatures-1].
+ *
+ * @param term1 The first string.
+ * @param term2 The second string.
+ * @param term3 The third string
+ * @param term4 And the fourth.
+ * @param probe An integer that modifies the resulting hash.
+ * @param numFeatures The range into which the resulting hash must fit.
+ * @return An integer in the range [0..numFeatures-1] that has good spread for small changes in
+ * term and probe.
+ */
+ protected int hash(String term1, String term2, String term3, String term4, int probe, int numFeatures) {
+ long r = MurmurHash.hash64A(bytesForString(term1), probe);
+ r = MurmurHash.hash64A(bytesForString(term2), (int) r) % numFeatures;
+ r = MurmurHash.hash64A(bytesForString(term3), (int) r) % numFeatures;
+ r = MurmurHash.hash64A(bytesForString(term4), (int) r) % numFeatures;
+ if (r < 0) {
+ r += numFeatures;
+ }
+ return (int) r;
+ }
+
+ /**
+ * Converts a value into a form that would help a human understand the internals of how the value
+ * is being interpreted. For text-like things, this is likely to be a list of the terms found
+ * with associated weights (if any).
+ *
+ * @param originalForm The original form of the value as a string.
+ * @return A string that a human can read.
+ */
+ public abstract String asString(String originalForm);
+
+ public int getProbes() {
+ return probes;
+ }
+
+ /**
+ * Sets the number of locations in the feature vector that a value should be in.
+ *
+ * @param probes Number of locations to increment.
+ */
+ public void setProbes(int probes) {
+ this.probes = probes;
+ }
+
+ public String getName() {
+ return name;
+ }
+
+ protected boolean isTraceEnabled() {
+ return traceDictionary != null;
+ }
+
+ protected void trace(String subName, int n) {
+ if (traceDictionary != null) {
+ String key = name;
+ if (subName != null) {
+ key = name + '=' + subName;
+ }
+ Set<Integer> trace = traceDictionary.get(key);
+ if (trace == null) {
+ trace = Sets.newHashSet(n);
+ traceDictionary.put(key, trace);
+ } else {
+ trace.add(n);
+ }
+ }
+ }
+
+ protected void trace(byte[] subName, int n) {
+ trace(new String(subName, Charsets.UTF_8), n);
+ }
+
+ public void setTraceDictionary(Map<String, Set<Integer>> traceDictionary) {
+ this.traceDictionary = traceDictionary;
+ }
+
+ protected static byte[] bytesForString(String x) {
+ return x == null ? EMPTY_ARRAY : x.getBytes(Charsets.UTF_8);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/InteractionValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/InteractionValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/InteractionValueEncoder.java
new file mode 100644
index 0000000..0be8823
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/InteractionValueEncoder.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import java.util.Locale;
+
+import org.apache.mahout.math.Vector;
+
+import com.google.common.base.Charsets;
+
+public class InteractionValueEncoder extends FeatureVectorEncoder {
+ private final FeatureVectorEncoder firstEncoder;
+ private final FeatureVectorEncoder secondEncoder;
+
+ public InteractionValueEncoder(String name, FeatureVectorEncoder encoderOne, FeatureVectorEncoder encoderTwo) {
+ super(name, 2);
+ firstEncoder = encoderOne;
+ secondEncoder = encoderTwo;
+ }
+
+ /**
+ * Adds a value to a vector.
+ *
+ * @param originalForm The original form of the first value as a string.
+ * @param data The vector to which the value should be added.
+ */
+ @Override
+ public void addToVector(String originalForm, double w, Vector data) {
+ throw new UnsupportedOperationException("addToVector is not supported for InteractionVectorEncoder");
+ }
+
+ /**
+ * Adds a value to a vector. (Unsupported)
+ *
+ * @param originalForm The original form of the first value as a byte array.
+ * @param data The vector to which the value should be added.
+ */
+ @Override
+ public void addToVector(byte[] originalForm, double w, Vector data) {
+ throw new UnsupportedOperationException("addToVector is not supported for InteractionVectorEncoder");
+ }
+
+ /**
+ * Adds a value to a vector.
+ *
+ * @param original1 The original form of the first value as a string.
+ * @param original2 The original form of the second value as a string.
+ * @param weight How much to weight this interaction
+ * @param data The vector to which the value should be added.
+ */
+ public void addInteractionToVector(String original1, String original2, double weight, Vector data) {
+ byte[] originalForm1 = bytesForString(original1);
+ byte[] originalForm2 = bytesForString(original2);
+ addInteractionToVector(originalForm1, originalForm2, weight, data);
+ }
+
+ /**
+ * Adds a value to a vector.
+ *
+ * @param originalForm1 The original form of the first value as a byte array.
+ * @param originalForm2 The original form of the second value as a byte array.
+ * @param weight How much to weight this interaction
+ * @param data The vector to which the value should be added.
+ */
+ public void addInteractionToVector(byte[] originalForm1, byte[] originalForm2, double weight, Vector data) {
+ String name = getName();
+ double w = getWeight(originalForm1, originalForm2, weight);
+ for (int i = 0; i < probes(); i++) {
+ Iterable<Integer> jValues =
+ secondEncoder.hashesForProbe(originalForm2, data.size(), name, i % secondEncoder.getProbes());
+ for (Integer k : firstEncoder.hashesForProbe(originalForm1, data.size(), name, i % firstEncoder.getProbes())) {
+ for (Integer j : jValues) {
+ int n = (k + j) % data.size();
+ if (isTraceEnabled()) {
+ trace(String.format("%s:%s", new String(originalForm1, Charsets.UTF_8), new String(originalForm2,
+ Charsets.UTF_8)), n);
+ }
+ data.set(n, data.get(n) + w);
+ }
+ }
+ }
+ }
+
+ private int probes() {
+ return getProbes();
+ }
+
+ protected double getWeight(byte[] originalForm1, byte[] originalForm2, double w) {
+ return firstEncoder.getWeight(originalForm1, 1.0) * secondEncoder.getWeight(originalForm2, 1.0) * w;
+ }
+
+ /**
+ * Converts a value into a form that would help a human understand the internals of how the value
+ * is being interpreted. For text-like things, this is likely to be a list of the terms found with
+ * associated weights (if any).
+ *
+ * @param originalForm The original form of the value as a string.
+ * @return A string that a human can read.
+ */
+ @Override
+ public String asString(String originalForm) {
+ return String.format(Locale.ENGLISH, "%s:%s", getName(), originalForm);
+ }
+
+ @Override
+ protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) {
+ return hash(name, probe, dataSize);
+ }
+}
+
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/LuceneTextValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/LuceneTextValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/LuceneTextValueEncoder.java
new file mode 100644
index 0000000..3bae26e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/LuceneTextValueEncoder.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.mahout.common.lucene.TokenStreamIterator;
+
+import java.io.IOException;
+import java.io.Reader;
+import java.nio.CharBuffer;
+import java.util.Iterator;
+
+/**
+ * Encodes text using a lucene style tokenizer.
+ *
+ * @see TextValueEncoder
+ */
+public class LuceneTextValueEncoder extends TextValueEncoder {
+ private Analyzer analyzer;
+
+ public LuceneTextValueEncoder(String name) {
+ super(name);
+ }
+
+ public void setAnalyzer(Analyzer analyzer) {
+ this.analyzer = analyzer;
+ }
+
+ /**
+ * Tokenizes a string using the simplest method. This should be over-ridden for more subtle
+ * tokenization.
+ */
+ @Override
+ protected Iterable<String> tokenize(CharSequence originalForm) {
+ try {
+ TokenStream ts = analyzer.tokenStream(getName(), new CharSequenceReader(originalForm));
+ ts.addAttribute(CharTermAttribute.class);
+ return new LuceneTokenIterable(ts, false);
+ } catch (IOException ex) {
+ throw new IllegalStateException(ex);
+ }
+ }
+
+ private static final class CharSequenceReader extends Reader {
+ private final CharBuffer buf;
+
+ /**
+ * Creates a new character-stream reader whose critical sections will synchronize on the reader
+ * itself.
+ */
+ private CharSequenceReader(CharSequence input) {
+ int n = input.length();
+ buf = CharBuffer.allocate(n);
+ for (int i = 0; i < n; i++) {
+ buf.put(input.charAt(i));
+ }
+ buf.rewind();
+ }
+
+ /**
+ * Reads characters into a portion of an array. This method will block until some input is
+ * available, an I/O error occurs, or the end of the stream is reached.
+ *
+ * @param cbuf Destination buffer
+ * @param off Offset at which to start storing characters
+ * @param len Maximum number of characters to read
+ * @return The number of characters read, or -1 if the end of the stream has been reached
+ */
+ @Override
+ public int read(char[] cbuf, int off, int len) {
+ int toRead = Math.min(len, buf.remaining());
+ if (toRead > 0) {
+ buf.get(cbuf, off, toRead);
+ return toRead;
+ } else {
+ return -1;
+ }
+ }
+
+ @Override
+ public void close() {
+ // do nothing
+ }
+ }
+
+ private static final class LuceneTokenIterable implements Iterable<String> {
+ private boolean firstTime = true;
+ private final TokenStream tokenStream;
+
+ private LuceneTokenIterable(TokenStream ts, boolean firstTime) {
+ this.tokenStream = ts;
+ this.firstTime = firstTime;
+ }
+
+ /**
+ * Returns an iterator over a set of elements of type T.
+ *
+ * @return an Iterator.
+ */
+ @Override
+ public Iterator<String> iterator() {
+ if (firstTime) {
+ firstTime = false;
+ } else {
+ try {
+ tokenStream.reset();
+ } catch (IOException e) {
+ throw new IllegalStateException("This token stream can't be reset");
+ }
+ }
+
+ return new TokenStreamIterator(tokenStream);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/StaticWordValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/StaticWordValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/StaticWordValueEncoder.java
new file mode 100644
index 0000000..6f67ef4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/StaticWordValueEncoder.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import com.google.common.base.Charsets;
+
+import java.util.Collections;
+import java.util.Map;
+
+/**
+ * Encodes a categorical values with an unbounded vocabulary. Values are encoding by incrementing a
+ * few locations in the output vector with a weight that is either defaulted to 1 or that is looked
+ * up in a weight dictionary. By default, only one probe is used which should be fine but could
+ * cause a decrease in the speed of learning because more features will be non-zero. If a large
+ * feature vector is used so that the probability of feature collisions is suitably small, then this
+ * can be decreased to 1. If a very small feature vector is used, the number of probes should
+ * probably be increased to 3.
+ */
+public class StaticWordValueEncoder extends WordValueEncoder {
+ private Map<String, Double> dictionary;
+ private double missingValueWeight = 1;
+ private final byte[] nameBytes;
+
+ public StaticWordValueEncoder(String name) {
+ super(name);
+ nameBytes = bytesForString(name);
+ }
+
+ @Override
+ protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) {
+ return hash(nameBytes, originalForm, WORD_LIKE_VALUE_HASH_SEED + probe, dataSize);
+ }
+
+ /**
+ * Sets the weighting dictionary to be used by this encoder. Also sets the missing value weight
+ * to be half the smallest weight in the dictionary.
+ *
+ * @param dictionary The dictionary to use to look up weights.
+ */
+ public void setDictionary(Map<String, Double> dictionary) {
+ this.dictionary = dictionary;
+ setMissingValueWeight(Collections.min(dictionary.values()) / 2);
+ }
+
+ /**
+ * Sets the weight that is to be used for values that do not appear in the dictionary.
+ *
+ * @param missingValueWeight The default weight for missing values.
+ */
+ public void setMissingValueWeight(double missingValueWeight) {
+ this.missingValueWeight = missingValueWeight;
+ }
+
+ @Override
+ protected double weight(byte[] originalForm) {
+ double weight = missingValueWeight;
+ if (dictionary != null) {
+ String s = new String(originalForm, Charsets.UTF_8);
+ if (dictionary.containsKey(s)) {
+ weight = dictionary.get(s);
+ }
+ }
+ return weight;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/TextValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/TextValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/TextValueEncoder.java
new file mode 100644
index 0000000..87de095
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/TextValueEncoder.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import org.apache.mahout.math.Vector;
+
+import java.util.Collection;
+import java.util.regex.Pattern;
+
+/**
+ * Encodes text that is tokenized on non-alphanum separators. Each word is encoded using a
+ * settable encoder which is by default an StaticWordValueEncoder which gives all
+ * words the same weight.
+ * @see LuceneTextValueEncoder
+ */
+public class TextValueEncoder extends FeatureVectorEncoder {
+
+ private static final double LOG_2 = Math.log(2.0);
+
+ private static final Splitter ON_NON_WORD = Splitter.on(Pattern.compile("\\W+")).omitEmptyStrings();
+
+ private FeatureVectorEncoder wordEncoder;
+ private final Multiset<String> counts;
+
+ public TextValueEncoder(String name) {
+ super(name, 2);
+ wordEncoder = new StaticWordValueEncoder(name);
+ counts = HashMultiset.create();
+ }
+
+ /**
+ * Adds a value to a vector after tokenizing it by splitting on non-alphanum characters.
+ *
+ * @param originalForm The original form of the value as a string.
+ * @param data The vector to which the value should be added.
+ */
+ @Override
+ public void addToVector(byte[] originalForm, double weight, Vector data) {
+ addText(originalForm);
+ flush(weight, data);
+ }
+
+ /**
+ * Adds text to the internal word counter, but delays converting it to vector
+ * form until flush is called.
+ * @param originalForm The original text encoded as UTF-8
+ */
+ public void addText(byte[] originalForm) {
+ addText(new String(originalForm, Charsets.UTF_8));
+ }
+
+ /**
+ * Adds text to the internal word counter, but delays converting it to vector
+ * form until flush is called.
+ * @param text The original text encoded as UTF-8
+ */
+ public void addText(CharSequence text) {
+ for (String word : tokenize(text)) {
+ counts.add(word);
+ }
+ }
+
+ /**
+ * Adds all of the tokens that we counted up to a vector.
+ */
+ public void flush(double weight, Vector data) {
+ for (String word : counts.elementSet()) {
+ // weight words by log_2(tf) times whatever other weight we are given
+ wordEncoder.addToVector(word, weight * Math.log1p(counts.count(word)) / LOG_2, data);
+ }
+ counts.clear();
+ }
+
+ @Override
+ protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) {
+ return 0;
+ }
+
+ @Override
+ protected Iterable<Integer> hashesForProbe(byte[] originalForm, int dataSize, String name, int probe) {
+ Collection<Integer> hashes = Lists.newArrayList();
+ for (String word : tokenize(new String(originalForm, Charsets.UTF_8))) {
+ hashes.add(hashForProbe(bytesForString(word), dataSize, name, probe));
+ }
+ return hashes;
+ }
+
+ /**
+ * Tokenizes a string using the simplest method. This should be over-ridden for more subtle
+ * tokenization.
+ * @see LuceneTextValueEncoder
+ */
+ protected Iterable<String> tokenize(CharSequence originalForm) {
+ return ON_NON_WORD.split(originalForm);
+ }
+
+ /**
+ * Converts a value into a form that would help a human understand the internals of how the value
+ * is being interpreted. For text-like things, this is likely to be a list of the terms found with
+ * associated weights (if any).
+ *
+ * @param originalForm The original form of the value as a string.
+ * @return A string that a human can read.
+ */
+ @Override
+ public String asString(String originalForm) {
+ StringBuilder r = new StringBuilder();
+ r.append('[');
+ for (String word : tokenize(originalForm)) {
+ if (r.length() > 1) {
+ r.append(", ");
+ }
+ r.append(wordEncoder.asString(word));
+ }
+ r.append(']');
+ return r.toString();
+ }
+
+ public final void setWordEncoder(FeatureVectorEncoder wordEncoder) {
+ this.wordEncoder = wordEncoder;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/encoders/WordValueEncoder.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/encoders/WordValueEncoder.java b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/WordValueEncoder.java
new file mode 100644
index 0000000..2b9dc23
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/encoders/WordValueEncoder.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.encoders;
+
+import org.apache.mahout.math.Vector;
+
+import java.util.Locale;
+
+/**
+ * Encodes words as sparse vector updates to a Vector. Weighting is defined by a
+ * sub-class.
+ */
+public abstract class WordValueEncoder extends FeatureVectorEncoder {
+ private final byte[] nameBytes;
+
+ protected WordValueEncoder(String name) {
+ super(name, 2);
+ nameBytes = bytesForString(name);
+ }
+
+ /**
+ * Adds a value to a vector.
+ *
+ * @param originalForm The original form of the value as a string.
+ * @param data The vector to which the value should be added.
+ */
+ @Override
+ public void addToVector(byte[] originalForm, double w, Vector data) {
+ int probes = getProbes();
+ String name = getName();
+ double weight = getWeight(originalForm, w);
+ for (int i = 0; i < probes; i++) {
+ int n = hashForProbe(originalForm, data.size(), name, i);
+ if (isTraceEnabled()) {
+ trace(originalForm, n);
+ }
+ data.set(n, data.get(n) + weight);
+ }
+ }
+
+
+ @Override
+ protected double getWeight(byte[] originalForm, double w) {
+ return w * weight(originalForm);
+ }
+
+ @Override
+ protected int hashForProbe(byte[] originalForm, int dataSize, String name, int probe) {
+ return hash(nameBytes, originalForm, WORD_LIKE_VALUE_HASH_SEED + probe, dataSize);
+ }
+
+ /**
+ * Converts a value into a form that would help a human understand the internals of how the value
+ * is being interpreted. For text-like things, this is likely to be a list of the terms found with
+ * associated weights (if any).
+ *
+ * @param originalForm The original form of the value as a string.
+ * @return A string that a human can read.
+ */
+ @Override
+ public String asString(String originalForm) {
+ return String.format(Locale.ENGLISH, "%s:%s:%.4f", getName(), originalForm, weight(bytesForString(originalForm)));
+ }
+
+ protected abstract double weight(byte[] originalForm);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/pruner/PrunedPartialVectorMergeReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/pruner/PrunedPartialVectorMergeReducer.java b/mr/src/main/java/org/apache/mahout/vectorizer/pruner/PrunedPartialVectorMergeReducer.java
new file mode 100644
index 0000000..9f14249
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/pruner/PrunedPartialVectorMergeReducer.java
@@ -0,0 +1,65 @@
+package org.apache.mahout.vectorizer.pruner;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.vectorizer.common.PartialVectorMerger;
+
+import java.io.IOException;
+
+public class PrunedPartialVectorMergeReducer extends
+ Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
+
+ private double normPower;
+
+ private boolean logNormalize;
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context context) throws IOException,
+ InterruptedException {
+
+ Vector vector = null;
+ for (VectorWritable value : values) {
+ if (vector == null) {
+ vector = value.get().clone();
+ continue;
+ }
+ //value.get().addTo(vector);
+ vector.assign(value.get(), Functions.PLUS);
+ }
+
+ if (vector != null && normPower != PartialVectorMerger.NO_NORMALIZING) {
+ vector = logNormalize ? vector.logNormalize(normPower) : vector.normalize(normPower);
+ }
+
+ VectorWritable vectorWritable = new VectorWritable(vector);
+ context.write(key, vectorWritable);
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ normPower = conf.getFloat(PartialVectorMerger.NORMALIZATION_POWER, PartialVectorMerger.NO_NORMALIZING);
+ logNormalize = conf.getBoolean(PartialVectorMerger.LOG_NORMALIZE, false);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/pruner/WordsPrunerReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/pruner/WordsPrunerReducer.java b/mr/src/main/java/org/apache/mahout/vectorizer/pruner/WordsPrunerReducer.java
new file mode 100644
index 0000000..e0da4fe
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/pruner/WordsPrunerReducer.java
@@ -0,0 +1,86 @@
+package org.apache.mahout.vectorizer.pruner;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenIntLongHashMap;
+import org.apache.mahout.vectorizer.HighDFWordsPruner;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+public class WordsPrunerReducer extends
+ Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
+
+ private final OpenIntLongHashMap dictionary = new OpenIntLongHashMap();
+ private long maxDf = Long.MAX_VALUE;
+ private long minDf = -1;
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context context)
+ throws IOException, InterruptedException {
+ Iterator<VectorWritable> it = values.iterator();
+ if (!it.hasNext()) {
+ return;
+ }
+ Vector value = it.next().get();
+ Vector vector = value.clone();
+ if (maxDf != Long.MAX_VALUE || minDf > -1) {
+ for (Vector.Element e : value.nonZeroes()) {
+ if (!dictionary.containsKey(e.index())) {
+ vector.setQuick(e.index(), 0.0);
+ continue;
+ }
+ long df = dictionary.get(e.index());
+ if (df > maxDf || df < minDf) {
+ vector.setQuick(e.index(), 0.0);
+ }
+ }
+ }
+
+ VectorWritable vectorWritable = new VectorWritable(vector);
+ context.write(key, vectorWritable);
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ //Path[] localFiles = HadoopUtil.getCachedFiles(conf);
+
+ maxDf = conf.getLong(HighDFWordsPruner.MAX_DF, Long.MAX_VALUE);
+ minDf = conf.getLong(HighDFWordsPruner.MIN_DF, -1);
+
+ Path dictionaryFile = HadoopUtil.getSingleCachedFile(conf);
+
+ // key is feature, value is the document frequency
+ for (Pair<IntWritable, LongWritable> record
+ : new SequenceFileIterable<IntWritable, LongWritable>(dictionaryFile, true, conf)) {
+ dictionary.put(record.getFirst().get(), record.getSecond().get());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/vectorizer/term/TFPartialVectorReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/vectorizer/term/TFPartialVectorReducer.java b/mr/src/main/java/org/apache/mahout/vectorizer/term/TFPartialVectorReducer.java
new file mode 100644
index 0000000..1496c90
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/vectorizer/term/TFPartialVectorReducer.java
@@ -0,0 +1,139 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.vectorizer.term;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.lucene.analysis.shingle.ShingleFilter;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.common.lucene.IteratorTokenStream;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+import org.apache.mahout.vectorizer.DictionaryVectorizer;
+import org.apache.mahout.vectorizer.common.PartialVectorMerger;
+
+import java.io.IOException;
+import java.net.URI;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * Converts a document in to a sparse vector
+ */
+public class TFPartialVectorReducer extends Reducer<Text, StringTuple, Text, VectorWritable> {
+
+ private final OpenObjectIntHashMap<String> dictionary = new OpenObjectIntHashMap<>();
+
+ private int dimension;
+ private boolean sequentialAccess;
+ private boolean namedVector;
+ private int maxNGramSize = 1;
+
+ @Override
+ protected void reduce(Text key, Iterable<StringTuple> values, Context context)
+ throws IOException, InterruptedException {
+ Iterator<StringTuple> it = values.iterator();
+
+ if (!it.hasNext()) {
+ return;
+ }
+
+ List<String> value = Lists.newArrayList();
+
+ while (it.hasNext()) {
+ value.addAll(it.next().getEntries());
+ }
+
+ Vector vector = new RandomAccessSparseVector(dimension, value.size()); // guess at initial size
+
+ if (maxNGramSize >= 2) {
+ ShingleFilter sf = new ShingleFilter(new IteratorTokenStream(value.iterator()), maxNGramSize);
+ sf.reset();
+ try {
+ do {
+ String term = sf.getAttribute(CharTermAttribute.class).toString();
+ if (!term.isEmpty() && dictionary.containsKey(term)) { // ngram
+ int termId = dictionary.get(term);
+ vector.setQuick(termId, vector.getQuick(termId) + 1);
+ }
+ } while (sf.incrementToken());
+
+ sf.end();
+ } finally {
+ Closeables.close(sf, true);
+ }
+ } else {
+ for (String term : value) {
+ if (!term.isEmpty() && dictionary.containsKey(term)) { // unigram
+ int termId = dictionary.get(term);
+ vector.setQuick(termId, vector.getQuick(termId) + 1);
+ }
+ }
+ }
+ if (sequentialAccess) {
+ vector = new SequentialAccessSparseVector(vector);
+ }
+
+ if (namedVector) {
+ vector = new NamedVector(vector, key.toString());
+ }
+
+ // if the vector has no nonZero entries (nothing in the dictionary), let's not waste space sending it to disk.
+ if (vector.getNumNondefaultElements() > 0) {
+ VectorWritable vectorWritable = new VectorWritable(vector);
+ context.write(key, vectorWritable);
+ } else {
+ context.getCounter("TFPartialVectorReducer", "emptyVectorCount").increment(1);
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+
+ dimension = conf.getInt(PartialVectorMerger.DIMENSION, Integer.MAX_VALUE);
+ sequentialAccess = conf.getBoolean(PartialVectorMerger.SEQUENTIAL_ACCESS, false);
+ namedVector = conf.getBoolean(PartialVectorMerger.NAMED_VECTOR, false);
+ maxNGramSize = conf.getInt(DictionaryVectorizer.MAX_NGRAMS, maxNGramSize);
+
+ URI[] localFiles = DistributedCache.getCacheFiles(conf);
+ Path dictionaryFile = HadoopUtil.findInCacheByPartOfFilename(DictionaryVectorizer.DICTIONARY_FILE, localFiles);
+ // key is word value is id
+ for (Pair<Writable, IntWritable> record
+ : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) {
+ dictionary.put(record.getFirst().toString(), record.getSecond().get());
+ }
+ }
+
+}
[43/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousUserLongPrimitiveIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousUserLongPrimitiveIterator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousUserLongPrimitiveIterator.java
new file mode 100644
index 0000000..ea4df85
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/PlusAnonymousUserLongPrimitiveIterator.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model;
+
+import org.apache.mahout.cf.taste.impl.common.AbstractLongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+
+final class PlusAnonymousUserLongPrimitiveIterator extends AbstractLongPrimitiveIterator {
+
+ private final LongPrimitiveIterator delegate;
+ private final long extraDatum;
+ private boolean datumConsumed;
+
+ PlusAnonymousUserLongPrimitiveIterator(LongPrimitiveIterator delegate, long extraDatum) {
+ this.delegate = delegate;
+ this.extraDatum = extraDatum;
+ datumConsumed = false;
+ }
+
+ @Override
+ public long nextLong() {
+ if (datumConsumed) {
+ return delegate.nextLong();
+ } else {
+ if (delegate.hasNext()) {
+ long delegateNext = delegate.peek();
+ if (extraDatum <= delegateNext) {
+ datumConsumed = true;
+ return extraDatum;
+ } else {
+ return delegate.next();
+ }
+ } else {
+ datumConsumed = true;
+ return extraDatum;
+ }
+ }
+ }
+
+ @Override
+ public long peek() {
+ if (datumConsumed) {
+ return delegate.peek();
+ } else {
+ if (delegate.hasNext()) {
+ long delegateNext = delegate.peek();
+ if (extraDatum <= delegateNext) {
+ return extraDatum;
+ } else {
+ return delegateNext;
+ }
+ } else {
+ return extraDatum;
+ }
+ }
+ }
+
+ @Override
+ public boolean hasNext() {
+ return !datumConsumed || delegate.hasNext();
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void skip(int n) {
+ for (int i = 0; i < n; i++) {
+ nextLong();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/file/FileDataModel.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/file/FileDataModel.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/file/FileDataModel.java
new file mode 100644
index 0000000..da6845e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/file/FileDataModel.java
@@ -0,0 +1,759 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model.file;
+
+import java.io.File;
+import java.io.FileFilter;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.concurrent.locks.ReentrantLock;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.model.AbstractDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericBooleanPrefDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.iterator.FileLineIterator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+
+/**
+ * <p>
+ * A {@link DataModel} backed by a delimited file. This class expects a file where each line
+ * contains a user ID, followed by item ID, followed by optional preference value, followed by
+ * optional timestamp. Commas or tabs delimit fields:
+ * </p>
+ *
+ * <p>{@code userID,itemID[,preference[,timestamp]]}</p>
+ *
+ * <p>
+ * Preference value is optional to accommodate applications that have no notion of a
+ * preference value (that is, the user simply expresses a
+ * preference for an item, but no degree of preference).
+ * </p>
+ *
+ * <p>
+ * The preference value is assumed to be parseable as a {@code double}. The user IDs and item IDs are
+ * read parsed as {@code long}s. The timestamp, if present, is assumed to be parseable as a
+ * {@code long}, though this can be overridden via {@link #readTimestampFromString(String)}.
+ * The preference value may be empty, to indicate "no preference value", but cannot be empty. That is,
+ * this is legal:
+ * </p>
+ *
+ * <p>{@code 123,456,,129050099059}</p>
+ *
+ * <p>But this isn't:</p>
+ *
+ * <p>{@code 123,456,129050099059}</p>
+ *
+ * <p>
+ * It is also acceptable for the lines to contain additional fields. Fields beyond the third will be ignored.
+ * An empty line, or one that begins with '#' will be ignored as a comment.
+ * </p>
+ *
+ * <p>
+ * This class will reload data from the data file when {@link #refresh(Collection)} is called, unless the file
+ * has been reloaded very recently already.
+ * </p>
+ *
+ * <p>
+ * This class will also look for update "delta" files in the same directory, with file names that start the
+ * same way (up to the first period). These files have the same format, and provide updated data that
+ * supersedes what is in the main data file. This is a mechanism that allows an application to push updates to
+ * {@link FileDataModel} without re-copying the entire data file.
+ * </p>
+ *
+ * <p>
+ * One small format difference exists. Update files must also be able to express deletes.
+ * This is done by ending with a blank preference value, as in "123,456,".
+ * </p>
+ *
+ * <p>
+ * Note that it's all-or-nothing -- all of the items in the file must express no preference, or the all must.
+ * These cannot be mixed. Put another way there will always be the same number of delimiters on every line of
+ * the file!
+ * </p>
+ *
+ * <p>
+ * This class is not intended for use with very large amounts of data (over, say, tens of millions of rows).
+ * For that, a JDBC-backed {@link DataModel} and a database are more appropriate.
+ * </p>
+ *
+ * <p>
+ * It is possible and likely useful to subclass this class and customize its behavior to accommodate
+ * application-specific needs and input formats. See {@link #processLine(String, FastByIDMap, FastByIDMap, boolean)} and
+ * {@link #processLineWithoutID(String, FastByIDMap, FastByIDMap)}
+ */
+public class FileDataModel extends AbstractDataModel {
+
+ private static final Logger log = LoggerFactory.getLogger(FileDataModel.class);
+
+ public static final long DEFAULT_MIN_RELOAD_INTERVAL_MS = 60 * 1000L; // 1 minute?
+ private static final char COMMENT_CHAR = '#';
+ private static final char[] DELIMIETERS = {',', '\t'};
+
+ private final File dataFile;
+ private long lastModified;
+ private long lastUpdateFileModified;
+ private final Splitter delimiterPattern;
+ private final boolean hasPrefValues;
+ private DataModel delegate;
+ private final ReentrantLock reloadLock;
+ private final boolean transpose;
+ private final long minReloadIntervalMS;
+
+ /**
+ * @param dataFile
+ * file containing preferences data. If file is compressed (and name ends in .gz or .zip
+ * accordingly) it will be decompressed as it is read)
+ * @throws FileNotFoundException
+ * if dataFile does not exist
+ * @throws IOException
+ * if file can't be read
+ */
+ public FileDataModel(File dataFile) throws IOException {
+ this(dataFile, false, DEFAULT_MIN_RELOAD_INTERVAL_MS);
+ }
+
+ /**
+ * @param delimiterRegex If your data file don't use '\t' or ',' as delimiter, you can specify
+ * a custom regex pattern.
+ */
+ public FileDataModel(File dataFile, String delimiterRegex) throws IOException {
+ this(dataFile, false, DEFAULT_MIN_RELOAD_INTERVAL_MS, delimiterRegex);
+ }
+
+ /**
+ * @param transpose
+ * transposes user IDs and item IDs -- convenient for 'flipping' the data model this way
+ * @param minReloadIntervalMS
+ * the minimum interval in milliseconds after which a full reload of the original datafile is done
+ * when refresh() is called
+ * @see #FileDataModel(File)
+ */
+ public FileDataModel(File dataFile, boolean transpose, long minReloadIntervalMS) throws IOException {
+ this(dataFile, transpose, minReloadIntervalMS, null);
+ }
+
+ /**
+ * @param delimiterRegex If your data file don't use '\t' or ',' as delimiters, you can specify
+ * user own using regex pattern.
+ * @throws IOException
+ */
+ public FileDataModel(File dataFile, boolean transpose, long minReloadIntervalMS, String delimiterRegex)
+ throws IOException {
+
+ this.dataFile = Preconditions.checkNotNull(dataFile.getAbsoluteFile());
+ if (!dataFile.exists() || dataFile.isDirectory()) {
+ throw new FileNotFoundException(dataFile.toString());
+ }
+ Preconditions.checkArgument(dataFile.length() > 0L, "dataFile is empty");
+ Preconditions.checkArgument(minReloadIntervalMS >= 0L, "minReloadIntervalMs must be non-negative");
+
+ log.info("Creating FileDataModel for file {}", dataFile);
+
+ this.lastModified = dataFile.lastModified();
+ this.lastUpdateFileModified = readLastUpdateFileModified();
+
+ FileLineIterator iterator = new FileLineIterator(dataFile, false);
+ String firstLine = iterator.peek();
+ while (firstLine.isEmpty() || firstLine.charAt(0) == COMMENT_CHAR) {
+ iterator.next();
+ firstLine = iterator.peek();
+ }
+ Closeables.close(iterator, true);
+
+ char delimiter;
+ if (delimiterRegex == null) {
+ delimiter = determineDelimiter(firstLine);
+ delimiterPattern = Splitter.on(delimiter);
+ } else {
+ delimiter = '\0';
+ delimiterPattern = Splitter.onPattern(delimiterRegex);
+ if (!delimiterPattern.split(firstLine).iterator().hasNext()) {
+ throw new IllegalArgumentException("Did not find a delimiter(pattern) in first line");
+ }
+ }
+ List<String> firstLineSplit = Lists.newArrayList();
+ for (String token : delimiterPattern.split(firstLine)) {
+ firstLineSplit.add(token);
+ }
+ // If preference value exists and isn't empty then the file is specifying pref values
+ hasPrefValues = firstLineSplit.size() >= 3 && !firstLineSplit.get(2).isEmpty();
+
+ this.reloadLock = new ReentrantLock();
+ this.transpose = transpose;
+ this.minReloadIntervalMS = minReloadIntervalMS;
+
+ reload();
+ }
+
+ public File getDataFile() {
+ return dataFile;
+ }
+
+ protected void reload() {
+ if (reloadLock.tryLock()) {
+ try {
+ delegate = buildModel();
+ } catch (IOException ioe) {
+ log.warn("Exception while reloading", ioe);
+ } finally {
+ reloadLock.unlock();
+ }
+ }
+ }
+
+ protected DataModel buildModel() throws IOException {
+
+ long newLastModified = dataFile.lastModified();
+ long newLastUpdateFileModified = readLastUpdateFileModified();
+
+ boolean loadFreshData = delegate == null || newLastModified > lastModified + minReloadIntervalMS;
+
+ long oldLastUpdateFileModifieid = lastUpdateFileModified;
+ lastModified = newLastModified;
+ lastUpdateFileModified = newLastUpdateFileModified;
+
+ FastByIDMap<FastByIDMap<Long>> timestamps = new FastByIDMap<>();
+
+ if (hasPrefValues) {
+
+ if (loadFreshData) {
+
+ FastByIDMap<Collection<Preference>> data = new FastByIDMap<>();
+ FileLineIterator iterator = new FileLineIterator(dataFile, false);
+ processFile(iterator, data, timestamps, false);
+
+ for (File updateFile : findUpdateFilesAfter(newLastModified)) {
+ processFile(new FileLineIterator(updateFile, false), data, timestamps, false);
+ }
+
+ return new GenericDataModel(GenericDataModel.toDataMap(data, true), timestamps);
+
+ } else {
+
+ FastByIDMap<PreferenceArray> rawData = ((GenericDataModel) delegate).getRawUserData();
+
+ for (File updateFile : findUpdateFilesAfter(Math.max(oldLastUpdateFileModifieid, newLastModified))) {
+ processFile(new FileLineIterator(updateFile, false), rawData, timestamps, true);
+ }
+
+ return new GenericDataModel(rawData, timestamps);
+
+ }
+
+ } else {
+
+ if (loadFreshData) {
+
+ FastByIDMap<FastIDSet> data = new FastByIDMap<>();
+ FileLineIterator iterator = new FileLineIterator(dataFile, false);
+ processFileWithoutID(iterator, data, timestamps);
+
+ for (File updateFile : findUpdateFilesAfter(newLastModified)) {
+ processFileWithoutID(new FileLineIterator(updateFile, false), data, timestamps);
+ }
+
+ return new GenericBooleanPrefDataModel(data, timestamps);
+
+ } else {
+
+ FastByIDMap<FastIDSet> rawData = ((GenericBooleanPrefDataModel) delegate).getRawUserData();
+
+ for (File updateFile : findUpdateFilesAfter(Math.max(oldLastUpdateFileModifieid, newLastModified))) {
+ processFileWithoutID(new FileLineIterator(updateFile, false), rawData, timestamps);
+ }
+
+ return new GenericBooleanPrefDataModel(rawData, timestamps);
+
+ }
+
+ }
+ }
+
+ /**
+ * Finds update delta files in the same directory as the data file. This finds any file whose name starts
+ * the same way as the data file (up to first period) but isn't the data file itself. For example, if the
+ * data file is /foo/data.txt.gz, you might place update files at /foo/data.1.txt.gz, /foo/data.2.txt.gz,
+ * etc.
+ */
+ private Iterable<File> findUpdateFilesAfter(long minimumLastModified) {
+ String dataFileName = dataFile.getName();
+ int period = dataFileName.indexOf('.');
+ String startName = period < 0 ? dataFileName : dataFileName.substring(0, period);
+ File parentDir = dataFile.getParentFile();
+ Map<Long, File> modTimeToUpdateFile = new TreeMap<>();
+ FileFilter onlyFiles = new FileFilter() {
+ @Override
+ public boolean accept(File file) {
+ return !file.isDirectory();
+ }
+ };
+ for (File updateFile : parentDir.listFiles(onlyFiles)) {
+ String updateFileName = updateFile.getName();
+ if (updateFileName.startsWith(startName)
+ && !updateFileName.equals(dataFileName)
+ && updateFile.lastModified() >= minimumLastModified) {
+ modTimeToUpdateFile.put(updateFile.lastModified(), updateFile);
+ }
+ }
+ return modTimeToUpdateFile.values();
+ }
+
+ private long readLastUpdateFileModified() {
+ long mostRecentModification = Long.MIN_VALUE;
+ for (File updateFile : findUpdateFilesAfter(0L)) {
+ mostRecentModification = Math.max(mostRecentModification, updateFile.lastModified());
+ }
+ return mostRecentModification;
+ }
+
+ public static char determineDelimiter(String line) {
+ for (char possibleDelimieter : DELIMIETERS) {
+ if (line.indexOf(possibleDelimieter) >= 0) {
+ return possibleDelimieter;
+ }
+ }
+ throw new IllegalArgumentException("Did not find a delimiter in first line");
+ }
+
+ protected void processFile(FileLineIterator dataOrUpdateFileIterator,
+ FastByIDMap<?> data,
+ FastByIDMap<FastByIDMap<Long>> timestamps,
+ boolean fromPriorData) {
+ log.info("Reading file info...");
+ int count = 0;
+ while (dataOrUpdateFileIterator.hasNext()) {
+ String line = dataOrUpdateFileIterator.next();
+ if (!line.isEmpty()) {
+ processLine(line, data, timestamps, fromPriorData);
+ if (++count % 1000000 == 0) {
+ log.info("Processed {} lines", count);
+ }
+ }
+ }
+ log.info("Read lines: {}", count);
+ }
+
+ /**
+ * <p>
+ * Reads one line from the input file and adds the data to a {@link FastByIDMap} data structure which maps user IDs
+ * to preferences. This assumes that each line of the input file corresponds to one preference. After
+ * reading a line and determining which user and item the preference pertains to, the method should look to
+ * see if the data contains a mapping for the user ID already, and if not, add an empty data structure of preferences
+ * as appropriate to the data.
+ * </p>
+ *
+ * <p>
+ * Note that if the line is empty or begins with '#' it will be ignored as a comment.
+ * </p>
+ *
+ * @param line
+ * line from input data file
+ * @param data
+ * all data read so far, as a mapping from user IDs to preferences
+ * @param fromPriorData an implementation detail -- if true, data will map IDs to
+ * {@link PreferenceArray} since the framework is attempting to read and update raw
+ * data that is already in memory. Otherwise it maps to {@link Collection}s of
+ * {@link Preference}s, since it's reading fresh data. Subclasses must be prepared
+ * to handle this wrinkle.
+ */
+ protected void processLine(String line,
+ FastByIDMap<?> data,
+ FastByIDMap<FastByIDMap<Long>> timestamps,
+ boolean fromPriorData) {
+
+ // Ignore empty lines and comments
+ if (line.isEmpty() || line.charAt(0) == COMMENT_CHAR) {
+ return;
+ }
+
+ Iterator<String> tokens = delimiterPattern.split(line).iterator();
+ String userIDString = tokens.next();
+ String itemIDString = tokens.next();
+ String preferenceValueString = tokens.next();
+ boolean hasTimestamp = tokens.hasNext();
+ String timestampString = hasTimestamp ? tokens.next() : null;
+
+ long userID = readUserIDFromString(userIDString);
+ long itemID = readItemIDFromString(itemIDString);
+
+ if (transpose) {
+ long tmp = userID;
+ userID = itemID;
+ itemID = tmp;
+ }
+
+ // This is kind of gross but need to handle two types of storage
+ Object maybePrefs = data.get(userID);
+ if (fromPriorData) {
+ // Data are PreferenceArray
+
+ PreferenceArray prefs = (PreferenceArray) maybePrefs;
+ if (!hasTimestamp && preferenceValueString.isEmpty()) {
+ // Then line is of form "userID,itemID,", meaning remove
+ if (prefs != null) {
+ boolean exists = false;
+ int length = prefs.length();
+ for (int i = 0; i < length; i++) {
+ if (prefs.getItemID(i) == itemID) {
+ exists = true;
+ break;
+ }
+ }
+ if (exists) {
+ if (length == 1) {
+ data.remove(userID);
+ } else {
+ PreferenceArray newPrefs = new GenericUserPreferenceArray(length - 1);
+ for (int i = 0, j = 0; i < length; i++, j++) {
+ if (prefs.getItemID(i) == itemID) {
+ j--;
+ } else {
+ newPrefs.set(j, prefs.get(i));
+ }
+ }
+ ((FastByIDMap<PreferenceArray>) data).put(userID, newPrefs);
+ }
+ }
+ }
+
+ removeTimestamp(userID, itemID, timestamps);
+
+ } else {
+
+ float preferenceValue = Float.parseFloat(preferenceValueString);
+
+ boolean exists = false;
+ if (prefs != null) {
+ for (int i = 0; i < prefs.length(); i++) {
+ if (prefs.getItemID(i) == itemID) {
+ exists = true;
+ prefs.setValue(i, preferenceValue);
+ break;
+ }
+ }
+ }
+
+ if (!exists) {
+ if (prefs == null) {
+ prefs = new GenericUserPreferenceArray(1);
+ } else {
+ PreferenceArray newPrefs = new GenericUserPreferenceArray(prefs.length() + 1);
+ for (int i = 0, j = 1; i < prefs.length(); i++, j++) {
+ newPrefs.set(j, prefs.get(i));
+ }
+ prefs = newPrefs;
+ }
+ prefs.setUserID(0, userID);
+ prefs.setItemID(0, itemID);
+ prefs.setValue(0, preferenceValue);
+ ((FastByIDMap<PreferenceArray>) data).put(userID, prefs);
+ }
+ }
+
+ addTimestamp(userID, itemID, timestampString, timestamps);
+
+ } else {
+ // Data are Collection<Preference>
+
+ Collection<Preference> prefs = (Collection<Preference>) maybePrefs;
+
+ if (!hasTimestamp && preferenceValueString.isEmpty()) {
+ // Then line is of form "userID,itemID,", meaning remove
+ if (prefs != null) {
+ // remove pref
+ Iterator<Preference> prefsIterator = prefs.iterator();
+ while (prefsIterator.hasNext()) {
+ Preference pref = prefsIterator.next();
+ if (pref.getItemID() == itemID) {
+ prefsIterator.remove();
+ break;
+ }
+ }
+ }
+
+ removeTimestamp(userID, itemID, timestamps);
+
+ } else {
+
+ float preferenceValue = Float.parseFloat(preferenceValueString);
+
+ boolean exists = false;
+ if (prefs != null) {
+ for (Preference pref : prefs) {
+ if (pref.getItemID() == itemID) {
+ exists = true;
+ pref.setValue(preferenceValue);
+ break;
+ }
+ }
+ }
+
+ if (!exists) {
+ if (prefs == null) {
+ prefs = Lists.newArrayListWithCapacity(2);
+ ((FastByIDMap<Collection<Preference>>) data).put(userID, prefs);
+ }
+ prefs.add(new GenericPreference(userID, itemID, preferenceValue));
+ }
+
+ addTimestamp(userID, itemID, timestampString, timestamps);
+
+ }
+
+ }
+ }
+
+ protected void processFileWithoutID(FileLineIterator dataOrUpdateFileIterator,
+ FastByIDMap<FastIDSet> data,
+ FastByIDMap<FastByIDMap<Long>> timestamps) {
+ log.info("Reading file info...");
+ int count = 0;
+ while (dataOrUpdateFileIterator.hasNext()) {
+ String line = dataOrUpdateFileIterator.next();
+ if (!line.isEmpty()) {
+ processLineWithoutID(line, data, timestamps);
+ if (++count % 100000 == 0) {
+ log.info("Processed {} lines", count);
+ }
+ }
+ }
+ log.info("Read lines: {}", count);
+ }
+
+ protected void processLineWithoutID(String line,
+ FastByIDMap<FastIDSet> data,
+ FastByIDMap<FastByIDMap<Long>> timestamps) {
+
+ if (line.isEmpty() || line.charAt(0) == COMMENT_CHAR) {
+ return;
+ }
+
+ Iterator<String> tokens = delimiterPattern.split(line).iterator();
+ String userIDString = tokens.next();
+ String itemIDString = tokens.next();
+ boolean hasPreference = tokens.hasNext();
+ String preferenceValueString = hasPreference ? tokens.next() : "";
+ boolean hasTimestamp = tokens.hasNext();
+ String timestampString = hasTimestamp ? tokens.next() : null;
+
+ long userID = readUserIDFromString(userIDString);
+ long itemID = readItemIDFromString(itemIDString);
+
+ if (transpose) {
+ long tmp = userID;
+ userID = itemID;
+ itemID = tmp;
+ }
+
+ if (hasPreference && !hasTimestamp && preferenceValueString.isEmpty()) {
+ // Then line is of form "userID,itemID,", meaning remove
+
+ FastIDSet itemIDs = data.get(userID);
+ if (itemIDs != null) {
+ itemIDs.remove(itemID);
+ }
+
+ removeTimestamp(userID, itemID, timestamps);
+
+ } else {
+
+ FastIDSet itemIDs = data.get(userID);
+ if (itemIDs == null) {
+ itemIDs = new FastIDSet(2);
+ data.put(userID, itemIDs);
+ }
+ itemIDs.add(itemID);
+
+ addTimestamp(userID, itemID, timestampString, timestamps);
+
+ }
+ }
+
+ private void addTimestamp(long userID,
+ long itemID,
+ String timestampString,
+ FastByIDMap<FastByIDMap<Long>> timestamps) {
+ if (timestampString != null) {
+ FastByIDMap<Long> itemTimestamps = timestamps.get(userID);
+ if (itemTimestamps == null) {
+ itemTimestamps = new FastByIDMap<>();
+ timestamps.put(userID, itemTimestamps);
+ }
+ long timestamp = readTimestampFromString(timestampString);
+ itemTimestamps.put(itemID, timestamp);
+ }
+ }
+
+ private static void removeTimestamp(long userID,
+ long itemID,
+ FastByIDMap<FastByIDMap<Long>> timestamps) {
+ FastByIDMap<Long> itemTimestamps = timestamps.get(userID);
+ if (itemTimestamps != null) {
+ itemTimestamps.remove(itemID);
+ }
+ }
+
+ /**
+ * Subclasses may wish to override this if ID values in the file are not numeric. This provides a hook by
+ * which subclasses can inject an {@link org.apache.mahout.cf.taste.model.IDMigrator} to perform
+ * translation.
+ */
+ protected long readUserIDFromString(String value) {
+ return Long.parseLong(value);
+ }
+
+ /**
+ * Subclasses may wish to override this if ID values in the file are not numeric. This provides a hook by
+ * which subclasses can inject an {@link org.apache.mahout.cf.taste.model.IDMigrator} to perform
+ * translation.
+ */
+ protected long readItemIDFromString(String value) {
+ return Long.parseLong(value);
+ }
+
+ /**
+ * Subclasses may wish to override this to change how time values in the input file are parsed.
+ * By default they are expected to be numeric, expressing a time as milliseconds since the epoch.
+ */
+ protected long readTimestampFromString(String value) {
+ return Long.parseLong(value);
+ }
+
+ @Override
+ public LongPrimitiveIterator getUserIDs() throws TasteException {
+ return delegate.getUserIDs();
+ }
+
+ @Override
+ public PreferenceArray getPreferencesFromUser(long userID) throws TasteException {
+ return delegate.getPreferencesFromUser(userID);
+ }
+
+ @Override
+ public FastIDSet getItemIDsFromUser(long userID) throws TasteException {
+ return delegate.getItemIDsFromUser(userID);
+ }
+
+ @Override
+ public LongPrimitiveIterator getItemIDs() throws TasteException {
+ return delegate.getItemIDs();
+ }
+
+ @Override
+ public PreferenceArray getPreferencesForItem(long itemID) throws TasteException {
+ return delegate.getPreferencesForItem(itemID);
+ }
+
+ @Override
+ public Float getPreferenceValue(long userID, long itemID) throws TasteException {
+ return delegate.getPreferenceValue(userID, itemID);
+ }
+
+ @Override
+ public Long getPreferenceTime(long userID, long itemID) throws TasteException {
+ return delegate.getPreferenceTime(userID, itemID);
+ }
+
+ @Override
+ public int getNumItems() throws TasteException {
+ return delegate.getNumItems();
+ }
+
+ @Override
+ public int getNumUsers() throws TasteException {
+ return delegate.getNumUsers();
+ }
+
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID) throws TasteException {
+ return delegate.getNumUsersWithPreferenceFor(itemID);
+ }
+
+ @Override
+ public int getNumUsersWithPreferenceFor(long itemID1, long itemID2) throws TasteException {
+ return delegate.getNumUsersWithPreferenceFor(itemID1, itemID2);
+ }
+
+ /**
+ * Note that this method only updates the in-memory preference data that this {@link FileDataModel}
+ * maintains; it does not modify any data on disk. Therefore any updates from this method are only
+ * temporary, and lost when data is reloaded from a file. This method should also be considered relatively
+ * slow.
+ */
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ delegate.setPreference(userID, itemID, value);
+ }
+
+ /** See the warning at {@link #setPreference(long, long, float)}. */
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ delegate.removePreference(userID, itemID);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ if (dataFile.lastModified() > lastModified + minReloadIntervalMS
+ || readLastUpdateFileModified() > lastUpdateFileModified + minReloadIntervalMS) {
+ log.debug("File has changed; reloading...");
+ reload();
+ }
+ }
+
+ @Override
+ public boolean hasPreferenceValues() {
+ return delegate.hasPreferenceValues();
+ }
+
+ @Override
+ public float getMaxPreference() {
+ return delegate.getMaxPreference();
+ }
+
+ @Override
+ public float getMinPreference() {
+ return delegate.getMinPreference();
+ }
+
+ @Override
+ public String toString() {
+ return "FileDataModel[dataFile:" + dataFile + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/file/FileIDMigrator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/file/FileIDMigrator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/file/FileIDMigrator.java
new file mode 100644
index 0000000..1bcb4ef
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/model/file/FileIDMigrator.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.model.file;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.concurrent.locks.ReentrantLock;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.model.AbstractIDMigrator;
+import org.apache.mahout.common.iterator.FileLineIterable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * An {@link org.apache.mahout.cf.taste.model.IDMigrator} backed by a file.
+ * This class typically expects a file where each line
+ * contains a single stringID to be stored in this migrator.
+ * </p>
+ *
+ * <p>
+ * This class will reload data from the data file when {@link #refresh(Collection)} is called, unless the file
+ * has been reloaded very recently already.
+ * </p>
+ */
+public class FileIDMigrator extends AbstractIDMigrator {
+
+ public static final long DEFAULT_MIN_RELOAD_INTERVAL_MS = 60 * 1000L; // 1 minute?
+
+ private final File dataFile;
+ private FastByIDMap<String> longToString;
+ private final ReentrantLock reloadLock;
+
+ private long lastModified;
+ private final long minReloadIntervalMS;
+
+ private static final Logger log = LoggerFactory.getLogger(FileIDMigrator.class);
+
+ public FileIDMigrator(File dataFile) throws FileNotFoundException {
+ this(dataFile, DEFAULT_MIN_RELOAD_INTERVAL_MS);
+ }
+
+ public FileIDMigrator(File dataFile, long minReloadIntervalMS) throws FileNotFoundException {
+ longToString = new FastByIDMap<>(100);
+ this.dataFile = Preconditions.checkNotNull(dataFile);
+ if (!dataFile.exists() || dataFile.isDirectory()) {
+ throw new FileNotFoundException(dataFile.toString());
+ }
+
+ log.info("Creating FileReadonlyIDMigrator for file {}", dataFile);
+
+ this.reloadLock = new ReentrantLock();
+ this.lastModified = dataFile.lastModified();
+ this.minReloadIntervalMS = minReloadIntervalMS;
+
+ reload();
+ }
+
+ @Override
+ public String toStringID(long longID) {
+ return longToString.get(longID);
+ }
+
+ private void reload() {
+ if (reloadLock.tryLock()) {
+ try {
+ longToString = buildMapping();
+ } catch (IOException ioe) {
+ throw new IllegalStateException(ioe);
+ } finally {
+ reloadLock.unlock();
+ }
+ }
+ }
+
+ private FastByIDMap<String> buildMapping() throws IOException {
+ FastByIDMap<String> mapping = new FastByIDMap<>();
+ for (String line : new FileLineIterable(dataFile)) {
+ mapping.put(toLongID(line), line);
+ }
+ lastModified = dataFile.lastModified();
+ return mapping;
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ if (dataFile.lastModified() > lastModified + minReloadIntervalMS) {
+ log.debug("File has changed; reloading...");
+ reload();
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "FileIDMigrator[dataFile:" + dataFile + ']';
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/AbstractUserNeighborhood.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/AbstractUserNeighborhood.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/AbstractUserNeighborhood.java
new file mode 100644
index 0000000..8d33f60
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/AbstractUserNeighborhood.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.neighborhood;
+
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * Contains methods and resources useful to all classes in this package.
+ * </p>
+ */
+abstract class AbstractUserNeighborhood implements UserNeighborhood {
+
+ private final UserSimilarity userSimilarity;
+ private final DataModel dataModel;
+ private final double samplingRate;
+ private final RefreshHelper refreshHelper;
+
+ AbstractUserNeighborhood(UserSimilarity userSimilarity, DataModel dataModel, double samplingRate) {
+ Preconditions.checkArgument(userSimilarity != null, "userSimilarity is null");
+ Preconditions.checkArgument(dataModel != null, "dataModel is null");
+ Preconditions.checkArgument(samplingRate > 0.0 && samplingRate <= 1.0, "samplingRate must be in (0,1]");
+ this.userSimilarity = userSimilarity;
+ this.dataModel = dataModel;
+ this.samplingRate = samplingRate;
+ this.refreshHelper = new RefreshHelper(null);
+ this.refreshHelper.addDependency(this.dataModel);
+ this.refreshHelper.addDependency(this.userSimilarity);
+ }
+
+ final UserSimilarity getUserSimilarity() {
+ return userSimilarity;
+ }
+
+ final DataModel getDataModel() {
+ return dataModel;
+ }
+
+ final double getSamplingRate() {
+ return samplingRate;
+ }
+
+ @Override
+ public final void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/CachingUserNeighborhood.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/CachingUserNeighborhood.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/CachingUserNeighborhood.java
new file mode 100644
index 0000000..998e476
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/CachingUserNeighborhood.java
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.neighborhood;
+
+import java.util.Collection;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.Cache;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.impl.common.Retriever;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
+
+import com.google.common.base.Preconditions;
+
+/** A caching wrapper around an underlying {@link UserNeighborhood} implementation. */
+public final class CachingUserNeighborhood implements UserNeighborhood {
+
+ private final UserNeighborhood neighborhood;
+ private final Cache<Long,long[]> neighborhoodCache;
+
+ public CachingUserNeighborhood(UserNeighborhood neighborhood, DataModel dataModel) throws TasteException {
+ Preconditions.checkArgument(neighborhood != null, "neighborhood is null");
+ this.neighborhood = neighborhood;
+ int maxCacheSize = dataModel.getNumUsers(); // just a dumb heuristic for sizing
+ this.neighborhoodCache = new Cache<>(new NeighborhoodRetriever(neighborhood), maxCacheSize);
+ }
+
+ @Override
+ public long[] getUserNeighborhood(long userID) throws TasteException {
+ return neighborhoodCache.get(userID);
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ neighborhoodCache.clear();
+ Collection<Refreshable> refreshed = RefreshHelper.buildRefreshed(alreadyRefreshed);
+ RefreshHelper.maybeRefresh(refreshed, neighborhood);
+ }
+
+ private static final class NeighborhoodRetriever implements Retriever<Long,long[]> {
+ private final UserNeighborhood neighborhood;
+
+ private NeighborhoodRetriever(UserNeighborhood neighborhood) {
+ this.neighborhood = neighborhood;
+ }
+
+ @Override
+ public long[] get(Long key) throws TasteException {
+ return neighborhood.getUserNeighborhood(key);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNUserNeighborhood.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNUserNeighborhood.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNUserNeighborhood.java
new file mode 100644
index 0000000..7f3a98a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/NearestNUserNeighborhood.java
@@ -0,0 +1,122 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.neighborhood;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.SamplingLongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.recommender.TopItems;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * Computes a neighborhood consisting of the nearest n users to a given user. "Nearest" is defined by the
+ * given {@link UserSimilarity}.
+ * </p>
+ */
+public final class NearestNUserNeighborhood extends AbstractUserNeighborhood {
+
+ private final int n;
+ private final double minSimilarity;
+
+ /**
+ * @param n neighborhood size; capped at the number of users in the data model
+ * @throws IllegalArgumentException
+ * if {@code n < 1}, or userSimilarity or dataModel are {@code null}
+ */
+ public NearestNUserNeighborhood(int n, UserSimilarity userSimilarity, DataModel dataModel) throws TasteException {
+ this(n, Double.NEGATIVE_INFINITY, userSimilarity, dataModel, 1.0);
+ }
+
+ /**
+ * @param n neighborhood size; capped at the number of users in the data model
+ * @param minSimilarity minimal similarity required for neighbors
+ * @throws IllegalArgumentException
+ * if {@code n < 1}, or userSimilarity or dataModel are {@code null}
+ */
+ public NearestNUserNeighborhood(int n,
+ double minSimilarity,
+ UserSimilarity userSimilarity,
+ DataModel dataModel) throws TasteException {
+ this(n, minSimilarity, userSimilarity, dataModel, 1.0);
+ }
+
+ /**
+ * @param n neighborhood size; capped at the number of users in the data model
+ * @param minSimilarity minimal similarity required for neighbors
+ * @param samplingRate percentage of users to consider when building neighborhood -- decrease to trade quality for
+ * performance
+ * @throws IllegalArgumentException
+ * if {@code n < 1} or samplingRate is NaN or not in (0,1], or userSimilarity or dataModel are
+ * {@code null}
+ */
+ public NearestNUserNeighborhood(int n,
+ double minSimilarity,
+ UserSimilarity userSimilarity,
+ DataModel dataModel,
+ double samplingRate) throws TasteException {
+ super(userSimilarity, dataModel, samplingRate);
+ Preconditions.checkArgument(n >= 1, "n must be at least 1");
+ int numUsers = dataModel.getNumUsers();
+ this.n = n > numUsers ? numUsers : n;
+ this.minSimilarity = minSimilarity;
+ }
+
+ @Override
+ public long[] getUserNeighborhood(long userID) throws TasteException {
+
+ DataModel dataModel = getDataModel();
+ UserSimilarity userSimilarityImpl = getUserSimilarity();
+
+ TopItems.Estimator<Long> estimator = new Estimator(userSimilarityImpl, userID, minSimilarity);
+
+ LongPrimitiveIterator userIDs = SamplingLongPrimitiveIterator.maybeWrapIterator(dataModel.getUserIDs(),
+ getSamplingRate());
+
+ return TopItems.getTopUsers(n, userIDs, null, estimator);
+ }
+
+ @Override
+ public String toString() {
+ return "NearestNUserNeighborhood";
+ }
+
+ private static final class Estimator implements TopItems.Estimator<Long> {
+ private final UserSimilarity userSimilarityImpl;
+ private final long theUserID;
+ private final double minSim;
+
+ private Estimator(UserSimilarity userSimilarityImpl, long theUserID, double minSim) {
+ this.userSimilarityImpl = userSimilarityImpl;
+ this.theUserID = theUserID;
+ this.minSim = minSim;
+ }
+
+ @Override
+ public double estimate(Long userID) throws TasteException {
+ if (userID == theUserID) {
+ return Double.NaN;
+ }
+ double sim = userSimilarityImpl.userSimilarity(theUserID, userID);
+ return sim >= minSim ? sim : Double.NaN;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdUserNeighborhood.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdUserNeighborhood.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdUserNeighborhood.java
new file mode 100644
index 0000000..d5246e4
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/neighborhood/ThresholdUserNeighborhood.java
@@ -0,0 +1,104 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.neighborhood;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.SamplingLongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * Computes a neigbhorhood consisting of all users whose similarity to the given user meets or exceeds a
+ * certain threshold. Similarity is defined by the given {@link UserSimilarity}.
+ * </p>
+ */
+public final class ThresholdUserNeighborhood extends AbstractUserNeighborhood {
+
+ private final double threshold;
+
+ /**
+ * @param threshold
+ * similarity threshold
+ * @param userSimilarity
+ * similarity metric
+ * @param dataModel
+ * data model
+ * @throws IllegalArgumentException
+ * if threshold is {@link Double#NaN}, or if samplingRate is not positive and less than or equal
+ * to 1.0, or if userSimilarity or dataModel are {@code null}
+ */
+ public ThresholdUserNeighborhood(double threshold, UserSimilarity userSimilarity, DataModel dataModel) {
+ this(threshold, userSimilarity, dataModel, 1.0);
+ }
+
+ /**
+ * @param threshold
+ * similarity threshold
+ * @param userSimilarity
+ * similarity metric
+ * @param dataModel
+ * data model
+ * @param samplingRate
+ * percentage of users to consider when building neighborhood -- decrease to trade quality for
+ * performance
+ * @throws IllegalArgumentException
+ * if threshold or samplingRate is {@link Double#NaN}, or if samplingRate is not positive and less
+ * than or equal to 1.0, or if userSimilarity or dataModel are {@code null}
+ */
+ public ThresholdUserNeighborhood(double threshold,
+ UserSimilarity userSimilarity,
+ DataModel dataModel,
+ double samplingRate) {
+ super(userSimilarity, dataModel, samplingRate);
+ Preconditions.checkArgument(!Double.isNaN(threshold), "threshold must not be NaN");
+ this.threshold = threshold;
+ }
+
+ @Override
+ public long[] getUserNeighborhood(long userID) throws TasteException {
+
+ DataModel dataModel = getDataModel();
+ FastIDSet neighborhood = new FastIDSet();
+ LongPrimitiveIterator usersIterable = SamplingLongPrimitiveIterator.maybeWrapIterator(dataModel
+ .getUserIDs(), getSamplingRate());
+ UserSimilarity userSimilarityImpl = getUserSimilarity();
+
+ while (usersIterable.hasNext()) {
+ long otherUserID = usersIterable.next();
+ if (userID != otherUserID) {
+ double theSimilarity = userSimilarityImpl.userSimilarity(userID, otherUserID);
+ if (!Double.isNaN(theSimilarity) && theSimilarity >= threshold) {
+ neighborhood.add(otherUserID);
+ }
+ }
+ }
+
+ return neighborhood.toArray();
+ }
+
+ @Override
+ public String toString() {
+ return "ThresholdUserNeighborhood";
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AbstractCandidateItemsStrategy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AbstractCandidateItemsStrategy.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AbstractCandidateItemsStrategy.java
new file mode 100644
index 0000000..d24ea6a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AbstractCandidateItemsStrategy.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import org.apache.mahout.cf.taste.recommender.MostSimilarItemsCandidateItemsStrategy;
+
+import java.util.Collection;
+
+/**
+ * Abstract base implementation for retrieving candidate items to recommend
+ */
+public abstract class AbstractCandidateItemsStrategy implements CandidateItemsStrategy,
+ MostSimilarItemsCandidateItemsStrategy {
+
+ protected FastIDSet doGetCandidateItems(long[] preferredItemIDs, DataModel dataModel) throws TasteException{
+ return doGetCandidateItems(preferredItemIDs, dataModel, false);
+ }
+
+ @Override
+ public FastIDSet getCandidateItems(long userID, PreferenceArray preferencesFromUser, DataModel dataModel,
+ boolean includeKnownItems) throws TasteException {
+ return doGetCandidateItems(preferencesFromUser.getIDs(), dataModel, includeKnownItems);
+ }
+
+ @Override
+ public FastIDSet getCandidateItems(long[] itemIDs, DataModel dataModel)
+ throws TasteException {
+ return doGetCandidateItems(itemIDs, dataModel, false);
+ }
+
+ protected abstract FastIDSet doGetCandidateItems(long[] preferredItemIDs, DataModel dataModel,
+ boolean includeKnownItems) throws TasteException;
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {}
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AbstractRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AbstractRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AbstractRecommender.java
new file mode 100644
index 0000000..3a62b08
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AbstractRecommender.java
@@ -0,0 +1,140 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+
+import java.util.List;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+public abstract class AbstractRecommender implements Recommender {
+
+ private static final Logger log = LoggerFactory.getLogger(AbstractRecommender.class);
+
+ private final DataModel dataModel;
+ private final CandidateItemsStrategy candidateItemsStrategy;
+
+ protected AbstractRecommender(DataModel dataModel, CandidateItemsStrategy candidateItemsStrategy) {
+ this.dataModel = Preconditions.checkNotNull(dataModel);
+ this.candidateItemsStrategy = Preconditions.checkNotNull(candidateItemsStrategy);
+ }
+
+ protected AbstractRecommender(DataModel dataModel) {
+ this(dataModel, getDefaultCandidateItemsStrategy());
+ }
+
+ protected static CandidateItemsStrategy getDefaultCandidateItemsStrategy() {
+ return new PreferredItemsNeighborhoodCandidateItemsStrategy();
+ }
+
+
+ /**
+ * <p>
+ * Default implementation which just calls
+ * {@link Recommender#recommend(long, int, org.apache.mahout.cf.taste.recommender.IDRescorer)}, with a
+ * {@link org.apache.mahout.cf.taste.recommender.Rescorer} that does nothing.
+ * </p>
+ */
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
+ return recommend(userID, howMany, null, false);
+ }
+
+ /**
+ * <p>
+ * Default implementation which just calls
+ * {@link Recommender#recommend(long, int, org.apache.mahout.cf.taste.recommender.IDRescorer)}, with a
+ * {@link org.apache.mahout.cf.taste.recommender.Rescorer} that does nothing.
+ * </p>
+ */
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
+ return recommend(userID, howMany, null, includeKnownItems);
+ }
+
+ /**
+ * <p> Delegates to {@link Recommender#recommend(long, int, IDRescorer, boolean)}
+ */
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) throws TasteException{
+ return recommend(userID, howMany,rescorer, false);
+ }
+
+ /**
+ * <p>
+ * Default implementation which just calls {@link DataModel#setPreference(long, long, float)}.
+ * </p>
+ *
+ * @throws IllegalArgumentException
+ * if userID or itemID is {@code null}, or if value is {@link Double#NaN}
+ */
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ Preconditions.checkArgument(!Float.isNaN(value), "NaN value");
+ log.debug("Setting preference for user {}, item {}", userID, itemID);
+ dataModel.setPreference(userID, itemID, value);
+ }
+
+ /**
+ * <p>
+ * Default implementation which just calls {@link DataModel#removePreference(long, long)} (Object, Object)}.
+ * </p>
+ *
+ * @throws IllegalArgumentException
+ * if userID or itemID is {@code null}
+ */
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ log.debug("Remove preference for user '{}', item '{}'", userID, itemID);
+ dataModel.removePreference(userID, itemID);
+ }
+
+ @Override
+ public DataModel getDataModel() {
+ return dataModel;
+ }
+
+ /**
+ * @param userID
+ * ID of user being evaluated
+ * @param preferencesFromUser
+ * the preferences from the user
+ * @param includeKnownItems
+ * whether to include items already known by the user in recommendations
+ * @return all items in the {@link DataModel} for which the user has not expressed a preference and could
+ * possibly be recommended to the user
+ * @throws TasteException
+ * if an error occurs while listing items
+ */
+ protected FastIDSet getAllOtherItems(long userID, PreferenceArray preferencesFromUser, boolean includeKnownItems)
+ throws TasteException {
+ return candidateItemsStrategy.getCandidateItems(userID, preferencesFromUser, dataModel, includeKnownItems);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AllSimilarItemsCandidateItemsStrategy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AllSimilarItemsCandidateItemsStrategy.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AllSimilarItemsCandidateItemsStrategy.java
new file mode 100644
index 0000000..37389a7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AllSimilarItemsCandidateItemsStrategy.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import com.google.common.base.Preconditions;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+
+/**
+ * returns the result of {@link ItemSimilarity#allSimilarItemIDs(long)} as candidate items
+ */
+public class AllSimilarItemsCandidateItemsStrategy extends AbstractCandidateItemsStrategy {
+
+ private final ItemSimilarity similarity;
+
+ public AllSimilarItemsCandidateItemsStrategy(ItemSimilarity similarity) {
+ Preconditions.checkArgument(similarity != null, "similarity is null");
+ this.similarity = similarity;
+ }
+
+ @Override
+ protected FastIDSet doGetCandidateItems(long[] preferredItemIDs, DataModel dataModel, boolean includeKnownItems)
+ throws TasteException {
+ FastIDSet candidateItemIDs = new FastIDSet();
+ for (long itemID : preferredItemIDs) {
+ candidateItemIDs.addAll(similarity.allSimilarItemIDs(itemID));
+ }
+ if (!includeKnownItems) {
+ candidateItemIDs.removeAll(preferredItemIDs);
+ }
+ return candidateItemIDs;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AllUnknownItemsCandidateItemsStrategy.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AllUnknownItemsCandidateItemsStrategy.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AllUnknownItemsCandidateItemsStrategy.java
new file mode 100644
index 0000000..929eddd
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/AllUnknownItemsCandidateItemsStrategy.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.model.DataModel;
+
+public final class AllUnknownItemsCandidateItemsStrategy extends AbstractCandidateItemsStrategy {
+
+ /** return all items the user has not yet seen */
+ @Override
+ protected FastIDSet doGetCandidateItems(long[] preferredItemIDs, DataModel dataModel, boolean includeKnownItems)
+ throws TasteException {
+ FastIDSet possibleItemIDs = new FastIDSet(dataModel.getNumItems());
+ LongPrimitiveIterator allItemIDs = dataModel.getItemIDs();
+ while (allItemIDs.hasNext()) {
+ possibleItemIDs.add(allItemIDs.nextLong());
+ }
+ if (!includeKnownItems) {
+ possibleItemIDs.removeAll(preferredItemIDs);
+ }
+ return possibleItemIDs;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ByRescoreComparator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ByRescoreComparator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ByRescoreComparator.java
new file mode 100644
index 0000000..1677ea8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ByRescoreComparator.java
@@ -0,0 +1,65 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.io.Serializable;
+import java.util.Comparator;
+
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+
+/**
+ * <p>
+ * Defines ordering on {@link RecommendedItem} by the rescored value of the recommendations' estimated
+ * preference value, from high to low.
+ * </p>
+ */
+final class ByRescoreComparator implements Comparator<RecommendedItem>, Serializable {
+
+ private final IDRescorer rescorer;
+
+ ByRescoreComparator(IDRescorer rescorer) {
+ this.rescorer = rescorer;
+ }
+
+ @Override
+ public int compare(RecommendedItem o1, RecommendedItem o2) {
+ double rescored1;
+ double rescored2;
+ if (rescorer == null) {
+ rescored1 = o1.getValue();
+ rescored2 = o2.getValue();
+ } else {
+ rescored1 = rescorer.rescore(o1.getItemID(), o1.getValue());
+ rescored2 = rescorer.rescore(o2.getItemID(), o2.getValue());
+ }
+ if (rescored1 < rescored2) {
+ return 1;
+ } else if (rescored1 > rescored2) {
+ return -1;
+ } else {
+ return 0;
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "ByRescoreComparator[rescorer:" + rescorer + ']';
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ByValueRecommendedItemComparator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ByValueRecommendedItemComparator.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ByValueRecommendedItemComparator.java
new file mode 100644
index 0000000..57c5f3d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/ByValueRecommendedItemComparator.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.io.Serializable;
+import java.util.Comparator;
+
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+
+/**
+ * Defines a natural ordering from most-preferred item (highest value) to least-preferred.
+ */
+public final class ByValueRecommendedItemComparator implements Comparator<RecommendedItem>, Serializable {
+
+ private static final Comparator<RecommendedItem> INSTANCE = new ByValueRecommendedItemComparator();
+
+ public static Comparator<RecommendedItem> getInstance() {
+ return INSTANCE;
+ }
+
+ @Override
+ public int compare(RecommendedItem o1, RecommendedItem o2) {
+ float value1 = o1.getValue();
+ float value2 = o2.getValue();
+ return value1 > value2 ? -1 : value1 < value2 ? 1 : 0;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/CachingRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/CachingRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/CachingRecommender.java
new file mode 100644
index 0000000..7ed8cc3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/CachingRecommender.java
@@ -0,0 +1,251 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Callable;
+
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.common.Cache;
+import org.apache.mahout.cf.taste.impl.common.RefreshHelper;
+import org.apache.mahout.cf.taste.impl.common.Retriever;
+import org.apache.mahout.cf.taste.impl.model.PlusAnonymousUserDataModel;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.common.LongPair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * <p>
+ * A {@link Recommender} which caches the results from another {@link Recommender} in memory.
+ *
+ * TODO: Should be checked for thread safety
+ * </p>
+ */
+public final class CachingRecommender implements Recommender {
+
+ private static final Logger log = LoggerFactory.getLogger(CachingRecommender.class);
+
+ private final Recommender recommender;
+ private final int[] maxHowMany;
+ private final Retriever<Long,Recommendations> recommendationsRetriever;
+ private final Cache<Long,Recommendations> recommendationCache;
+ private final Cache<LongPair,Float> estimatedPrefCache;
+ private final RefreshHelper refreshHelper;
+ private IDRescorer currentRescorer;
+ private boolean currentlyIncludeKnownItems;
+
+ public CachingRecommender(Recommender recommender) throws TasteException {
+ Preconditions.checkArgument(recommender != null, "recommender is null");
+ this.recommender = recommender;
+ maxHowMany = new int[]{1};
+ // Use "num users" as an upper limit on cache size. Rough guess.
+ int numUsers = recommender.getDataModel().getNumUsers();
+ recommendationsRetriever = new RecommendationRetriever();
+ recommendationCache = new Cache<>(recommendationsRetriever, numUsers);
+ estimatedPrefCache = new Cache<>(new EstimatedPrefRetriever(), numUsers);
+ refreshHelper = new RefreshHelper(new Callable<Object>() {
+ @Override
+ public Object call() {
+ clear();
+ return null;
+ }
+ });
+ refreshHelper.addDependency(recommender);
+ }
+
+ private void setCurrentRescorer(IDRescorer rescorer) {
+ if (rescorer == null) {
+ if (currentRescorer != null) {
+ currentRescorer = null;
+ clear();
+ }
+ } else {
+ if (!rescorer.equals(currentRescorer)) {
+ currentRescorer = rescorer;
+ clear();
+ }
+ }
+ }
+
+ public void setCurrentlyIncludeKnownItems(boolean currentlyIncludeKnownItems) {
+ this.currentlyIncludeKnownItems = currentlyIncludeKnownItems;
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany) throws TasteException {
+ return recommend(userID, howMany, null, false);
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) throws TasteException {
+ return recommend(userID, howMany, null, includeKnownItems);
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany,IDRescorer rescorer) throws TasteException {
+ return recommend(userID, howMany, rescorer, false);
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany,IDRescorer rescorer, boolean includeKnownItems)
+ throws TasteException {
+ Preconditions.checkArgument(howMany >= 1, "howMany must be at least 1");
+ synchronized (maxHowMany) {
+ if (howMany > maxHowMany[0]) {
+ maxHowMany[0] = howMany;
+ }
+ }
+
+ // Special case, avoid caching an anonymous user
+ if (userID == PlusAnonymousUserDataModel.TEMP_USER_ID) {
+ return recommendationsRetriever.get(PlusAnonymousUserDataModel.TEMP_USER_ID).getItems();
+ }
+
+ setCurrentRescorer(rescorer);
+ setCurrentlyIncludeKnownItems(includeKnownItems);
+
+ Recommendations recommendations = recommendationCache.get(userID);
+ if (recommendations.getItems().size() < howMany && !recommendations.isNoMoreRecommendableItems()) {
+ clear(userID);
+ recommendations = recommendationCache.get(userID);
+ if (recommendations.getItems().size() < howMany) {
+ recommendations.setNoMoreRecommendableItems(true);
+ }
+ }
+
+ List<RecommendedItem> recommendedItems = recommendations.getItems();
+ return recommendedItems.size() > howMany ? recommendedItems.subList(0, howMany) : recommendedItems;
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) throws TasteException {
+ return estimatedPrefCache.get(new LongPair(userID, itemID));
+ }
+
+ @Override
+ public void setPreference(long userID, long itemID, float value) throws TasteException {
+ recommender.setPreference(userID, itemID, value);
+ clear(userID);
+ }
+
+ @Override
+ public void removePreference(long userID, long itemID) throws TasteException {
+ recommender.removePreference(userID, itemID);
+ clear(userID);
+ }
+
+ @Override
+ public DataModel getDataModel() {
+ return recommender.getDataModel();
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {
+ refreshHelper.refresh(alreadyRefreshed);
+ }
+
+ /**
+ * <p>
+ * Clears cached recommendations for the given user.
+ * </p>
+ *
+ * @param userID
+ * clear cached data associated with this user ID
+ */
+ public void clear(final long userID) {
+ log.debug("Clearing recommendations for user ID '{}'", userID);
+ recommendationCache.remove(userID);
+ estimatedPrefCache.removeKeysMatching(new Cache.MatchPredicate<LongPair>() {
+ @Override
+ public boolean matches(LongPair userItemPair) {
+ return userItemPair.getFirst() == userID;
+ }
+ });
+ }
+
+ /**
+ * <p>
+ * Clears all cached recommendations.
+ * </p>
+ */
+ public void clear() {
+ log.debug("Clearing all recommendations...");
+ recommendationCache.clear();
+ estimatedPrefCache.clear();
+ }
+
+ @Override
+ public String toString() {
+ return "CachingRecommender[recommender:" + recommender + ']';
+ }
+
+ private final class RecommendationRetriever implements Retriever<Long,Recommendations> {
+ @Override
+ public Recommendations get(Long key) throws TasteException {
+ log.debug("Retrieving new recommendations for user ID '{}'", key);
+ int howMany = maxHowMany[0];
+ IDRescorer rescorer = currentRescorer;
+ List<RecommendedItem> recommendations =
+ rescorer == null ? recommender.recommend(key, howMany, null, currentlyIncludeKnownItems) :
+ recommender.recommend(key, howMany, rescorer, currentlyIncludeKnownItems);
+ return new Recommendations(Collections.unmodifiableList(recommendations));
+ }
+ }
+
+ private final class EstimatedPrefRetriever implements Retriever<LongPair,Float> {
+ @Override
+ public Float get(LongPair key) throws TasteException {
+ long userID = key.getFirst();
+ long itemID = key.getSecond();
+ log.debug("Retrieving estimated preference for user ID '{}' and item ID '{}'", userID, itemID);
+ return recommender.estimatePreference(userID, itemID);
+ }
+ }
+
+ private static final class Recommendations {
+
+ private final List<RecommendedItem> items;
+ private boolean noMoreRecommendableItems;
+
+ private Recommendations(List<RecommendedItem> items) {
+ this.items = items;
+ }
+
+ List<RecommendedItem> getItems() {
+ return items;
+ }
+
+ boolean isNoMoreRecommendableItems() {
+ return noMoreRecommendableItems;
+ }
+
+ void setNoMoreRecommendableItems(boolean noMoreRecommendableItems) {
+ this.noMoreRecommendableItems = noMoreRecommendableItems;
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/EstimatedPreferenceCapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/EstimatedPreferenceCapper.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/EstimatedPreferenceCapper.java
new file mode 100644
index 0000000..f0f389f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/EstimatedPreferenceCapper.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.model.DataModel;
+
+/**
+ * Simple class which encapsulates restricting a preference value
+ * to a predefined range. The simple logic is wrapped up here for
+ * performance reasons.
+ */
+public final class EstimatedPreferenceCapper {
+
+ private final float min;
+ private final float max;
+
+ public EstimatedPreferenceCapper(DataModel model) {
+ min = model.getMinPreference();
+ max = model.getMaxPreference();
+ }
+
+ public float capEstimate(float estimate) {
+ if (estimate > max) {
+ estimate = max;
+ } else if (estimate < min) {
+ estimate = min;
+ }
+ return estimate;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericBooleanPrefItemBasedRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericBooleanPrefItemBasedRecommender.java b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericBooleanPrefItemBasedRecommender.java
new file mode 100644
index 0000000..40e21a3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/cf/taste/impl/recommender/GenericBooleanPrefItemBasedRecommender.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import org.apache.mahout.cf.taste.recommender.MostSimilarItemsCandidateItemsStrategy;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+
+/**
+ * A variant on {@link GenericItemBasedRecommender} which is appropriate for use when no notion of preference
+ * value exists in the data.
+ *
+ * @see org.apache.mahout.cf.taste.impl.recommender.GenericBooleanPrefUserBasedRecommender
+ */
+public final class GenericBooleanPrefItemBasedRecommender extends GenericItemBasedRecommender {
+
+ public GenericBooleanPrefItemBasedRecommender(DataModel dataModel, ItemSimilarity similarity) {
+ super(dataModel, similarity);
+ }
+
+ public GenericBooleanPrefItemBasedRecommender(DataModel dataModel, ItemSimilarity similarity,
+ CandidateItemsStrategy candidateItemsStrategy, MostSimilarItemsCandidateItemsStrategy
+ mostSimilarItemsCandidateItemsStrategy) {
+ super(dataModel, similarity, candidateItemsStrategy, mostSimilarItemsCandidateItemsStrategy);
+ }
+
+ /**
+ * This computation is in a technical sense, wrong, since in the domain of "boolean preference users" where
+ * all preference values are 1, this method should only ever return 1.0 or NaN. This isn't terribly useful
+ * however since it means results can't be ranked by preference value (all are 1). So instead this returns a
+ * sum of similarities.
+ */
+ @Override
+ protected float doEstimatePreference(long userID, PreferenceArray preferencesFromUser, long itemID)
+ throws TasteException {
+ double[] similarities = getSimilarity().itemSimilarities(itemID, preferencesFromUser.getIDs());
+ boolean foundAPref = false;
+ double totalSimilarity = 0.0;
+ for (double theSimilarity : similarities) {
+ if (!Double.isNaN(theSimilarity)) {
+ foundAPref = true;
+ totalSimilarity += theSimilarity;
+ }
+ }
+ return foundAPref ? (float) totalSimilarity : Float.NaN;
+ }
+
+ @Override
+ public String toString() {
+ return "GenericBooleanPrefItemBasedRecommender";
+ }
+
+}
[10/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/GenericItemBasedRecommenderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/GenericItemBasedRecommenderTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/GenericItemBasedRecommenderTest.java
new file mode 100644
index 0000000..16cbdca
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/GenericItemBasedRecommenderTest.java
@@ -0,0 +1,324 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import org.apache.mahout.cf.taste.recommender.ItemBasedRecommender;
+import org.apache.mahout.cf.taste.recommender.MostSimilarItemsCandidateItemsStrategy;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.cf.taste.similarity.ItemSimilarity;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+/** <p>Tests {@link GenericItemBasedRecommender}.</p> */
+public final class GenericItemBasedRecommenderTest extends TasteTestCase {
+
+ @Test
+ public void testRecommender() throws Exception {
+ Recommender recommender = buildRecommender();
+ List<RecommendedItem> recommended = recommender.recommend(1, 1);
+ assertNotNull(recommended);
+ assertEquals(1, recommended.size());
+ RecommendedItem firstRecommended = recommended.get(0);
+ assertEquals(2, firstRecommended.getItemID());
+ assertEquals(0.1f, firstRecommended.getValue(), EPSILON);
+ recommender.refresh(null);
+ recommended = recommender.recommend(1, 1);
+ firstRecommended = recommended.get(0);
+ assertEquals(2, firstRecommended.getItemID());
+ assertEquals(0.1f, firstRecommended.getValue(), EPSILON);
+ }
+
+ @Test
+ public void testHowMany() throws Exception {
+
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3, 4, 5},
+ new Double[][] {
+ {0.1, 0.2},
+ {0.2, 0.3, 0.3, 0.6},
+ {0.4, 0.4, 0.5, 0.9},
+ {0.1, 0.4, 0.5, 0.8, 0.9, 1.0},
+ {0.2, 0.3, 0.6, 0.7, 0.1, 0.2},
+ });
+
+ Collection<GenericItemSimilarity.ItemItemSimilarity> similarities = Lists.newArrayList();
+ for (int i = 0; i < 6; i++) {
+ for (int j = i + 1; j < 6; j++) {
+ similarities.add(
+ new GenericItemSimilarity.ItemItemSimilarity(i, j, 1.0 / (1.0 + i + j)));
+ }
+ }
+ ItemSimilarity similarity = new GenericItemSimilarity(similarities);
+ Recommender recommender = new GenericItemBasedRecommender(dataModel, similarity);
+ List<RecommendedItem> fewRecommended = recommender.recommend(1, 2);
+ List<RecommendedItem> moreRecommended = recommender.recommend(1, 4);
+ for (int i = 0; i < fewRecommended.size(); i++) {
+ assertEquals(fewRecommended.get(i).getItemID(), moreRecommended.get(i).getItemID());
+ }
+ recommender.refresh(null);
+ for (int i = 0; i < fewRecommended.size(); i++) {
+ assertEquals(fewRecommended.get(i).getItemID(), moreRecommended.get(i).getItemID());
+ }
+ }
+
+ @Test
+ public void testRescorer() throws Exception {
+
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {0.1, 0.2},
+ {0.2, 0.3, 0.3, 0.6},
+ {0.4, 0.4, 0.5, 0.9},
+ });
+
+ Collection<GenericItemSimilarity.ItemItemSimilarity> similarities = Lists.newArrayList();
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 1, 1.0));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 2, 0.5));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 3, 0.2));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 2, 0.7));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 3, 0.5));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(2, 3, 0.9));
+ ItemSimilarity similarity = new GenericItemSimilarity(similarities);
+ Recommender recommender = new GenericItemBasedRecommender(dataModel, similarity);
+ List<RecommendedItem> originalRecommended = recommender.recommend(1, 2);
+ List<RecommendedItem> rescoredRecommended =
+ recommender.recommend(1, 2, new ReversingRescorer<Long>());
+ assertNotNull(originalRecommended);
+ assertNotNull(rescoredRecommended);
+ assertEquals(2, originalRecommended.size());
+ assertEquals(2, rescoredRecommended.size());
+ assertEquals(originalRecommended.get(0).getItemID(), rescoredRecommended.get(1).getItemID());
+ assertEquals(originalRecommended.get(1).getItemID(), rescoredRecommended.get(0).getItemID());
+ }
+
+ @Test
+ public void testIncludeKnownItems() throws Exception {
+
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {0.1, 0.2},
+ {0.2, 0.3, 0.3, 0.6},
+ {0.4, 0.4, 0.5, 0.9},
+ });
+
+ Collection<GenericItemSimilarity.ItemItemSimilarity> similarities = Lists.newArrayList();
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 1, 0.8));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 2, 0.5));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 3, 0.2));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 2, 0.7));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 3, 0.5));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(2, 3, 0.9));
+ ItemSimilarity similarity = new GenericItemSimilarity(similarities);
+ Recommender recommender = new GenericItemBasedRecommender(dataModel, similarity);
+ List<RecommendedItem> originalRecommended = recommender.recommend(1, 4, null, true);
+ List<RecommendedItem> rescoredRecommended = recommender.recommend(1, 4, new ReversingRescorer<Long>(), true);
+ assertNotNull(originalRecommended);
+ assertNotNull(rescoredRecommended);
+ assertEquals(4, originalRecommended.size());
+ assertEquals(4, rescoredRecommended.size());
+ assertEquals(originalRecommended.get(0).getItemID(), rescoredRecommended.get(3).getItemID());
+ assertEquals(originalRecommended.get(3).getItemID(), rescoredRecommended.get(0).getItemID());
+ }
+
+ @Test
+ public void testEstimatePref() throws Exception {
+ Recommender recommender = buildRecommender();
+ assertEquals(0.1f, recommender.estimatePreference(1, 2), EPSILON);
+ }
+
+ /**
+ * Contributed test case that verifies fix for bug
+ * <a href="http://sourceforge.net/tracker/index.php?func=detail&aid=1396128&group_id=138771&atid=741665">
+ * 1396128</a>.
+ */
+ @Test
+ public void testBestRating() throws Exception {
+ Recommender recommender = buildRecommender();
+ List<RecommendedItem> recommended = recommender.recommend(1, 1);
+ assertNotNull(recommended);
+ assertEquals(1, recommended.size());
+ RecommendedItem firstRecommended = recommended.get(0);
+ // item one should be recommended because it has a greater rating/score
+ assertEquals(2, firstRecommended.getItemID());
+ assertEquals(0.1f, firstRecommended.getValue(), EPSILON);
+ }
+
+ @Test
+ public void testMostSimilar() throws Exception {
+ ItemBasedRecommender recommender = buildRecommender();
+ List<RecommendedItem> similar = recommender.mostSimilarItems(0, 2);
+ assertNotNull(similar);
+ assertEquals(2, similar.size());
+ RecommendedItem first = similar.get(0);
+ RecommendedItem second = similar.get(1);
+ assertEquals(1, first.getItemID());
+ assertEquals(1.0f, first.getValue(), EPSILON);
+ assertEquals(2, second.getItemID());
+ assertEquals(0.5f, second.getValue(), EPSILON);
+ }
+
+ @Test
+ public void testMostSimilarToMultiple() throws Exception {
+ ItemBasedRecommender recommender = buildRecommender2();
+ List<RecommendedItem> similar = recommender.mostSimilarItems(new long[] {0, 1}, 2);
+ assertNotNull(similar);
+ assertEquals(2, similar.size());
+ RecommendedItem first = similar.get(0);
+ RecommendedItem second = similar.get(1);
+ assertEquals(2, first.getItemID());
+ assertEquals(0.85f, first.getValue(), EPSILON);
+ assertEquals(3, second.getItemID());
+ assertEquals(-0.3f, second.getValue(), EPSILON);
+ }
+
+ @Test
+ public void testMostSimilarToMultipleExcludeIfNotSimilarToAll() throws Exception {
+ ItemBasedRecommender recommender = buildRecommender2();
+ List<RecommendedItem> similar = recommender.mostSimilarItems(new long[] {3, 4}, 2);
+ assertNotNull(similar);
+ assertEquals(1, similar.size());
+ RecommendedItem first = similar.get(0);
+ assertEquals(0, first.getItemID());
+ assertEquals(0.2f, first.getValue(), EPSILON);
+ }
+
+ @Test
+ public void testMostSimilarToMultipleDontExcludeIfNotSimilarToAll() throws Exception {
+ ItemBasedRecommender recommender = buildRecommender2();
+ List<RecommendedItem> similar = recommender.mostSimilarItems(new long[] {1, 2, 4}, 10, false);
+ assertNotNull(similar);
+ assertEquals(2, similar.size());
+ RecommendedItem first = similar.get(0);
+ RecommendedItem second = similar.get(1);
+ assertEquals(0, first.getItemID());
+ assertEquals(0.933333333f, first.getValue(), EPSILON);
+ assertEquals(3, second.getItemID());
+ assertEquals(-0.2f, second.getValue(), EPSILON);
+ }
+
+
+ @Test
+ public void testRecommendedBecause() throws Exception {
+ ItemBasedRecommender recommender = buildRecommender2();
+ List<RecommendedItem> recommendedBecause = recommender.recommendedBecause(1, 4, 3);
+ assertNotNull(recommendedBecause);
+ assertEquals(3, recommendedBecause.size());
+ RecommendedItem first = recommendedBecause.get(0);
+ RecommendedItem second = recommendedBecause.get(1);
+ RecommendedItem third = recommendedBecause.get(2);
+ assertEquals(2, first.getItemID());
+ assertEquals(0.99f, first.getValue(), EPSILON);
+ assertEquals(3, second.getItemID());
+ assertEquals(0.4f, second.getValue(), EPSILON);
+ assertEquals(0, third.getItemID());
+ assertEquals(0.2f, third.getValue(), EPSILON);
+ }
+
+ private static ItemBasedRecommender buildRecommender() {
+ DataModel dataModel = getDataModel();
+ Collection<GenericItemSimilarity.ItemItemSimilarity> similarities = Lists.newArrayList();
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 1, 1.0));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 2, 0.5));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 2, 0.0));
+ ItemSimilarity similarity = new GenericItemSimilarity(similarities);
+ return new GenericItemBasedRecommender(dataModel, similarity);
+ }
+
+ private static ItemBasedRecommender buildRecommender2() {
+
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3, 4},
+ new Double[][] {
+ {0.1, 0.3, 0.9, 0.8},
+ {0.2, 0.3, 0.3, 0.4},
+ {0.4, 0.3, 0.5, 0.1, 0.1},
+ {0.7, 0.3, 0.8, 0.5, 0.6},
+ });
+
+ Collection<GenericItemSimilarity.ItemItemSimilarity> similarities = Lists.newArrayList();
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 1, 1.0));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 2, 0.8));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 3, -0.6));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(0, 4, 1.0));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 2, 0.9));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 3, 0.0));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(1, 1, 1.0));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(2, 3, -0.1));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(2, 4, 0.1));
+ similarities.add(new GenericItemSimilarity.ItemItemSimilarity(3, 4, -0.5));
+ ItemSimilarity similarity = new GenericItemSimilarity(similarities);
+ return new GenericItemBasedRecommender(dataModel, similarity);
+ }
+
+
+ /**
+ * we're making sure that a user's preferences are fetched only once from the {@link DataModel} for one call to
+ * {@link GenericItemBasedRecommender#recommend(long, int)}
+ *
+ * @throws Exception
+ */
+ @Test
+ public void preferencesFetchedOnlyOnce() throws Exception {
+
+ DataModel dataModel = EasyMock.createMock(DataModel.class);
+ ItemSimilarity itemSimilarity = EasyMock.createMock(ItemSimilarity.class);
+ CandidateItemsStrategy candidateItemsStrategy = EasyMock.createMock(CandidateItemsStrategy.class);
+ MostSimilarItemsCandidateItemsStrategy mostSimilarItemsCandidateItemsStrategy =
+ EasyMock.createMock(MostSimilarItemsCandidateItemsStrategy.class);
+
+ PreferenceArray preferencesFromUser = new GenericUserPreferenceArray(
+ Arrays.asList(new GenericPreference(1L, 1L, 5.0f), new GenericPreference(1L, 2L, 4.0f)));
+
+ EasyMock.expect(dataModel.getMinPreference()).andReturn(Float.NaN);
+ EasyMock.expect(dataModel.getMaxPreference()).andReturn(Float.NaN);
+
+ EasyMock.expect(dataModel.getPreferencesFromUser(1L)).andReturn(preferencesFromUser);
+ EasyMock.expect(candidateItemsStrategy.getCandidateItems(1L, preferencesFromUser, dataModel, false))
+ .andReturn(new FastIDSet(new long[] { 3L, 4L }));
+
+ EasyMock.expect(itemSimilarity.itemSimilarities(3L, preferencesFromUser.getIDs()))
+ .andReturn(new double[] { 0.5, 0.3 });
+ EasyMock.expect(itemSimilarity.itemSimilarities(4L, preferencesFromUser.getIDs()))
+ .andReturn(new double[] { 0.4, 0.1 });
+
+ EasyMock.replay(dataModel, itemSimilarity, candidateItemsStrategy, mostSimilarItemsCandidateItemsStrategy);
+
+ Recommender recommender = new GenericItemBasedRecommender(dataModel, itemSimilarity,
+ candidateItemsStrategy, mostSimilarItemsCandidateItemsStrategy);
+
+ recommender.recommend(1L, 3);
+
+ EasyMock.verify(dataModel, itemSimilarity, candidateItemsStrategy, mostSimilarItemsCandidateItemsStrategy);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/GenericUserBasedRecommenderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/GenericUserBasedRecommenderTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/GenericUserBasedRecommenderTest.java
new file mode 100644
index 0000000..121cd1a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/GenericUserBasedRecommenderTest.java
@@ -0,0 +1,174 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.neighborhood.NearestNUserNeighborhood;
+import org.apache.mahout.cf.taste.impl.similarity.PearsonCorrelationSimilarity;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.neighborhood.UserNeighborhood;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.apache.mahout.cf.taste.recommender.UserBasedRecommender;
+import org.apache.mahout.cf.taste.similarity.UserSimilarity;
+import org.junit.Test;
+
+import java.util.List;
+
+/** <p>Tests {@link GenericUserBasedRecommender}.</p> */
+public final class GenericUserBasedRecommenderTest extends TasteTestCase {
+
+ @Test
+ public void testRecommender() throws Exception {
+ Recommender recommender = buildRecommender();
+ List<RecommendedItem> recommended = recommender.recommend(1, 1);
+ assertNotNull(recommended);
+ assertEquals(1, recommended.size());
+ RecommendedItem firstRecommended = recommended.get(0);
+ assertEquals(2, firstRecommended.getItemID());
+ assertEquals(0.1f, firstRecommended.getValue(), EPSILON);
+ recommender.refresh(null);
+ assertEquals(2, firstRecommended.getItemID());
+ assertEquals(0.1f, firstRecommended.getValue(), EPSILON);
+ }
+
+ @Test
+ public void testHowMany() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3, 4, 5},
+ new Double[][] {
+ {0.1, 0.2},
+ {0.2, 0.3, 0.3, 0.6},
+ {0.4, 0.4, 0.5, 0.9},
+ {0.1, 0.4, 0.5, 0.8, 0.9, 1.0},
+ {0.2, 0.3, 0.6, 0.7, 0.1, 0.2},
+ });
+ UserSimilarity similarity = new PearsonCorrelationSimilarity(dataModel);
+ UserNeighborhood neighborhood = new NearestNUserNeighborhood(2, similarity, dataModel);
+ Recommender recommender = new GenericUserBasedRecommender(dataModel, neighborhood, similarity);
+ List<RecommendedItem> fewRecommended = recommender.recommend(1, 2);
+ List<RecommendedItem> moreRecommended = recommender.recommend(1, 4);
+ for (int i = 0; i < fewRecommended.size(); i++) {
+ assertEquals(fewRecommended.get(i).getItemID(), moreRecommended.get(i).getItemID());
+ }
+ recommender.refresh(null);
+ for (int i = 0; i < fewRecommended.size(); i++) {
+ assertEquals(fewRecommended.get(i).getItemID(), moreRecommended.get(i).getItemID());
+ }
+ }
+
+ @Test
+ public void testRescorer() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {0.1, 0.2},
+ {0.2, 0.3, 0.3, 0.6},
+ {0.4, 0.5, 0.5, 0.9},
+ });
+ UserSimilarity similarity = new PearsonCorrelationSimilarity(dataModel);
+ UserNeighborhood neighborhood = new NearestNUserNeighborhood(2, similarity, dataModel);
+ Recommender recommender = new GenericUserBasedRecommender(dataModel, neighborhood, similarity);
+ List<RecommendedItem> originalRecommended = recommender.recommend(1, 2);
+ List<RecommendedItem> rescoredRecommended =
+ recommender.recommend(1, 2, new ReversingRescorer<Long>());
+ assertNotNull(originalRecommended);
+ assertNotNull(rescoredRecommended);
+ assertEquals(2, originalRecommended.size());
+ assertEquals(2, rescoredRecommended.size());
+ assertEquals(originalRecommended.get(0).getItemID(), rescoredRecommended.get(1).getItemID());
+ assertEquals(originalRecommended.get(1).getItemID(), rescoredRecommended.get(0).getItemID());
+ }
+
+ @Test
+ public void testIncludeKnownItems() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][] {
+ {0.1, 0.2},
+ {0.2, 0.3, 0.3, 0.6},
+ {0.4, 0.5, 0.5, 0.9},
+ });
+ UserSimilarity similarity = new PearsonCorrelationSimilarity(dataModel);
+ UserNeighborhood neighborhood = new NearestNUserNeighborhood(2, similarity, dataModel);
+ Recommender recommender = new GenericUserBasedRecommender(dataModel, neighborhood, similarity);
+ List<RecommendedItem> originalRecommended = recommender.recommend(1, 4, null, true);
+ List<RecommendedItem> rescoredRecommended = recommender.recommend(1, 4, new ReversingRescorer<Long>(), true);
+ assertNotNull(originalRecommended);
+ assertNotNull(rescoredRecommended);
+ assertEquals(4, originalRecommended.size());
+ assertEquals(4, rescoredRecommended.size());
+ assertEquals(originalRecommended.get(0).getItemID(), rescoredRecommended.get(3).getItemID());
+ assertEquals(originalRecommended.get(3).getItemID(), rescoredRecommended.get(0).getItemID());
+ }
+
+ @Test
+ public void testEstimatePref() throws Exception {
+ Recommender recommender = buildRecommender();
+ assertEquals(0.1f, recommender.estimatePreference(1, 2), EPSILON);
+ }
+
+ @Test
+ public void testBestRating() throws Exception {
+ Recommender recommender = buildRecommender();
+ List<RecommendedItem> recommended = recommender.recommend(1, 1);
+ assertNotNull(recommended);
+ assertEquals(1, recommended.size());
+ RecommendedItem firstRecommended = recommended.get(0);
+ // item one should be recommended because it has a greater rating/score
+ assertEquals(2, firstRecommended.getItemID());
+ assertEquals(0.1f, firstRecommended.getValue(), EPSILON);
+ }
+
+ @Test
+ public void testMostSimilar() throws Exception {
+ UserBasedRecommender recommender = buildRecommender();
+ long[] similar = recommender.mostSimilarUserIDs(1, 2);
+ assertNotNull(similar);
+ assertEquals(2, similar.length);
+ assertEquals(2, similar[0]);
+ assertEquals(3, similar[1]);
+ }
+
+ @Test
+ public void testIsolatedUser() throws Exception {
+ DataModel dataModel = getDataModel(
+ new long[] {1, 2, 3, 4},
+ new Double[][] {
+ {0.1, 0.2},
+ {0.2, 0.3, 0.3, 0.6},
+ {0.4, 0.4, 0.5, 0.9},
+ {null, null, null, null, 1.0},
+ });
+ UserSimilarity similarity = new PearsonCorrelationSimilarity(dataModel);
+ UserNeighborhood neighborhood = new NearestNUserNeighborhood(3, similarity, dataModel);
+ UserBasedRecommender recommender = new GenericUserBasedRecommender(dataModel, neighborhood, similarity);
+ long[] mostSimilar = recommender.mostSimilarUserIDs(4, 3);
+ assertNotNull(mostSimilar);
+ assertEquals(0, mostSimilar.length);
+ }
+
+ private static UserBasedRecommender buildRecommender() throws TasteException {
+ DataModel dataModel = getDataModel();
+ UserSimilarity similarity = new PearsonCorrelationSimilarity(dataModel);
+ UserNeighborhood neighborhood = new NearestNUserNeighborhood(2, similarity, dataModel);
+ return new GenericUserBasedRecommender(dataModel, neighborhood, similarity);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ItemAverageRecommenderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ItemAverageRecommenderTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ItemAverageRecommenderTest.java
new file mode 100644
index 0000000..243eaa9
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ItemAverageRecommenderTest.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.junit.Test;
+
+import java.util.List;
+
+public final class ItemAverageRecommenderTest extends TasteTestCase {
+
+ @Test
+ public void testRecommender() throws Exception {
+ Recommender recommender = new ItemAverageRecommender(getDataModel());
+ List<RecommendedItem> recommended = recommender.recommend(1, 1);
+ assertNotNull(recommended);
+ assertEquals(1, recommended.size());
+ RecommendedItem firstRecommended = recommended.get(0);
+ assertEquals(2, firstRecommended.getItemID());
+ assertEquals(0.53333336f, firstRecommended.getValue(), EPSILON);
+ recommender.refresh(null);
+ assertEquals(2, firstRecommended.getItemID());
+ assertEquals(0.53333336f, firstRecommended.getValue(), EPSILON);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ItemUserAverageRecommenderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ItemUserAverageRecommenderTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ItemUserAverageRecommenderTest.java
new file mode 100644
index 0000000..f8bf1a1
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ItemUserAverageRecommenderTest.java
@@ -0,0 +1,43 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.junit.Test;
+
+import java.util.List;
+
+public final class ItemUserAverageRecommenderTest extends TasteTestCase {
+
+ @Test
+ public void testRecommender() throws Exception {
+ Recommender recommender = new ItemUserAverageRecommender(getDataModel());
+ List<RecommendedItem> recommended = recommender.recommend(1, 1);
+ assertNotNull(recommended);
+ assertEquals(1, recommended.size());
+ RecommendedItem firstRecommended = recommended.get(0);
+ assertEquals(2, firstRecommended.getItemID());
+ assertEquals(0.35151517f, firstRecommended.getValue(), EPSILON);
+ recommender.refresh(null);
+ assertEquals(2, firstRecommended.getItemID());
+ assertEquals(0.35151517f, firstRecommended.getValue(), EPSILON);
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/MockRecommender.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/MockRecommender.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/MockRecommender.java
new file mode 100644
index 0000000..50a16cb
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/MockRecommender.java
@@ -0,0 +1,89 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.commons.lang3.mutable.MutableInt;
+import org.apache.mahout.cf.taste.common.Refreshable;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+
+
+final class MockRecommender implements Recommender {
+
+ private final MutableInt recommendCount;
+
+ MockRecommender(MutableInt recommendCount) {
+ this.recommendCount = recommendCount;
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany) {
+ recommendCount.increment();
+ return Collections.<RecommendedItem>singletonList(
+ new GenericRecommendedItem(1, 1.0f));
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, boolean includeKnownItems) {
+ return recommend(userID, howMany);
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer) {
+ return recommend(userID, howMany);
+ }
+
+ @Override
+ public List<RecommendedItem> recommend(long userID, int howMany, IDRescorer rescorer, boolean includeKnownItems) {
+ return recommend(userID, howMany);
+ }
+
+ @Override
+ public float estimatePreference(long userID, long itemID) {
+ recommendCount.increment();
+ return 0.0f;
+ }
+
+ @Override
+ public void setPreference(long userID, long itemID, float value) {
+ // do nothing
+ }
+
+ @Override
+ public void removePreference(long userID, long itemID) {
+ // do nothing
+ }
+
+ @Override
+ public DataModel getDataModel() {
+ return TasteTestCase.getDataModel(
+ new long[] {1, 2, 3},
+ new Double[][]{{1.0},{2.0},{3.0}});
+ }
+
+ @Override
+ public void refresh(Collection<Refreshable> alreadyRefreshed) {}
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/NullRescorerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/NullRescorerTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/NullRescorerTest.java
new file mode 100644
index 0000000..97e539e
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/NullRescorerTest.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.junit.Test;
+
+/** <p>Tests {@link NullRescorer}.</p> */
+public final class NullRescorerTest extends TasteTestCase {
+
+ @Test
+ public void testItemRescorer() throws Exception {
+ IDRescorer rescorer = NullRescorer.getItemInstance();
+ assertNotNull(rescorer);
+ assertEquals(1.0, rescorer.rescore(1L, 1.0), EPSILON);
+ assertEquals(1.0, rescorer.rescore(0L, 1.0), EPSILON);
+ assertEquals(0.0, rescorer.rescore(1L, 0.0), EPSILON);
+ assertTrue(Double.isNaN(rescorer.rescore(1L, Double.NaN)));
+ }
+
+ @Test
+ public void testUserRescorer() throws Exception {
+ IDRescorer rescorer = NullRescorer.getUserInstance();
+ assertNotNull(rescorer);
+ assertEquals(1.0, rescorer.rescore(1L, 1.0), EPSILON);
+ assertEquals(1.0, rescorer.rescore(0L, 1.0), EPSILON);
+ assertEquals(0.0, rescorer.rescore(1L, 0.0), EPSILON);
+ assertTrue(Double.isNaN(rescorer.rescore(1L, Double.NaN)));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/PreferredItemsNeighborhoodCandidateItemsStrategyTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/PreferredItemsNeighborhoodCandidateItemsStrategyTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/PreferredItemsNeighborhoodCandidateItemsStrategyTest.java
new file mode 100644
index 0000000..cbf20cf
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/PreferredItemsNeighborhoodCandidateItemsStrategyTest.java
@@ -0,0 +1,75 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import java.util.Collections;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.model.GenericItemPreferenceArray;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+/**
+ * Tests {@link PreferredItemsNeighborhoodCandidateItemsStrategy}
+ */
+public final class PreferredItemsNeighborhoodCandidateItemsStrategyTest extends TasteTestCase {
+
+ @Test
+ public void testStrategy() throws TasteException {
+ FastIDSet itemIDsFromUser123 = new FastIDSet();
+ itemIDsFromUser123.add(1L);
+
+ FastIDSet itemIDsFromUser456 = new FastIDSet();
+ itemIDsFromUser456.add(1L);
+ itemIDsFromUser456.add(2L);
+
+ List<Preference> prefs = Lists.newArrayList();
+ prefs.add(new GenericPreference(123L, 1L, 1.0f));
+ prefs.add(new GenericPreference(456L, 1L, 1.0f));
+ PreferenceArray preferencesForItem1 = new GenericItemPreferenceArray(prefs);
+
+ DataModel dataModel = EasyMock.createMock(DataModel.class);
+ EasyMock.expect(dataModel.getPreferencesForItem(1L)).andReturn(preferencesForItem1);
+ EasyMock.expect(dataModel.getItemIDsFromUser(123L)).andReturn(itemIDsFromUser123);
+ EasyMock.expect(dataModel.getItemIDsFromUser(456L)).andReturn(itemIDsFromUser456);
+
+ PreferenceArray prefArrayOfUser123 =
+ new GenericUserPreferenceArray(Collections.singletonList(new GenericPreference(123L, 1L, 1.0f)));
+
+ CandidateItemsStrategy strategy = new PreferredItemsNeighborhoodCandidateItemsStrategy();
+
+ EasyMock.replay(dataModel);
+
+ FastIDSet candidateItems = strategy.getCandidateItems(123L, prefArrayOfUser123, dataModel, false);
+ assertEquals(1, candidateItems.size());
+ assertTrue(candidateItems.contains(2L));
+
+ EasyMock.verify(dataModel);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommenderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommenderTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommenderTest.java
new file mode 100644
index 0000000..f57d389
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/RandomRecommenderTest.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.cf.taste.recommender.Recommender;
+import org.junit.Test;
+
+import java.util.List;
+
+public final class RandomRecommenderTest extends TasteTestCase {
+
+ @Test
+ public void testRecommender() throws Exception {
+ Recommender recommender = new RandomRecommender(getDataModel());
+ List<RecommendedItem> recommended = recommender.recommend(1, 1);
+ assertNotNull(recommended);
+ assertEquals(1, recommended.size());
+ RecommendedItem firstRecommended = recommended.get(0);
+ assertEquals(2, firstRecommended.getItemID());
+ recommender.refresh(null);
+ assertEquals(2, firstRecommended.getItemID());
+ }
+
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ReversingRescorer.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ReversingRescorer.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ReversingRescorer.java
new file mode 100644
index 0000000..3c4f7fc
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/ReversingRescorer.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import org.apache.mahout.cf.taste.recommender.IDRescorer;
+import org.apache.mahout.cf.taste.recommender.Rescorer;
+
+/** <p>Simple {@link Rescorer} which negates the given score, thus reversing order of rankings.</p> */
+public final class ReversingRescorer<T> implements Rescorer<T>, IDRescorer {
+
+ @Override
+ public double rescore(T thing, double originalScore) {
+ return -originalScore;
+ }
+
+ @Override
+ public boolean isFiltered(T thing) {
+ return false;
+ }
+
+ @Override
+ public double rescore(long ID, double originalScore) {
+ return -originalScore;
+ }
+
+ @Override
+ public boolean isFiltered(long ID) {
+ return false;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategyTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategyTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategyTest.java
new file mode 100644
index 0000000..687b2b1
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/SamplingCandidateItemsStrategyTest.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import org.junit.Test;
+
+import java.util.List;
+
+/**
+ * Tests {@link SamplingCandidateItemsStrategy}
+ */
+public final class SamplingCandidateItemsStrategyTest extends TasteTestCase {
+
+ @Test
+ public void testStrategy() throws TasteException {
+ List<Preference> prefsOfUser123 = Lists.newArrayList();
+ prefsOfUser123.add(new GenericPreference(123L, 1L, 1.0f));
+
+ List<Preference> prefsOfUser456 = Lists.newArrayList();
+ prefsOfUser456.add(new GenericPreference(456L, 1L, 1.0f));
+ prefsOfUser456.add(new GenericPreference(456L, 2L, 1.0f));
+
+ List<Preference> prefsOfUser789 = Lists.newArrayList();
+ prefsOfUser789.add(new GenericPreference(789L, 1L, 0.5f));
+ prefsOfUser789.add(new GenericPreference(789L, 3L, 1.0f));
+
+ PreferenceArray prefArrayOfUser123 = new GenericUserPreferenceArray(prefsOfUser123);
+
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<PreferenceArray>();
+ userData.put(123L, prefArrayOfUser123);
+ userData.put(456L, new GenericUserPreferenceArray(prefsOfUser456));
+ userData.put(789L, new GenericUserPreferenceArray(prefsOfUser789));
+
+ DataModel dataModel = new GenericDataModel(userData);
+
+ CandidateItemsStrategy strategy =
+ new SamplingCandidateItemsStrategy(1, 1, 1, dataModel.getNumUsers(), dataModel.getNumItems());
+
+ FastIDSet candidateItems = strategy.getCandidateItems(123L, prefArrayOfUser123, dataModel, false);
+ /* result can be either item2 or item3 or empty */
+ assertTrue(candidateItems.size() <= 1);
+ assertFalse(candidateItems.contains(1L));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/TopItemsTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/TopItemsTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/TopItemsTest.java
new file mode 100644
index 0000000..1d8b862
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/TopItemsTest.java
@@ -0,0 +1,158 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender;
+
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveArrayIterator;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.similarity.GenericItemSimilarity;
+import org.apache.mahout.cf.taste.impl.similarity.GenericUserSimilarity;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+import java.util.List;
+import java.util.Random;
+
+/**
+ * Tests for {@link TopItems}.
+ */
+public final class TopItemsTest extends TasteTestCase {
+
+ @Test
+ public void testTopItems() throws Exception {
+ long[] ids = new long[100];
+ for (int i = 0; i < 100; i++) {
+ ids[i] = i;
+ }
+ LongPrimitiveIterator possibleItemIds = new LongPrimitiveArrayIterator(ids);
+ TopItems.Estimator<Long> estimator = new TopItems.Estimator<Long>() {
+ @Override
+ public double estimate(Long thing) {
+ return thing;
+ }
+ };
+ List<RecommendedItem> topItems = TopItems.getTopItems(10, possibleItemIds, null, estimator);
+ int gold = 99;
+ for (RecommendedItem topItem : topItems) {
+ assertEquals(gold, topItem.getItemID());
+ assertEquals(gold--, topItem.getValue(), 0.01);
+ }
+ }
+
+ @Test
+ public void testTopItemsRandom() throws Exception {
+ long[] ids = new long[100];
+ for (int i = 0; i < 100; i++) {
+ ids[i] = i;
+ }
+ LongPrimitiveIterator possibleItemIds = new LongPrimitiveArrayIterator(ids);
+ final Random random = RandomUtils.getRandom();
+ TopItems.Estimator<Long> estimator = new TopItems.Estimator<Long>() {
+ @Override
+ public double estimate(Long thing) {
+ return random.nextDouble();
+ }
+ };
+ List<RecommendedItem> topItems = TopItems.getTopItems(10, possibleItemIds, null, estimator);
+ assertEquals(10, topItems.size());
+ double last = 2.0;
+ for (RecommendedItem topItem : topItems) {
+ assertTrue(topItem.getValue() <= last);
+ last = topItem.getItemID();
+ }
+ }
+
+ @Test
+ public void testTopUsers() throws Exception {
+ long[] ids = new long[100];
+ for (int i = 0; i < 100; i++) {
+ ids[i] = i;
+ }
+ LongPrimitiveIterator possibleItemIds = new LongPrimitiveArrayIterator(ids);
+ TopItems.Estimator<Long> estimator = new TopItems.Estimator<Long>() {
+ @Override
+ public double estimate(Long thing) {
+ return thing;
+ }
+ };
+ long[] topItems = TopItems.getTopUsers(10, possibleItemIds, null, estimator);
+ int gold = 99;
+ for (long topItem : topItems) {
+ assertEquals(gold--, topItem);
+ }
+ }
+
+ @Test
+ public void testTopItemItem() throws Exception {
+ List<GenericItemSimilarity.ItemItemSimilarity> sims = Lists.newArrayList();
+ for (int i = 0; i < 99; i++) {
+ sims.add(new GenericItemSimilarity.ItemItemSimilarity(i, i + 1, i / 99.0));
+ }
+
+ List<GenericItemSimilarity.ItemItemSimilarity> res = TopItems.getTopItemItemSimilarities(10, sims.iterator());
+ int gold = 99;
+ for (GenericItemSimilarity.ItemItemSimilarity re : res) {
+ assertEquals(gold--, re.getItemID2()); //the second id should be equal to 99 to start
+ }
+ }
+
+ @Test
+ public void testTopItemItemAlt() throws Exception {
+ List<GenericItemSimilarity.ItemItemSimilarity> sims = Lists.newArrayList();
+ for (int i = 0; i < 99; i++) {
+ sims.add(new GenericItemSimilarity.ItemItemSimilarity(i, i + 1, 1 - (i / 99.0)));
+ }
+
+ List<GenericItemSimilarity.ItemItemSimilarity> res = TopItems.getTopItemItemSimilarities(10, sims.iterator());
+ int gold = 0;
+ for (GenericItemSimilarity.ItemItemSimilarity re : res) {
+ assertEquals(gold++, re.getItemID1()); //the second id should be equal to 99 to start
+ }
+ }
+
+ @Test
+ public void testTopUserUser() throws Exception {
+ List<GenericUserSimilarity.UserUserSimilarity> sims = Lists.newArrayList();
+ for (int i = 0; i < 99; i++) {
+ sims.add(new GenericUserSimilarity.UserUserSimilarity(i, i + 1, i / 99.0));
+ }
+
+ List<GenericUserSimilarity.UserUserSimilarity> res = TopItems.getTopUserUserSimilarities(10, sims.iterator());
+ int gold = 99;
+ for (GenericUserSimilarity.UserUserSimilarity re : res) {
+ assertEquals(gold--, re.getUserID2()); //the second id should be equal to 99 to start
+ }
+ }
+
+ @Test
+ public void testTopUserUserAlt() throws Exception {
+ List<GenericUserSimilarity.UserUserSimilarity> sims = Lists.newArrayList();
+ for (int i = 0; i < 99; i++) {
+ sims.add(new GenericUserSimilarity.UserUserSimilarity(i, i + 1, 1 - (i / 99.0)));
+ }
+
+ List<GenericUserSimilarity.UserUserSimilarity> res = TopItems.getTopUserUserSimilarities(10, sims.iterator());
+ int gold = 0;
+ for (GenericUserSimilarity.UserUserSimilarity re : res) {
+ assertEquals(gold++, re.getUserID1()); //the first id should be equal to 0 to start
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
new file mode 100644
index 0000000..23fa38f
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ALSWRFactorizerTest.java
@@ -0,0 +1,208 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.SparseRowMatrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
+
+import java.util.Arrays;
+import java.util.Iterator;
+
+public class ALSWRFactorizerTest extends TasteTestCase {
+
+ private ALSWRFactorizer factorizer;
+ private DataModel dataModel;
+
+ private static final Logger log = LoggerFactory.getLogger(ALSWRFactorizerTest.class);
+
+ /**
+ * rating-matrix
+ *
+ * burger hotdog berries icecream
+ * dog 5 5 2 -
+ * rabbit 2 - 3 5
+ * cow - 5 - 3
+ * donkey 3 - - 5
+ */
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<PreferenceArray>();
+
+ userData.put(1L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1L, 1L, 5.0f),
+ new GenericPreference(1L, 2L, 5.0f),
+ new GenericPreference(1L, 3L, 2.0f))));
+
+ userData.put(2L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2L, 1L, 2.0f),
+ new GenericPreference(2L, 3L, 3.0f),
+ new GenericPreference(2L, 4L, 5.0f))));
+
+ userData.put(3L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(3L, 2L, 5.0f),
+ new GenericPreference(3L, 4L, 3.0f))));
+
+ userData.put(4L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(4L, 1L, 3.0f),
+ new GenericPreference(4L, 4L, 5.0f))));
+
+ dataModel = new GenericDataModel(userData);
+ factorizer = new ALSWRFactorizer(dataModel, 3, 0.065, 10);
+ }
+
+ @Test
+ public void setFeatureColumn() throws Exception {
+ ALSWRFactorizer.Features features = new ALSWRFactorizer.Features(factorizer);
+ Vector vector = new DenseVector(new double[] { 0.5, 2.0, 1.5 });
+ int index = 1;
+
+ features.setFeatureColumnInM(index, vector);
+ double[][] matrix = features.getM();
+
+ assertEquals(vector.get(0), matrix[index][0], EPSILON);
+ assertEquals(vector.get(1), matrix[index][1], EPSILON);
+ assertEquals(vector.get(2), matrix[index][2], EPSILON);
+ }
+
+ @Test
+ public void ratingVector() throws Exception {
+ PreferenceArray prefs = dataModel.getPreferencesFromUser(1);
+
+ Vector ratingVector = ALSWRFactorizer.ratingVector(prefs);
+
+ assertEquals(prefs.length(), ratingVector.getNumNondefaultElements());
+ assertEquals(prefs.get(0).getValue(), ratingVector.get(0), EPSILON);
+ assertEquals(prefs.get(1).getValue(), ratingVector.get(1), EPSILON);
+ assertEquals(prefs.get(2).getValue(), ratingVector.get(2), EPSILON);
+ }
+
+ @Test
+ public void averageRating() throws Exception {
+ ALSWRFactorizer.Features features = new ALSWRFactorizer.Features(factorizer);
+ assertEquals(2.5, features.averateRating(3L), EPSILON);
+ }
+
+ @Test
+ public void initializeM() throws Exception {
+ ALSWRFactorizer.Features features = new ALSWRFactorizer.Features(factorizer);
+ double[][] M = features.getM();
+
+ assertEquals(3.333333333, M[0][0], EPSILON);
+ assertEquals(5, M[1][0], EPSILON);
+ assertEquals(2.5, M[2][0], EPSILON);
+ assertEquals(4.333333333, M[3][0], EPSILON);
+
+ for (int itemIndex = 0; itemIndex < dataModel.getNumItems(); itemIndex++) {
+ for (int feature = 1; feature < 3; feature++ ) {
+ assertTrue(M[itemIndex][feature] >= 0);
+ assertTrue(M[itemIndex][feature] <= 0.1);
+ }
+ }
+ }
+
+ @ThreadLeakLingering(linger = 10)
+ @Test
+ public void toyExample() throws Exception {
+
+ SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer);
+
+ /* a hold out test would be better, but this is just a toy example so we only check that the
+ * factorization is close to the original matrix */
+ RunningAverage avg = new FullRunningAverage();
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ for (Preference pref : dataModel.getPreferencesFromUser(userID)) {
+ double rating = pref.getValue();
+ double estimate = svdRecommender.estimatePreference(userID, pref.getItemID());
+ double err = rating - estimate;
+ avg.addDatum(err * err);
+ }
+ }
+
+ double rmse = Math.sqrt(avg.getAverage());
+ assertTrue(rmse < 0.2);
+ }
+
+ @Test
+ public void toyExampleImplicit() throws Exception {
+
+ Matrix observations = new SparseRowMatrix(4, 4, new Vector[] {
+ new DenseVector(new double[] { 5.0, 5.0, 2.0, 0 }),
+ new DenseVector(new double[] { 2.0, 0, 3.0, 5.0 }),
+ new DenseVector(new double[] { 0, 5.0, 0, 3.0 }),
+ new DenseVector(new double[] { 3.0, 0, 0, 5.0 }) });
+
+ Matrix preferences = new SparseRowMatrix(4, 4, new Vector[] {
+ new DenseVector(new double[] { 1.0, 1.0, 1.0, 0 }),
+ new DenseVector(new double[] { 1.0, 0, 1.0, 1.0 }),
+ new DenseVector(new double[] { 0, 1.0, 0, 1.0 }),
+ new DenseVector(new double[] { 1.0, 0, 0, 1.0 }) });
+
+ double alpha = 20;
+
+ ALSWRFactorizer factorizer = new ALSWRFactorizer(dataModel, 3, 0.065, 5, true, alpha);
+
+ SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer);
+
+ RunningAverage avg = new FullRunningAverage();
+ Iterator<MatrixSlice> sliceIterator = preferences.iterateAll();
+ while (sliceIterator.hasNext()) {
+ MatrixSlice slice = sliceIterator.next();
+ for (Vector.Element e : slice.vector().all()) {
+
+ long userID = slice.index() + 1;
+ long itemID = e.index() + 1;
+
+ if (!Double.isNaN(e.get())) {
+ double pref = e.get();
+ double estimate = svdRecommender.estimatePreference(userID, itemID);
+
+ double confidence = 1 + alpha * observations.getQuick(slice.index(), e.index());
+ double err = confidence * (pref - estimate) * (pref - estimate);
+ avg.addDatum(err);
+ log.info("Comparing preference of user [{}] towards item [{}], was [{}] with confidence [{}] "
+ + "estimate is [{}]", slice.index(), e.index(), pref, confidence, estimate);
+ }
+ }
+ }
+ double rmse = Math.sqrt(avg.getAverage());
+ log.info("RMSE: {}", rmse);
+
+ assertTrue(rmse < 0.4);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategyTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategyTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategyTest.java
new file mode 100644
index 0000000..eb8a356
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/FilePersistenceStrategyTest.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.junit.Test;
+
+import java.io.File;
+
+public class FilePersistenceStrategyTest extends TasteTestCase {
+
+ @Test
+ public void persistAndLoad() throws Exception {
+ FastByIDMap<Integer> userIDMapping = new FastByIDMap<Integer>();
+ FastByIDMap<Integer> itemIDMapping = new FastByIDMap<Integer>();
+
+ userIDMapping.put(123, 0);
+ userIDMapping.put(456, 1);
+
+ itemIDMapping.put(12, 0);
+ itemIDMapping.put(34, 1);
+
+ double[][] userFeatures = { { 0.1, 0.2, 0.3 }, { 0.4, 0.5, 0.6 } };
+ double[][] itemFeatures = { { 0.7, 0.8, 0.9 }, { 1.0, 1.1, 1.2 } };
+
+ Factorization original = new Factorization(userIDMapping, itemIDMapping, userFeatures, itemFeatures);
+ File storage = getTestTempFile("storage.bin");
+ PersistenceStrategy persistenceStrategy = new FilePersistenceStrategy(storage);
+
+ assertNull(persistenceStrategy.load());
+
+ persistenceStrategy.maybePersist(original);
+ Factorization clone = persistenceStrategy.load();
+
+ assertEquals(original, clone);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizerTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizerTest.java
new file mode 100644
index 0000000..8a91e7a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/ParallelSGDFactorizerTest.java
@@ -0,0 +1,355 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import java.util.Arrays;
+import java.util.List;
+
+import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering;
+import com.google.common.collect.Lists;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastByIDMap;
+import org.apache.mahout.cf.taste.impl.common.FullRunningAverage;
+import org.apache.mahout.cf.taste.impl.common.LongPrimitiveIterator;
+import org.apache.mahout.cf.taste.impl.common.RunningAverage;
+import org.apache.mahout.cf.taste.impl.model.GenericDataModel;
+import org.apache.mahout.cf.taste.impl.model.GenericPreference;
+import org.apache.mahout.cf.taste.impl.model.GenericUserPreferenceArray;
+import org.apache.mahout.cf.taste.impl.recommender.svd.ParallelSGDFactorizer.PreferenceShuffler;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.Preference;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.VectorFunction;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ParallelSGDFactorizerTest extends TasteTestCase {
+
+ protected DataModel dataModel;
+
+ protected int rank;
+ protected double lambda;
+ protected int numIterations;
+
+ private RandomWrapper random = (RandomWrapper) RandomUtils.getRandom();
+
+ protected Factorizer factorizer;
+ protected SVDRecommender svdRecommender;
+
+ private static final Logger logger = LoggerFactory.getLogger(ParallelSGDFactorizerTest.class);
+
+ private Matrix randomMatrix(int numRows, int numColumns, double range) {
+ double[][] data = new double[numRows][numColumns];
+ for (int i = 0; i < numRows; i++) {
+ for (int j = 0; j < numColumns; j++) {
+ double sqrtUniform = random.nextDouble();
+ data[i][j] = sqrtUniform * range;
+ }
+ }
+ return new DenseMatrix(data);
+ }
+
+ private void normalize(Matrix source, final double range) {
+ final double max = source.aggregateColumns(new VectorFunction() {
+ @Override
+ public double apply(Vector column) {
+ return column.maxValue();
+ }
+ }).maxValue();
+
+ final double min = source.aggregateColumns(new VectorFunction() {
+ @Override
+ public double apply(Vector column) {
+ return column.minValue();
+ }
+ }).minValue();
+
+ source.assign(new DoubleFunction() {
+ @Override
+ public double apply(double value) {
+ return (value - min) * range / (max - min);
+ }
+ });
+ }
+
+ public void setUpSyntheticData() throws Exception {
+
+ int numUsers = 2000;
+ int numItems = 1000;
+ double sparsity = 0.5;
+
+ this.rank = 20;
+ this.lambda = 0.000000001;
+ this.numIterations = 100;
+
+ Matrix users = randomMatrix(numUsers, rank, 1);
+ Matrix items = randomMatrix(rank, numItems, 1);
+ Matrix ratings = users.times(items);
+ normalize(ratings, 5);
+
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<PreferenceArray>();
+ for (int userIndex = 0; userIndex < numUsers; userIndex++) {
+ List<Preference> row= Lists.newArrayList();
+ for (int itemIndex = 0; itemIndex < numItems; itemIndex++) {
+ if (random.nextDouble() <= sparsity) {
+ row.add(new GenericPreference(userIndex, itemIndex, (float) ratings.get(userIndex, itemIndex)));
+ }
+ }
+
+ userData.put(userIndex, new GenericUserPreferenceArray(row));
+ }
+
+ dataModel = new GenericDataModel(userData);
+ }
+
+ public void setUpToyData() throws Exception {
+ this.rank = 3;
+ this.lambda = 0.01;
+ this.numIterations = 1000;
+
+ FastByIDMap<PreferenceArray> userData = new FastByIDMap<PreferenceArray>();
+
+ userData.put(1L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(1L, 1L, 5.0f),
+ new GenericPreference(1L, 2L, 5.0f),
+ new GenericPreference(1L, 3L, 2.0f))));
+
+ userData.put(2L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(2L, 1L, 2.0f),
+ new GenericPreference(2L, 3L, 3.0f),
+ new GenericPreference(2L, 4L, 5.0f))));
+
+ userData.put(3L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(3L, 2L, 5.0f),
+ new GenericPreference(3L, 4L, 3.0f))));
+
+ userData.put(4L, new GenericUserPreferenceArray(Arrays.asList(new GenericPreference(4L, 1L, 3.0f),
+ new GenericPreference(4L, 4L, 5.0f))));
+ dataModel = new GenericDataModel(userData);
+ }
+
+ @Test
+ public void testPreferenceShufflerWithSyntheticData() throws Exception {
+ setUpSyntheticData();
+
+ ParallelSGDFactorizer.PreferenceShuffler shuffler = new PreferenceShuffler(dataModel);
+ shuffler.shuffle();
+ shuffler.stage();
+
+ FastByIDMap<FastByIDMap<Boolean>> checked = new FastByIDMap<FastByIDMap<Boolean>>();
+
+ for (int i = 0; i < shuffler.size(); i++) {
+ Preference pref=shuffler.get(i);
+
+ float value = dataModel.getPreferenceValue(pref.getUserID(), pref.getItemID());
+ assertEquals(pref.getValue(), value, 0.0);
+ if (!checked.containsKey(pref.getUserID())) {
+ checked.put(pref.getUserID(), new FastByIDMap<Boolean>());
+ }
+
+ assertNull(checked.get(pref.getUserID()).get(pref.getItemID()));
+
+ checked.get(pref.getUserID()).put(pref.getItemID(), true);
+ }
+
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ int index=0;
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ PreferenceArray preferencesFromUser = dataModel.getPreferencesFromUser(userID);
+ for (Preference preference : preferencesFromUser) {
+ assertTrue(checked.get(preference.getUserID()).get(preference.getItemID()));
+ index++;
+ }
+ }
+ assertEquals(index, shuffler.size());
+ }
+
+ @ThreadLeakLingering(linger = 1000)
+ @Test
+ public void testFactorizerWithToyData() throws Exception {
+
+ setUpToyData();
+
+ long start = System.currentTimeMillis();
+
+ factorizer = new ParallelSGDFactorizer(dataModel, rank, lambda, numIterations, 0.01, 1, 0, 0);
+
+ Factorization factorization = factorizer.factorize();
+
+ long duration = System.currentTimeMillis() - start;
+
+ /* a hold out test would be better, but this is just a toy example so we only check that the
+ * factorization is close to the original matrix */
+ RunningAverage avg = new FullRunningAverage();
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ LongPrimitiveIterator itemIDs;
+
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ for (Preference pref : dataModel.getPreferencesFromUser(userID)) {
+ double rating = pref.getValue();
+ Vector userVector = new DenseVector(factorization.getUserFeatures(userID));
+ Vector itemVector = new DenseVector(factorization.getItemFeatures(pref.getItemID()));
+ double estimate = userVector.dot(itemVector);
+ double err = rating - estimate;
+
+ avg.addDatum(err * err);
+ }
+ }
+
+ double sum = 0.0;
+
+ userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ Vector userVector = new DenseVector(factorization.getUserFeatures(userID));
+ double regularization = userVector.dot(userVector);
+ sum += regularization;
+ }
+
+ itemIDs = dataModel.getItemIDs();
+ while (itemIDs.hasNext()) {
+ long itemID = itemIDs.nextLong();
+ Vector itemVector = new DenseVector(factorization.getUserFeatures(itemID));
+ double regularization = itemVector.dot(itemVector);
+ sum += regularization;
+ }
+
+ double rmse = Math.sqrt(avg.getAverage());
+ double loss = avg.getAverage() / 2 + lambda / 2 * sum;
+ logger.info("RMSE: " + rmse + ";\tLoss: " + loss + ";\tTime Used: " + duration);
+ assertTrue(rmse < 0.2);
+ }
+
+ @ThreadLeakLingering(linger = 1000)
+ @Test
+ public void testRecommenderWithToyData() throws Exception {
+
+ setUpToyData();
+
+ factorizer = new ParallelSGDFactorizer(dataModel, rank, lambda, numIterations, 0.01, 1, 0,0);
+ svdRecommender = new SVDRecommender(dataModel, factorizer);
+
+ /* a hold out test would be better, but this is just a toy example so we only check that the
+ * factorization is close to the original matrix */
+ RunningAverage avg = new FullRunningAverage();
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ for (Preference pref : dataModel.getPreferencesFromUser(userID)) {
+ double rating = pref.getValue();
+ double estimate = svdRecommender.estimatePreference(userID, pref.getItemID());
+ double err = rating - estimate;
+ avg.addDatum(err * err);
+ }
+ }
+
+ double rmse = Math.sqrt(avg.getAverage());
+ logger.info("rmse: " + rmse);
+ assertTrue(rmse < 0.2);
+ }
+
+ @Test
+ public void testFactorizerWithWithSyntheticData() throws Exception {
+
+ setUpSyntheticData();
+
+ long start = System.currentTimeMillis();
+
+ factorizer = new ParallelSGDFactorizer(dataModel, rank, lambda, numIterations, 0.01, 1, 0, 0);
+
+ Factorization factorization = factorizer.factorize();
+
+ long duration = System.currentTimeMillis() - start;
+
+ /* a hold out test would be better, but this is just a toy example so we only check that the
+ * factorization is close to the original matrix */
+ RunningAverage avg = new FullRunningAverage();
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ LongPrimitiveIterator itemIDs;
+
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ for (Preference pref : dataModel.getPreferencesFromUser(userID)) {
+ double rating = pref.getValue();
+ Vector userVector = new DenseVector(factorization.getUserFeatures(userID));
+ Vector itemVector = new DenseVector(factorization.getItemFeatures(pref.getItemID()));
+ double estimate = userVector.dot(itemVector);
+ double err = rating - estimate;
+
+ avg.addDatum(err * err);
+ }
+ }
+
+ double sum = 0.0;
+
+ userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ Vector userVector = new DenseVector(factorization.getUserFeatures(userID));
+ double regularization=userVector.dot(userVector);
+ sum += regularization;
+ }
+
+ itemIDs = dataModel.getItemIDs();
+ while (itemIDs.hasNext()) {
+ long itemID = itemIDs.nextLong();
+ Vector itemVector = new DenseVector(factorization.getUserFeatures(itemID));
+ double regularization = itemVector.dot(itemVector);
+ sum += regularization;
+ }
+
+ double rmse = Math.sqrt(avg.getAverage());
+ double loss = avg.getAverage() / 2 + lambda / 2 * sum;
+ logger.info("RMSE: " + rmse + ";\tLoss: " + loss + ";\tTime Used: " + duration + "ms");
+ assertTrue(rmse < 0.2);
+ }
+
+ @Test
+ public void testRecommenderWithSyntheticData() throws Exception {
+
+ setUpSyntheticData();
+
+ factorizer= new ParallelSGDFactorizer(dataModel, rank, lambda, numIterations, 0.01, 1, 0, 0);
+ svdRecommender = new SVDRecommender(dataModel, factorizer);
+
+ /* a hold out test would be better, but this is just a toy example so we only check that the
+ * factorization is close to the original matrix */
+ RunningAverage avg = new FullRunningAverage();
+ LongPrimitiveIterator userIDs = dataModel.getUserIDs();
+ while (userIDs.hasNext()) {
+ long userID = userIDs.nextLong();
+ for (Preference pref : dataModel.getPreferencesFromUser(userID)) {
+ double rating = pref.getValue();
+ double estimate = svdRecommender.estimatePreference(userID, pref.getItemID());
+ double err = rating - estimate;
+ avg.addDatum(err * err);
+ }
+ }
+
+ double rmse = Math.sqrt(avg.getAverage());
+ logger.info("rmse: " + rmse);
+ assertTrue(rmse < 0.2);
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java
new file mode 100644
index 0000000..aebd324
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/recommender/svd/SVDRecommenderTest.java
@@ -0,0 +1,86 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.recommender.svd;
+
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.impl.common.FastIDSet;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.model.PreferenceArray;
+import org.apache.mahout.cf.taste.recommender.CandidateItemsStrategy;
+import org.apache.mahout.cf.taste.recommender.RecommendedItem;
+import org.easymock.EasyMock;
+import org.junit.Test;
+
+import java.util.List;
+
+public class SVDRecommenderTest extends TasteTestCase {
+
+ @Test
+ public void estimatePreference() throws Exception {
+ DataModel dataModel = EasyMock.createMock(DataModel.class);
+ Factorizer factorizer = EasyMock.createMock(Factorizer.class);
+ Factorization factorization = EasyMock.createMock(Factorization.class);
+
+ EasyMock.expect(factorizer.factorize()).andReturn(factorization);
+ EasyMock.expect(factorization.getUserFeatures(1L)).andReturn(new double[] { 0.4, 2 });
+ EasyMock.expect(factorization.getItemFeatures(5L)).andReturn(new double[] { 1, 0.3 });
+ EasyMock.replay(dataModel, factorizer, factorization);
+
+ SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer);
+
+ float estimate = svdRecommender.estimatePreference(1L, 5L);
+ assertEquals(1, estimate, EPSILON);
+
+ EasyMock.verify(dataModel, factorizer, factorization);
+ }
+
+ @Test
+ public void recommend() throws Exception {
+ DataModel dataModel = EasyMock.createMock(DataModel.class);
+ PreferenceArray preferencesFromUser = EasyMock.createMock(PreferenceArray.class);
+ CandidateItemsStrategy candidateItemsStrategy = EasyMock.createMock(CandidateItemsStrategy.class);
+ Factorizer factorizer = EasyMock.createMock(Factorizer.class);
+ Factorization factorization = EasyMock.createMock(Factorization.class);
+
+ FastIDSet candidateItems = new FastIDSet();
+ candidateItems.add(5L);
+ candidateItems.add(3L);
+
+ EasyMock.expect(factorizer.factorize()).andReturn(factorization);
+ EasyMock.expect(dataModel.getPreferencesFromUser(1L)).andReturn(preferencesFromUser);
+ EasyMock.expect(candidateItemsStrategy.getCandidateItems(1L, preferencesFromUser, dataModel, false))
+ .andReturn(candidateItems);
+ EasyMock.expect(factorization.getUserFeatures(1L)).andReturn(new double[] { 0.4, 2 });
+ EasyMock.expect(factorization.getItemFeatures(5L)).andReturn(new double[] { 1, 0.3 });
+ EasyMock.expect(factorization.getUserFeatures(1L)).andReturn(new double[] { 0.4, 2 });
+ EasyMock.expect(factorization.getItemFeatures(3L)).andReturn(new double[] { 2, 0.6 });
+
+ EasyMock.replay(dataModel, candidateItemsStrategy, factorizer, factorization);
+
+ SVDRecommender svdRecommender = new SVDRecommender(dataModel, factorizer, candidateItemsStrategy);
+
+ List<RecommendedItem> recommendedItems = svdRecommender.recommend(1L, 5);
+ assertEquals(2, recommendedItems.size());
+ assertEquals(3L, recommendedItems.get(0).getItemID());
+ assertEquals(2.0f, recommendedItems.get(0).getValue(), EPSILON);
+ assertEquals(5L, recommendedItems.get(1).getItemID());
+ assertEquals(1.0f, recommendedItems.get(1).getValue(), EPSILON);
+
+ EasyMock.verify(dataModel, candidateItemsStrategy, factorizer, factorization);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrerTest.java b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrerTest.java
new file mode 100644
index 0000000..d8242e3
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/cf/taste/impl/similarity/AveragingPreferenceInferrerTest.java
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.cf.taste.impl.similarity;
+
+import org.apache.mahout.cf.taste.common.TasteException;
+import org.apache.mahout.cf.taste.impl.TasteTestCase;
+import org.apache.mahout.cf.taste.model.DataModel;
+import org.apache.mahout.cf.taste.similarity.PreferenceInferrer;
+import org.junit.Test;
+
+/** <p>Tests {@link AveragingPreferenceInferrer}.</p> */
+public final class AveragingPreferenceInferrerTest extends TasteTestCase {
+
+ @Test
+ public void testInferrer() throws TasteException {
+ DataModel model = getDataModel(new long[] {1}, new Double[][] {{3.0,-2.0,5.0}});
+ PreferenceInferrer inferrer = new AveragingPreferenceInferrer(model);
+ double inferred = inferrer.inferPreference(1, 3);
+ assertEquals(2.0, inferred, EPSILON);
+ }
+
+}
[19/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java b/mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java
new file mode 100644
index 0000000..dd38971
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/solver/DistributedConjugateGradientSolver.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.solver;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.util.Tool;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.DistributedRowMatrix;
+import org.apache.mahout.math.solver.ConjugateGradientSolver;
+import org.apache.mahout.math.solver.Preconditioner;
+
+/**
+ * Distributed implementation of the conjugate gradient solver. More or less, this is just the standard solver
+ * but wrapped with some methods that make it easy to run it on a DistributedRowMatrix.
+ */
+public class DistributedConjugateGradientSolver extends ConjugateGradientSolver implements Tool {
+
+ private Configuration conf;
+ private Map<String, List<String>> parsedArgs;
+
+ /**
+ *
+ * Runs the distributed conjugate gradient solver programmatically to solve the system (A + lambda*I)x = b.
+ *
+ * @param inputPath Path to the matrix A
+ * @param tempPath Path to scratch output path, deleted after the solver completes
+ * @param numRows Number of rows in A
+ * @param numCols Number of columns in A
+ * @param b Vector b
+ * @param preconditioner Optional preconditioner for the system
+ * @param maxIterations Maximum number of iterations to run, defaults to numCols
+ * @param maxError Maximum error tolerated in the result. If the norm of the residual falls below this,
+ * then the algorithm stops and returns.
+ * @return The vector that solves the system.
+ */
+ public Vector runJob(Path inputPath,
+ Path tempPath,
+ int numRows,
+ int numCols,
+ Vector b,
+ Preconditioner preconditioner,
+ int maxIterations,
+ double maxError) {
+ DistributedRowMatrix matrix = new DistributedRowMatrix(inputPath, tempPath, numRows, numCols);
+ matrix.setConf(conf);
+
+ return solve(matrix, b, preconditioner, maxIterations, maxError);
+ }
+
+ @Override
+ public Configuration getConf() {
+ return conf;
+ }
+
+ @Override
+ public void setConf(Configuration conf) {
+ this.conf = conf;
+ }
+
+ @Override
+ public int run(String[] strings) throws Exception {
+ Path inputPath = new Path(AbstractJob.getOption(parsedArgs, "--input"));
+ Path outputPath = new Path(AbstractJob.getOption(parsedArgs, "--output"));
+ Path tempPath = new Path(AbstractJob.getOption(parsedArgs, "--tempDir"));
+ Path vectorPath = new Path(AbstractJob.getOption(parsedArgs, "--vector"));
+ int numRows = Integer.parseInt(AbstractJob.getOption(parsedArgs, "--numRows"));
+ int numCols = Integer.parseInt(AbstractJob.getOption(parsedArgs, "--numCols"));
+ int maxIterations = parsedArgs.containsKey("--maxIter")
+ ? Integer.parseInt(AbstractJob.getOption(parsedArgs, "--maxIter"))
+ : numCols + 2;
+ double maxError = parsedArgs.containsKey("--maxError")
+ ? Double.parseDouble(AbstractJob.getOption(parsedArgs, "--maxError"))
+ : ConjugateGradientSolver.DEFAULT_MAX_ERROR;
+
+ Vector b = loadInputVector(vectorPath);
+ Vector x = runJob(inputPath, tempPath, numRows, numCols, b, null, maxIterations, maxError);
+ saveOutputVector(outputPath, x);
+ tempPath.getFileSystem(conf).delete(tempPath, true);
+
+ return 0;
+ }
+
+ public DistributedConjugateGradientSolverJob job() {
+ return new DistributedConjugateGradientSolverJob();
+ }
+
+ private Vector loadInputVector(Path path) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+ try (SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, conf)) {
+ VectorWritable value = new VectorWritable();
+ if (!reader.next(new IntWritable(), value)) {
+ throw new IOException("Input vector file is empty.");
+ }
+ return value.get();
+ }
+ }
+
+ private void saveOutputVector(Path path, Vector v) throws IOException {
+ FileSystem fs = path.getFileSystem(conf);
+ try (SequenceFile.Writer writer =
+ new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class)) {
+ writer.append(new IntWritable(0), new VectorWritable(v));
+ }
+ }
+
+ public class DistributedConjugateGradientSolverJob extends AbstractJob {
+ @Override
+ public void setConf(Configuration conf) {
+ DistributedConjugateGradientSolver.this.setConf(conf);
+ }
+
+ @Override
+ public Configuration getConf() {
+ return DistributedConjugateGradientSolver.this.getConf();
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+ addInputOption();
+ addOutputOption();
+ addOption("numRows", "nr", "Number of rows in the input matrix", true);
+ addOption("numCols", "nc", "Number of columns in the input matrix", true);
+ addOption("vector", "b", "Vector to solve against", true);
+ addOption("lambda", "l", "Scalar in A + lambda * I [default = 0]", "0.0");
+ addOption("symmetric", "sym", "Is the input matrix square and symmetric?", "true");
+ addOption("maxIter", "x", "Maximum number of iterations to run");
+ addOption("maxError", "err", "Maximum residual error to allow before stopping");
+
+ DistributedConjugateGradientSolver.this.parsedArgs = parseArguments(args);
+ if (DistributedConjugateGradientSolver.this.parsedArgs == null) {
+ return -1;
+ } else {
+ Configuration conf = getConf();
+ if (conf == null) {
+ conf = new Configuration();
+ }
+ DistributedConjugateGradientSolver.this.setConf(conf);
+ return DistributedConjugateGradientSolver.this.run(args);
+ }
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new DistributedConjugateGradientSolver().job(), args);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java
new file mode 100644
index 0000000..ad0baf3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/BasicStats.java
@@ -0,0 +1,148 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stats;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable;
+
+import java.io.IOException;
+
+/**
+ * Methods for calculating basic stats (mean, variance, stdDev, etc.) in map/reduce
+ */
+public final class BasicStats {
+
+ private BasicStats() {
+ }
+
+ /**
+ * Calculate the variance of values stored as
+ *
+ * @param input The input file containing the key and the count
+ * @param output The output to store the intermediate values
+ * @param baseConf
+ * @return The variance (based on sample estimation)
+ */
+ public static double variance(Path input, Path output,
+ Configuration baseConf)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ VarianceTotals varianceTotals = computeVarianceTotals(input, output, baseConf);
+ return varianceTotals.computeVariance();
+ }
+
+ /**
+ * Calculate the variance by a predefined mean of values stored as
+ *
+ * @param input The input file containing the key and the count
+ * @param output The output to store the intermediate values
+ * @param mean The mean based on which to compute the variance
+ * @param baseConf
+ * @return The variance (based on sample estimation)
+ */
+ public static double varianceForGivenMean(Path input, Path output, double mean,
+ Configuration baseConf)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ VarianceTotals varianceTotals = computeVarianceTotals(input, output, baseConf);
+ return varianceTotals.computeVarianceForGivenMean(mean);
+ }
+
+ private static VarianceTotals computeVarianceTotals(Path input, Path output,
+ Configuration baseConf) throws IOException, InterruptedException,
+ ClassNotFoundException {
+ Configuration conf = new Configuration(baseConf);
+ conf.set("io.serializations",
+ "org.apache.hadoop.io.serializer.JavaSerialization,"
+ + "org.apache.hadoop.io.serializer.WritableSerialization");
+ Job job = HadoopUtil.prepareJob(input, output, SequenceFileInputFormat.class,
+ StandardDeviationCalculatorMapper.class, IntWritable.class, DoubleWritable.class,
+ StandardDeviationCalculatorReducer.class, IntWritable.class, DoubleWritable.class,
+ SequenceFileOutputFormat.class, conf);
+ HadoopUtil.delete(conf, output);
+ job.setCombinerClass(StandardDeviationCalculatorReducer.class);
+ boolean succeeded = job.waitForCompletion(true);
+ if (!succeeded) {
+ throw new IllegalStateException("Job failed!");
+ }
+
+ // Now extract the computed sum
+ Path filesPattern = new Path(output, "part-*");
+ double sumOfSquares = 0;
+ double sum = 0;
+ double totalCount = 0;
+ for (Pair<Writable, Writable> record : new SequenceFileDirIterable<>(
+ filesPattern, PathType.GLOB, null, null, true, conf)) {
+
+ int key = ((IntWritable) record.getFirst()).get();
+ if (key == StandardDeviationCalculatorMapper.SUM_OF_SQUARES.get()) {
+ sumOfSquares += ((DoubleWritable) record.getSecond()).get();
+ } else if (key == StandardDeviationCalculatorMapper.TOTAL_COUNT
+ .get()) {
+ totalCount += ((DoubleWritable) record.getSecond()).get();
+ } else if (key == StandardDeviationCalculatorMapper.SUM
+ .get()) {
+ sum += ((DoubleWritable) record.getSecond()).get();
+ }
+ }
+
+ VarianceTotals varianceTotals = new VarianceTotals();
+ varianceTotals.setSum(sum);
+ varianceTotals.setSumOfSquares(sumOfSquares);
+ varianceTotals.setTotalCount(totalCount);
+
+ return varianceTotals;
+ }
+
+ /**
+ * Calculate the standard deviation
+ *
+ * @param input The input file containing the key and the count
+ * @param output The output file to write the counting results to
+ * @param baseConf The base configuration
+ * @return The standard deviation
+ */
+ public static double stdDev(Path input, Path output,
+ Configuration baseConf) throws IOException, InterruptedException,
+ ClassNotFoundException {
+ return Math.sqrt(variance(input, output, baseConf));
+ }
+
+ /**
+ * Calculate the standard deviation given a predefined mean
+ *
+ * @param input The input file containing the key and the count
+ * @param output The output file to write the counting results to
+ * @param mean The mean based on which to compute the standard deviation
+ * @param baseConf The base configuration
+ * @return The standard deviation
+ */
+ public static double stdDevForGivenMean(Path input, Path output, double mean,
+ Configuration baseConf) throws IOException, InterruptedException,
+ ClassNotFoundException {
+ return Math.sqrt(varianceForGivenMean(input, output, mean, baseConf));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java
new file mode 100644
index 0000000..03271da
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorMapper.java
@@ -0,0 +1,55 @@
+package org.apache.mahout.math.hadoop.stats;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapreduce.Mapper;
+
+import java.io.IOException;
+
+public class StandardDeviationCalculatorMapper extends
+ Mapper<IntWritable, Writable, IntWritable, DoubleWritable> {
+
+ public static final IntWritable SUM_OF_SQUARES = new IntWritable(1);
+ public static final IntWritable SUM = new IntWritable(2);
+ public static final IntWritable TOTAL_COUNT = new IntWritable(-1);
+
+ @Override
+ protected void map(IntWritable key, Writable value, Context context)
+ throws IOException, InterruptedException {
+ if (key.get() == -1) {
+ return;
+ }
+ //Kind of ugly, but such is life
+ double df = Double.NaN;
+ if (value instanceof LongWritable) {
+ df = ((LongWritable)value).get();
+ } else if (value instanceof DoubleWritable) {
+ df = ((DoubleWritable)value).get();
+ }
+ if (!Double.isNaN(df)) {
+ // For calculating the sum of squares
+ context.write(SUM_OF_SQUARES, new DoubleWritable(df * df));
+ context.write(SUM, new DoubleWritable(df));
+ // For calculating the total number of entries
+ context.write(TOTAL_COUNT, new DoubleWritable(1));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java
new file mode 100644
index 0000000..0a27eec
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/StandardDeviationCalculatorReducer.java
@@ -0,0 +1,37 @@
+package org.apache.mahout.math.hadoop.stats;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Reducer;
+
+import java.io.IOException;
+
+public class StandardDeviationCalculatorReducer extends
+ Reducer<IntWritable, DoubleWritable, IntWritable, DoubleWritable> {
+
+ @Override
+ protected void reduce(IntWritable key, Iterable<DoubleWritable> values,
+ Context context) throws IOException, InterruptedException {
+ double sum = 0.0;
+ for (DoubleWritable value : values) {
+ sum += value.get();
+ }
+ context.write(key, new DoubleWritable(sum));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java
new file mode 100644
index 0000000..87448bc
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stats/VarianceTotals.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stats;
+
+/**
+ * Holds the total values needed to compute mean and standard deviation
+ * Provides methods for their computation
+ */
+public final class VarianceTotals {
+
+ private double sumOfSquares;
+ private double sum;
+ private double totalCount;
+
+ public double getSumOfSquares() {
+ return sumOfSquares;
+ }
+
+ public void setSumOfSquares(double sumOfSquares) {
+ this.sumOfSquares = sumOfSquares;
+ }
+
+ public double getSum() {
+ return sum;
+ }
+
+ public void setSum(double sum) {
+ this.sum = sum;
+ }
+
+ public double getTotalCount() {
+ return totalCount;
+ }
+
+ public void setTotalCount(double totalCount) {
+ this.totalCount = totalCount;
+ }
+
+ public double computeMean() {
+ return sum / totalCount;
+ }
+
+ public double computeVariance() {
+ return ((totalCount * sumOfSquares) - (sum * sum))
+ / (totalCount * (totalCount - 1.0));
+ }
+
+ public double computeVarianceForGivenMean(double mean) {
+ return (sumOfSquares - totalCount * mean * mean)
+ / (totalCount - 1.0);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java
new file mode 100644
index 0000000..359b281
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtDenseOutJob.java
@@ -0,0 +1,585 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.text.NumberFormat;
+import java.util.ArrayDeque;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.regex.Matcher;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.lang3.Validate;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRFirstStep;
+
+/**
+ * Computes ABt products, then first step of QR which is pushed down to the
+ * reducer.
+ */
+@SuppressWarnings("deprecation")
+public final class ABtDenseOutJob {
+
+ public static final String PROP_BT_PATH = "ssvd.Bt.path";
+ public static final String PROP_BT_BROADCAST = "ssvd.Bt.broadcast";
+ public static final String PROP_SB_PATH = "ssvdpca.sb.path";
+ public static final String PROP_SQ_PATH = "ssvdpca.sq.path";
+ public static final String PROP_XI_PATH = "ssvdpca.xi.path";
+
+ private ABtDenseOutJob() {
+ }
+
+ /**
+ * So, here, i preload A block into memory.
+ * <P>
+ *
+ * A sparse matrix seems to be ideal for that but there are two reasons why i
+ * am not using it:
+ * <UL>
+ * <LI>1) I don't know the full block height. so i may need to reallocate it
+ * from time to time. Although this probably not a showstopper.
+ * <LI>2) I found that RandomAccessSparseVectors seem to take much more memory
+ * than the SequentialAccessSparseVectors.
+ * </UL>
+ * <P>
+ *
+ */
+ public static class ABtMapper
+ extends
+ Mapper<Writable, VectorWritable, SplitPartitionedWritable, DenseBlockWritable> {
+
+ private SplitPartitionedWritable outKey;
+ private final Deque<Closeable> closeables = new ArrayDeque<>();
+ private SequenceFileDirIterator<IntWritable, VectorWritable> btInput;
+ private Vector[] aCols;
+ private double[][] yiCols;
+ private int aRowCount;
+ private int kp;
+ private int blockHeight;
+ private boolean distributedBt;
+ private Path[] btLocalPath;
+ private Configuration localFsConfig;
+ /*
+ * xi and s_q are PCA-related corrections, per MAHOUT-817
+ */
+ protected Vector xi;
+ protected Vector sq;
+
+ @Override
+ protected void map(Writable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+
+ Vector vec = value.get();
+
+ int vecSize = vec.size();
+ if (aCols == null) {
+ aCols = new Vector[vecSize];
+ } else if (aCols.length < vecSize) {
+ aCols = Arrays.copyOf(aCols, vecSize);
+ }
+
+ if (vec.isDense()) {
+ for (int i = 0; i < vecSize; i++) {
+ extendAColIfNeeded(i, aRowCount + 1);
+ aCols[i].setQuick(aRowCount, vec.getQuick(i));
+ }
+ } else if (vec.size() > 0) {
+ for (Vector.Element vecEl : vec.nonZeroes()) {
+ int i = vecEl.index();
+ extendAColIfNeeded(i, aRowCount + 1);
+ aCols[i].setQuick(aRowCount, vecEl.get());
+ }
+ }
+ aRowCount++;
+ }
+
+ private void extendAColIfNeeded(int col, int rowCount) {
+ if (aCols[col] == null) {
+ aCols[col] =
+ new SequentialAccessSparseVector(rowCount < blockHeight ? blockHeight
+ : rowCount, 1);
+ } else if (aCols[col].size() < rowCount) {
+ Vector newVec =
+ new SequentialAccessSparseVector(rowCount + blockHeight,
+ aCols[col].getNumNondefaultElements() << 1);
+ newVec.viewPart(0, aCols[col].size()).assign(aCols[col]);
+ aCols[col] = newVec;
+ }
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+ try {
+
+ yiCols = new double[kp][];
+
+ for (int i = 0; i < kp; i++) {
+ yiCols[i] = new double[Math.min(aRowCount, blockHeight)];
+ }
+
+ int numPasses = (aRowCount - 1) / blockHeight + 1;
+
+ String propBtPathStr = context.getConfiguration().get(PROP_BT_PATH);
+ Validate.notNull(propBtPathStr, "Bt input is not set");
+ Path btPath = new Path(propBtPathStr);
+ DenseBlockWritable dbw = new DenseBlockWritable();
+
+ /*
+ * so it turns out that it may be much more efficient to do a few
+ * independent passes over Bt accumulating the entire block in memory
+ * than pass huge amount of blocks out to combiner. so we aim of course
+ * to fit entire s x (k+p) dense block in memory where s is the number
+ * of A rows in this split. If A is much sparser than (k+p) avg # of
+ * elements per row then the block may exceed the split size. if this
+ * happens, and if the given blockHeight is not high enough to
+ * accomodate this (because of memory constraints), then we start
+ * splitting s into several passes. since computation is cpu-bound
+ * anyway, it should be o.k. for supersparse inputs. (as ok it can be
+ * that projection is thicker than the original anyway, why would one
+ * use that many k+p then).
+ */
+ int lastRowIndex = -1;
+ for (int pass = 0; pass < numPasses; pass++) {
+
+ if (distributedBt) {
+
+ btInput =
+ new SequenceFileDirIterator<>(btLocalPath, true, localFsConfig);
+
+ } else {
+
+ btInput =
+ new SequenceFileDirIterator<>(btPath, PathType.GLOB, null, null, true, context.getConfiguration());
+ }
+ closeables.addFirst(btInput);
+ Validate.isTrue(btInput.hasNext(), "Empty B' input!");
+
+ int aRowBegin = pass * blockHeight;
+ int bh = Math.min(blockHeight, aRowCount - aRowBegin);
+
+ /*
+ * check if we need to trim block allocation
+ */
+ if (pass > 0) {
+ if (bh == blockHeight) {
+ for (int i = 0; i < kp; i++) {
+ Arrays.fill(yiCols[i], 0.0);
+ }
+ } else {
+
+ for (int i = 0; i < kp; i++) {
+ yiCols[i] = null;
+ }
+ for (int i = 0; i < kp; i++) {
+ yiCols[i] = new double[bh];
+ }
+ }
+ }
+
+ while (btInput.hasNext()) {
+ Pair<IntWritable, VectorWritable> btRec = btInput.next();
+ int btIndex = btRec.getFirst().get();
+ Vector btVec = btRec.getSecond().get();
+ Vector aCol;
+ if (btIndex > aCols.length || (aCol = aCols[btIndex]) == null
+ || aCol.size() == 0) {
+
+ /* 100% zero A column in the block, skip it as sparse */
+ continue;
+ }
+ int j = -1;
+ for (Vector.Element aEl : aCol.nonZeroes()) {
+ j = aEl.index();
+
+ /*
+ * now we compute only swathes between aRowBegin..aRowBegin+bh
+ * exclusive. it seems like a deficiency but in fact i think it
+ * will balance itself out: either A is dense and then we
+ * shouldn't have more than one pass and therefore filter
+ * conditions will never kick in. Or, the only situation where we
+ * can't fit Y_i block in memory is when A input is much sparser
+ * than k+p per row. But if this is the case, then we'd be looking
+ * at very few elements without engaging them in any operations so
+ * even then it should be ok.
+ */
+ if (j < aRowBegin) {
+ continue;
+ }
+ if (j >= aRowBegin + bh) {
+ break;
+ }
+
+ /*
+ * assume btVec is dense
+ */
+ if (xi != null) {
+ /*
+ * MAHOUT-817: PCA correction for B'. I rewrite the whole
+ * computation loop so i don't have to check if PCA correction
+ * is needed at individual element level. It looks bulkier this
+ * way but perhaps less wasteful on cpu.
+ */
+ for (int s = 0; s < kp; s++) {
+ // code defensively against shortened xi
+ double xii = xi.size() > btIndex ? xi.get(btIndex) : 0.0;
+ yiCols[s][j - aRowBegin] +=
+ aEl.get() * (btVec.getQuick(s) - xii * sq.get(s));
+ }
+ } else {
+ /*
+ * no PCA correction
+ */
+ for (int s = 0; s < kp; s++) {
+ yiCols[s][j - aRowBegin] += aEl.get() * btVec.getQuick(s);
+ }
+ }
+
+ }
+ if (lastRowIndex < j) {
+ lastRowIndex = j;
+ }
+ }
+
+ /*
+ * so now we have stuff in yi
+ */
+ dbw.setBlock(yiCols);
+ outKey.setTaskItemOrdinal(pass);
+ context.write(outKey, dbw);
+
+ closeables.remove(btInput);
+ btInput.close();
+ }
+
+ } finally {
+ IOUtils.close(closeables);
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+
+ Configuration conf = context.getConfiguration();
+ int k = Integer.parseInt(conf.get(QRFirstStep.PROP_K));
+ int p = Integer.parseInt(conf.get(QRFirstStep.PROP_P));
+ kp = k + p;
+
+ outKey = new SplitPartitionedWritable(context);
+
+ blockHeight = conf.getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
+ distributedBt = conf.get(PROP_BT_BROADCAST) != null;
+ if (distributedBt) {
+ btLocalPath = HadoopUtil.getCachedFiles(conf);
+ localFsConfig = new Configuration();
+ localFsConfig.set("fs.default.name", "file:///");
+ }
+
+ /*
+ * PCA -related corrections (MAHOUT-817)
+ */
+ String xiPathStr = conf.get(PROP_XI_PATH);
+ if (xiPathStr != null) {
+ xi = SSVDHelper.loadAndSumUpVectors(new Path(xiPathStr), conf);
+ sq =
+ SSVDHelper.loadAndSumUpVectors(new Path(conf.get(PROP_SQ_PATH)), conf);
+ }
+
+ }
+ }
+
+ /**
+ * QR first step pushed down to reducer.
+ *
+ */
+ public static class QRReducer
+ extends Reducer<SplitPartitionedWritable, DenseBlockWritable, SplitPartitionedWritable, VectorWritable> {
+
+ /*
+ * HACK: partition number formats in hadoop, copied. this may stop working
+ * if it gets out of sync with newer hadoop version. But unfortunately rules
+ * of forming output file names are not sufficiently exposed so we need to
+ * hack it if we write the same split output from either mapper or reducer.
+ * alternatively, we probably can replace it by our own output file naming
+ * management completely and bypass MultipleOutputs entirely.
+ */
+
+ private static final NumberFormat NUMBER_FORMAT =
+ NumberFormat.getInstance();
+ static {
+ NUMBER_FORMAT.setMinimumIntegerDigits(5);
+ NUMBER_FORMAT.setGroupingUsed(false);
+ }
+
+ private final Deque<Closeable> closeables = Lists.newLinkedList();
+
+ protected int blockHeight;
+
+ protected int lastTaskId = -1;
+
+ protected OutputCollector<Writable, DenseBlockWritable> qhatCollector;
+ protected OutputCollector<Writable, VectorWritable> rhatCollector;
+ protected QRFirstStep qr;
+ protected Vector yiRow;
+ protected Vector sb;
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ Configuration conf = context.getConfiguration();
+ blockHeight = conf.getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
+ String sbPathStr = conf.get(PROP_SB_PATH);
+
+ /*
+ * PCA -related corrections (MAHOUT-817)
+ */
+ if (sbPathStr != null) {
+ sb = SSVDHelper.loadAndSumUpVectors(new Path(sbPathStr), conf);
+ }
+ }
+
+ protected void setupBlock(Context context, SplitPartitionedWritable spw)
+ throws InterruptedException, IOException {
+ IOUtils.close(closeables);
+ qhatCollector =
+ createOutputCollector(QJob.OUTPUT_QHAT,
+ spw,
+ context,
+ DenseBlockWritable.class);
+ rhatCollector =
+ createOutputCollector(QJob.OUTPUT_RHAT,
+ spw,
+ context,
+ VectorWritable.class);
+ qr =
+ new QRFirstStep(context.getConfiguration(),
+ qhatCollector,
+ rhatCollector);
+ closeables.addFirst(qr);
+ lastTaskId = spw.getTaskId();
+
+ }
+
+ @Override
+ protected void reduce(SplitPartitionedWritable key,
+ Iterable<DenseBlockWritable> values,
+ Context context) throws IOException,
+ InterruptedException {
+
+ if (key.getTaskId() != lastTaskId) {
+ setupBlock(context, key);
+ }
+
+ Iterator<DenseBlockWritable> iter = values.iterator();
+ DenseBlockWritable dbw = iter.next();
+ double[][] yiCols = dbw.getBlock();
+ if (iter.hasNext()) {
+ throw new IOException("Unexpected extra Y_i block in reducer input.");
+ }
+
+ long blockBase = key.getTaskItemOrdinal() * blockHeight;
+ int bh = yiCols[0].length;
+ if (yiRow == null) {
+ yiRow = new DenseVector(yiCols.length);
+ }
+
+ for (int k = 0; k < bh; k++) {
+ for (int j = 0; j < yiCols.length; j++) {
+ yiRow.setQuick(j, yiCols[j][k]);
+ }
+
+ key.setTaskItemOrdinal(blockBase + k);
+
+ // pca offset correction if any
+ if (sb != null) {
+ yiRow.assign(sb, Functions.MINUS);
+ }
+
+ qr.collect(key, yiRow);
+ }
+
+ }
+
+ private Path getSplitFilePath(String name,
+ SplitPartitionedWritable spw,
+ Context context) throws InterruptedException,
+ IOException {
+ String uniqueFileName = FileOutputFormat.getUniqueFile(context, name, "");
+ uniqueFileName = uniqueFileName.replaceFirst("-r-", "-m-");
+ uniqueFileName =
+ uniqueFileName.replaceFirst("\\d+$",
+ Matcher.quoteReplacement(NUMBER_FORMAT.format(spw.getTaskId())));
+ return new Path(FileOutputFormat.getWorkOutputPath(context),
+ uniqueFileName);
+ }
+
+ /**
+ * key doesn't matter here, only value does. key always gets substituted by
+ * SPW.
+ *
+ * @param <K>
+ * bogus
+ */
+ private <K, V> OutputCollector<K, V> createOutputCollector(String name,
+ final SplitPartitionedWritable spw,
+ Context ctx,
+ Class<V> valueClass) throws IOException, InterruptedException {
+ Path outputPath = getSplitFilePath(name, spw, ctx);
+ final SequenceFile.Writer w =
+ SequenceFile.createWriter(FileSystem.get(outputPath.toUri(), ctx.getConfiguration()),
+ ctx.getConfiguration(),
+ outputPath,
+ SplitPartitionedWritable.class,
+ valueClass);
+ closeables.addFirst(w);
+ return new OutputCollector<K, V>() {
+ @Override
+ public void collect(K key, V val) throws IOException {
+ w.append(spw, val);
+ }
+ };
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+
+ IOUtils.close(closeables);
+ }
+
+ }
+
+ public static void run(Configuration conf,
+ Path[] inputAPaths,
+ Path inputBtGlob,
+ Path xiPath,
+ Path sqPath,
+ Path sbPath,
+ Path outputPath,
+ int aBlockRows,
+ int minSplitSize,
+ int k,
+ int p,
+ int outerProdBlockHeight,
+ int numReduceTasks,
+ boolean broadcastBInput)
+ throws ClassNotFoundException, InterruptedException, IOException {
+
+ JobConf oldApiJob = new JobConf(conf);
+
+ Job job = new Job(oldApiJob);
+ job.setJobName("ABt-job");
+ job.setJarByClass(ABtDenseOutJob.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ FileInputFormat.setInputPaths(job, inputAPaths);
+ if (minSplitSize > 0) {
+ FileInputFormat.setMinInputSplitSize(job, minSplitSize);
+ }
+
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ SequenceFileOutputFormat.setOutputCompressionType(job,
+ CompressionType.BLOCK);
+
+ job.setMapOutputKeyClass(SplitPartitionedWritable.class);
+ job.setMapOutputValueClass(DenseBlockWritable.class);
+
+ job.setOutputKeyClass(SplitPartitionedWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+
+ job.setMapperClass(ABtMapper.class);
+ job.setReducerClass(QRReducer.class);
+
+ job.getConfiguration().setInt(QJob.PROP_AROWBLOCK_SIZE, aBlockRows);
+ job.getConfiguration().setInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT,
+ outerProdBlockHeight);
+ job.getConfiguration().setInt(QRFirstStep.PROP_K, k);
+ job.getConfiguration().setInt(QRFirstStep.PROP_P, p);
+ job.getConfiguration().set(PROP_BT_PATH, inputBtGlob.toString());
+
+ /*
+ * PCA-related options, MAHOUT-817
+ */
+ if (xiPath != null) {
+ job.getConfiguration().set(PROP_XI_PATH, xiPath.toString());
+ job.getConfiguration().set(PROP_SB_PATH, sbPath.toString());
+ job.getConfiguration().set(PROP_SQ_PATH, sqPath.toString());
+ }
+
+ job.setNumReduceTasks(numReduceTasks);
+
+ // broadcast Bt files if required.
+ if (broadcastBInput) {
+ job.getConfiguration().set(PROP_BT_BROADCAST, "y");
+
+ FileSystem fs = FileSystem.get(inputBtGlob.toUri(), conf);
+ FileStatus[] fstats = fs.globStatus(inputBtGlob);
+ if (fstats != null) {
+ for (FileStatus fstat : fstats) {
+ /*
+ * new api is not enabled yet in our dependencies at this time, still
+ * using deprecated one
+ */
+ DistributedCache.addCacheFile(fstat.getPath().toUri(),
+ job.getConfiguration());
+ }
+ }
+ }
+
+ job.submit();
+ job.waitForCompletion(false);
+
+ if (!job.isSuccessful()) {
+ throw new IOException("ABt job unsuccessful.");
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java
new file mode 100644
index 0000000..afa1463
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/ABtJob.java
@@ -0,0 +1,494 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.text.NumberFormat;
+import java.util.ArrayDeque;
+import java.util.Arrays;
+import java.util.Deque;
+import java.util.regex.Matcher;
+
+import com.google.common.collect.Lists;
+import org.apache.commons.lang3.Validate;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRFirstStep;
+
+/**
+ * Computes ABt products, then first step of QR which is pushed down to the
+ * reducer.
+ *
+ */
+@SuppressWarnings("deprecation")
+public final class ABtJob {
+
+ public static final String PROP_BT_PATH = "ssvd.Bt.path";
+ public static final String PROP_BT_BROADCAST = "ssvd.Bt.broadcast";
+
+ private ABtJob() {
+ }
+
+ /**
+ * So, here, i preload A block into memory.
+ * <P>
+ *
+ * A sparse matrix seems to be ideal for that but there are two reasons why i
+ * am not using it:
+ * <UL>
+ * <LI>1) I don't know the full block height. so i may need to reallocate it
+ * from time to time. Although this probably not a showstopper.
+ * <LI>2) I found that RandomAccessSparseVectors seem to take much more memory
+ * than the SequentialAccessSparseVectors.
+ * </UL>
+ * <P>
+ *
+ */
+ public static class ABtMapper
+ extends
+ Mapper<Writable, VectorWritable, SplitPartitionedWritable, SparseRowBlockWritable> {
+
+ private SplitPartitionedWritable outKey;
+ private final Deque<Closeable> closeables = new ArrayDeque<>();
+ private SequenceFileDirIterator<IntWritable, VectorWritable> btInput;
+ private Vector[] aCols;
+ // private Vector[] yiRows;
+ // private VectorWritable outValue = new VectorWritable();
+ private int aRowCount;
+ private int kp;
+ private int blockHeight;
+ private SparseRowBlockAccumulator yiCollector;
+
+ @Override
+ protected void map(Writable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+
+ Vector vec = value.get();
+
+ int vecSize = vec.size();
+ if (aCols == null) {
+ aCols = new Vector[vecSize];
+ } else if (aCols.length < vecSize) {
+ aCols = Arrays.copyOf(aCols, vecSize);
+ }
+
+ if (vec.isDense()) {
+ for (int i = 0; i < vecSize; i++) {
+ extendAColIfNeeded(i, aRowCount + 1);
+ aCols[i].setQuick(aRowCount, vec.getQuick(i));
+ }
+ } else {
+ for (Vector.Element vecEl : vec.nonZeroes()) {
+ int i = vecEl.index();
+ extendAColIfNeeded(i, aRowCount + 1);
+ aCols[i].setQuick(aRowCount, vecEl.get());
+ }
+ }
+ aRowCount++;
+ }
+
+ private void extendAColIfNeeded(int col, int rowCount) {
+ if (aCols[col] == null) {
+ aCols[col] =
+ new SequentialAccessSparseVector(rowCount < 10000 ? 10000 : rowCount,
+ 1);
+ } else if (aCols[col].size() < rowCount) {
+ Vector newVec =
+ new SequentialAccessSparseVector(rowCount << 1,
+ aCols[col].getNumNondefaultElements() << 1);
+ newVec.viewPart(0, aCols[col].size()).assign(aCols[col]);
+ aCols[col] = newVec;
+ }
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+ try {
+ // yiRows= new Vector[aRowCount];
+
+ int lastRowIndex = -1;
+
+ while (btInput.hasNext()) {
+ Pair<IntWritable, VectorWritable> btRec = btInput.next();
+ int btIndex = btRec.getFirst().get();
+ Vector btVec = btRec.getSecond().get();
+ Vector aCol;
+ if (btIndex > aCols.length || (aCol = aCols[btIndex]) == null) {
+ continue;
+ }
+ int j = -1;
+ for (Vector.Element aEl : aCol.nonZeroes()) {
+ j = aEl.index();
+
+ // outKey.setTaskItemOrdinal(j);
+ // outValue.set(btVec.times(aEl.get())); // assign might work better
+ // // with memory after all.
+ // context.write(outKey, outValue);
+ yiCollector.collect((long) j, btVec.times(aEl.get()));
+ }
+ if (lastRowIndex < j) {
+ lastRowIndex = j;
+ }
+ }
+ aCols = null;
+
+ // output empty rows if we never output partial products for them
+ // this happens in sparse matrices when last rows are all zeros
+ // and is subsequently causing shorter Q matrix row count which we
+ // probably don't want to repair there but rather here.
+ Vector yDummy = new SequentialAccessSparseVector(kp);
+ // outValue.set(yDummy);
+ for (lastRowIndex += 1; lastRowIndex < aRowCount; lastRowIndex++) {
+ // outKey.setTaskItemOrdinal(lastRowIndex);
+ // context.write(outKey, outValue);
+
+ yiCollector.collect((long) lastRowIndex, yDummy);
+ }
+
+ } finally {
+ IOUtils.close(closeables);
+ }
+ }
+
+ @Override
+ protected void setup(final Context context) throws IOException,
+ InterruptedException {
+
+ int k =
+ Integer.parseInt(context.getConfiguration().get(QRFirstStep.PROP_K));
+ int p =
+ Integer.parseInt(context.getConfiguration().get(QRFirstStep.PROP_P));
+ kp = k + p;
+
+ outKey = new SplitPartitionedWritable(context);
+ String propBtPathStr = context.getConfiguration().get(PROP_BT_PATH);
+ Validate.notNull(propBtPathStr, "Bt input is not set");
+ Path btPath = new Path(propBtPathStr);
+
+ boolean distributedBt =
+ context.getConfiguration().get(PROP_BT_BROADCAST) != null;
+
+ if (distributedBt) {
+
+ Path[] btFiles = HadoopUtil.getCachedFiles(context.getConfiguration());
+
+ // DEBUG: stdout
+ //System.out.printf("list of files: " + btFiles);
+
+ StringBuilder btLocalPath = new StringBuilder();
+ for (Path btFile : btFiles) {
+ if (btLocalPath.length() > 0) {
+ btLocalPath.append(Path.SEPARATOR_CHAR);
+ }
+ btLocalPath.append(btFile);
+ }
+
+ btInput =
+ new SequenceFileDirIterator<>(new Path(btLocalPath.toString()),
+ PathType.LIST,
+ null,
+ null,
+ true,
+ context.getConfiguration());
+
+ } else {
+
+ btInput =
+ new SequenceFileDirIterator<>(btPath, PathType.GLOB, null, null, true, context.getConfiguration());
+ }
+ // TODO: how do i release all that stuff??
+ closeables.addFirst(btInput);
+ OutputCollector<LongWritable, SparseRowBlockWritable> yiBlockCollector =
+ new OutputCollector<LongWritable, SparseRowBlockWritable>() {
+
+ @Override
+ public void collect(LongWritable blockKey,
+ SparseRowBlockWritable block) throws IOException {
+ outKey.setTaskItemOrdinal((int) blockKey.get());
+ try {
+ context.write(outKey, block);
+ } catch (InterruptedException exc) {
+ throw new IOException("Interrupted", exc);
+ }
+ }
+ };
+ blockHeight =
+ context.getConfiguration().getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT,
+ -1);
+ yiCollector =
+ new SparseRowBlockAccumulator(blockHeight, yiBlockCollector);
+ closeables.addFirst(yiCollector);
+ }
+
+ }
+
+ /**
+ * QR first step pushed down to reducer.
+ *
+ */
+ public static class QRReducer
+ extends
+ Reducer<SplitPartitionedWritable, SparseRowBlockWritable, SplitPartitionedWritable, VectorWritable> {
+
+ // hack: partition number formats in hadoop, copied. this may stop working
+ // if it gets
+ // out of sync with newer hadoop version. But unfortunately rules of forming
+ // output file names are not sufficiently exposed so we need to hack it
+ // if we write the same split output from either mapper or reducer.
+ // alternatively, we probably can replace it by our own output file namnig
+ // management
+ // completely and bypass MultipleOutputs entirely.
+
+ private static final NumberFormat NUMBER_FORMAT =
+ NumberFormat.getInstance();
+ static {
+ NUMBER_FORMAT.setMinimumIntegerDigits(5);
+ NUMBER_FORMAT.setGroupingUsed(false);
+ }
+
+ private final Deque<Closeable> closeables = Lists.newLinkedList();
+ protected final SparseRowBlockWritable accum = new SparseRowBlockWritable();
+
+ protected int blockHeight;
+
+ protected int lastTaskId = -1;
+
+ protected OutputCollector<Writable, DenseBlockWritable> qhatCollector;
+ protected OutputCollector<Writable, VectorWritable> rhatCollector;
+ protected QRFirstStep qr;
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ blockHeight =
+ context.getConfiguration().getInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT,
+ -1);
+
+ }
+
+ protected void setupBlock(Context context, SplitPartitionedWritable spw)
+ throws InterruptedException, IOException {
+ IOUtils.close(closeables);
+ qhatCollector =
+ createOutputCollector(QJob.OUTPUT_QHAT,
+ spw,
+ context,
+ DenseBlockWritable.class);
+ rhatCollector =
+ createOutputCollector(QJob.OUTPUT_RHAT,
+ spw,
+ context,
+ VectorWritable.class);
+ qr =
+ new QRFirstStep(context.getConfiguration(),
+ qhatCollector,
+ rhatCollector);
+ closeables.addFirst(qr);
+ lastTaskId = spw.getTaskId();
+
+ }
+
+ @Override
+ protected void reduce(SplitPartitionedWritable key,
+ Iterable<SparseRowBlockWritable> values,
+ Context context) throws IOException,
+ InterruptedException {
+
+ accum.clear();
+ for (SparseRowBlockWritable bw : values) {
+ accum.plusBlock(bw);
+ }
+
+ if (key.getTaskId() != lastTaskId) {
+ setupBlock(context, key);
+ }
+
+ long blockBase = key.getTaskItemOrdinal() * blockHeight;
+ for (int k = 0; k < accum.getNumRows(); k++) {
+ Vector yiRow = accum.getRows()[k];
+ key.setTaskItemOrdinal(blockBase + accum.getRowIndices()[k]);
+ qr.collect(key, yiRow);
+ }
+
+ }
+
+ private Path getSplitFilePath(String name,
+ SplitPartitionedWritable spw,
+ Context context) throws InterruptedException,
+ IOException {
+ String uniqueFileName = FileOutputFormat.getUniqueFile(context, name, "");
+ uniqueFileName = uniqueFileName.replaceFirst("-r-", "-m-");
+ uniqueFileName =
+ uniqueFileName.replaceFirst("\\d+$",
+ Matcher.quoteReplacement(NUMBER_FORMAT.format(spw.getTaskId())));
+ return new Path(FileOutputFormat.getWorkOutputPath(context),
+ uniqueFileName);
+ }
+
+ /**
+ * key doesn't matter here, only value does. key always gets substituted by
+ * SPW.
+ */
+ private <K,V> OutputCollector<K,V> createOutputCollector(String name,
+ final SplitPartitionedWritable spw,
+ Context ctx,
+ Class<V> valueClass)
+ throws IOException, InterruptedException {
+ Path outputPath = getSplitFilePath(name, spw, ctx);
+ final SequenceFile.Writer w =
+ SequenceFile.createWriter(FileSystem.get(outputPath.toUri(), ctx.getConfiguration()),
+ ctx.getConfiguration(),
+ outputPath,
+ SplitPartitionedWritable.class,
+ valueClass);
+ closeables.addFirst(w);
+ return new OutputCollector<K, V>() {
+ @Override
+ public void collect(K key, V val) throws IOException {
+ w.append(spw, val);
+ }
+ };
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException, InterruptedException {
+ IOUtils.close(closeables);
+ }
+
+ }
+
+ public static void run(Configuration conf,
+ Path[] inputAPaths,
+ Path inputBtGlob,
+ Path outputPath,
+ int aBlockRows,
+ int minSplitSize,
+ int k,
+ int p,
+ int outerProdBlockHeight,
+ int numReduceTasks,
+ boolean broadcastBInput)
+ throws ClassNotFoundException, InterruptedException, IOException {
+
+ JobConf oldApiJob = new JobConf(conf);
+
+ // MultipleOutputs
+ // .addNamedOutput(oldApiJob,
+ // QJob.OUTPUT_QHAT,
+ // org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
+ // SplitPartitionedWritable.class,
+ // DenseBlockWritable.class);
+ //
+ // MultipleOutputs
+ // .addNamedOutput(oldApiJob,
+ // QJob.OUTPUT_RHAT,
+ // org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
+ // SplitPartitionedWritable.class,
+ // VectorWritable.class);
+
+ Job job = new Job(oldApiJob);
+ job.setJobName("ABt-job");
+ job.setJarByClass(ABtJob.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ FileInputFormat.setInputPaths(job, inputAPaths);
+ if (minSplitSize > 0) {
+ FileInputFormat.setMinInputSplitSize(job, minSplitSize);
+ }
+
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ SequenceFileOutputFormat.setOutputCompressionType(job,
+ CompressionType.BLOCK);
+
+ job.setMapOutputKeyClass(SplitPartitionedWritable.class);
+ job.setMapOutputValueClass(SparseRowBlockWritable.class);
+
+ job.setOutputKeyClass(SplitPartitionedWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+
+ job.setMapperClass(ABtMapper.class);
+ job.setCombinerClass(BtJob.OuterProductCombiner.class);
+ job.setReducerClass(QRReducer.class);
+
+ job.getConfiguration().setInt(QJob.PROP_AROWBLOCK_SIZE, aBlockRows);
+ job.getConfiguration().setInt(BtJob.PROP_OUTER_PROD_BLOCK_HEIGHT,
+ outerProdBlockHeight);
+ job.getConfiguration().setInt(QRFirstStep.PROP_K, k);
+ job.getConfiguration().setInt(QRFirstStep.PROP_P, p);
+ job.getConfiguration().set(PROP_BT_PATH, inputBtGlob.toString());
+
+ // number of reduce tasks doesn't matter. we don't actually
+ // send anything to reducers.
+
+ job.setNumReduceTasks(numReduceTasks);
+
+ // broadcast Bt files if required.
+ if (broadcastBInput) {
+ job.getConfiguration().set(PROP_BT_BROADCAST, "y");
+
+ FileSystem fs = FileSystem.get(inputBtGlob.toUri(), conf);
+ FileStatus[] fstats = fs.globStatus(inputBtGlob);
+ if (fstats != null) {
+ for (FileStatus fstat : fstats) {
+ /*
+ * new api is not enabled yet in our dependencies at this time, still
+ * using deprecated one
+ */
+ DistributedCache.addCacheFile(fstat.getPath().toUri(), conf);
+ }
+ }
+ }
+
+ job.submit();
+ job.waitForCompletion(false);
+
+ if (!job.isSuccessful()) {
+ throw new IOException("ABt job unsuccessful.");
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
new file mode 100644
index 0000000..1277bae
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/BtJob.java
@@ -0,0 +1,628 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import org.apache.commons.lang3.Validate;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.filecache.DistributedCache;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.SequenceFile.CompressionType;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.compress.DefaultCodec;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.lib.MultipleOutputs;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.UpperTriangular;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.function.PlusMult;
+import org.apache.mahout.math.hadoop.stochasticsvd.qr.QRLastStep;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.ArrayDeque;
+import java.util.Deque;
+
+/**
+ * Bt job. For details, see working notes in MAHOUT-376.
+ * <p/>
+ * <p/>
+ * Uses hadoop deprecated API wherever new api has not been updated
+ * (MAHOUT-593), hence @SuppressWarning("deprecation").
+ * <p/>
+ * <p/>
+ * This job outputs either Bt in its standard output, or upper triangular
+ * matrices representing BBt partial sums if that's requested . If the latter
+ * mode is enabled, then we accumulate BBt outer product sums in upper
+ * triangular accumulator and output it at the end of the job, thus saving space
+ * and BBt job.
+ * <p/>
+ * <p/>
+ * This job also outputs Q and Bt and optionally BBt. Bt is output to standard
+ * job output (part-*) and Q and BBt use named multiple outputs.
+ * <p/>
+ * <p/>
+ */
+@SuppressWarnings("deprecation")
+public final class BtJob {
+
+ public static final String OUTPUT_Q = "Q";
+ public static final String OUTPUT_BT = "part";
+ public static final String OUTPUT_BBT = "bbt";
+ public static final String OUTPUT_SQ = "sq";
+ public static final String OUTPUT_SB = "sb";
+
+ public static final String PROP_QJOB_PATH = "ssvd.QJob.path";
+ public static final String PROP_OUPTUT_BBT_PRODUCTS =
+ "ssvd.BtJob.outputBBtProducts";
+ public static final String PROP_OUTER_PROD_BLOCK_HEIGHT =
+ "ssvd.outerProdBlockHeight";
+ public static final String PROP_RHAT_BROADCAST = "ssvd.rhat.broadcast";
+ public static final String PROP_XI_PATH = "ssvdpca.xi.path";
+ public static final String PROP_NV = "ssvd.nv";
+
+ private BtJob() {
+ }
+
+ public static class BtMapper extends
+ Mapper<Writable, VectorWritable, LongWritable, SparseRowBlockWritable> {
+
+ private QRLastStep qr;
+ private final Deque<Closeable> closeables = new ArrayDeque<>();
+
+ private int blockNum;
+ private MultipleOutputs outputs;
+ private final VectorWritable qRowValue = new VectorWritable();
+ private Vector btRow;
+ private SparseRowBlockAccumulator btCollector;
+ private Context mapContext;
+ private boolean nv;
+
+ // pca stuff
+ private Vector sqAccum;
+ private boolean computeSq;
+
+ /**
+ * We maintain A and QtHat inputs partitioned the same way, so we
+ * essentially are performing map-side merge here of A and QtHats except
+ * QtHat is stored not row-wise but block-wise.
+ */
+ @Override
+ protected void map(Writable key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+
+ mapContext = context;
+ // output Bt outer products
+ Vector aRow = value.get();
+
+ Vector qRow = qr.next();
+ int kp = qRow.size();
+
+ // make sure Qs are inheriting A row labels.
+ outputQRow(key, qRow, aRow);
+
+ // MAHOUT-817
+ if (computeSq) {
+ if (sqAccum == null) {
+ sqAccum = new DenseVector(kp);
+ }
+ sqAccum.assign(qRow, Functions.PLUS);
+ }
+
+ if (btRow == null) {
+ btRow = new DenseVector(kp);
+ }
+
+ if (!aRow.isDense()) {
+ for (Vector.Element el : aRow.nonZeroes()) {
+ double mul = el.get();
+ for (int j = 0; j < kp; j++) {
+ btRow.setQuick(j, mul * qRow.getQuick(j));
+ }
+ btCollector.collect((long) el.index(), btRow);
+ }
+ } else {
+ int n = aRow.size();
+ for (int i = 0; i < n; i++) {
+ double mul = aRow.getQuick(i);
+ for (int j = 0; j < kp; j++) {
+ btRow.setQuick(j, mul * qRow.getQuick(j));
+ }
+ btCollector.collect((long) i, btRow);
+ }
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ super.setup(context);
+
+ Configuration conf = context.getConfiguration();
+
+ Path qJobPath = new Path(conf.get(PROP_QJOB_PATH));
+
+ /*
+ * actually this is kind of dangerous because this routine thinks we need
+ * to create file name for our current job and this will use -m- so it's
+ * just serendipity we are calling it from the mapper too as the QJob did.
+ */
+ Path qInputPath =
+ new Path(qJobPath, FileOutputFormat.getUniqueFile(context,
+ QJob.OUTPUT_QHAT,
+ ""));
+ blockNum = context.getTaskAttemptID().getTaskID().getId();
+
+ SequenceFileValueIterator<DenseBlockWritable> qhatInput =
+ new SequenceFileValueIterator<>(qInputPath,
+ true,
+ conf);
+ closeables.addFirst(qhatInput);
+
+ /*
+ * read all r files _in order of task ids_, i.e. partitions (aka group
+ * nums).
+ *
+ * Note: if broadcast option is used, this comes from distributed cache
+ * files rather than hdfs path.
+ */
+
+ SequenceFileDirValueIterator<VectorWritable> rhatInput;
+
+ boolean distributedRHat = conf.get(PROP_RHAT_BROADCAST) != null;
+ if (distributedRHat) {
+
+ Path[] rFiles = HadoopUtil.getCachedFiles(conf);
+
+ Validate.notNull(rFiles,
+ "no RHat files in distributed cache job definition");
+ //TODO: this probably can be replaced w/ local fs makeQualified
+ Configuration lconf = new Configuration();
+ lconf.set("fs.default.name", "file:///");
+
+ rhatInput =
+ new SequenceFileDirValueIterator<>(rFiles,
+ SSVDHelper.PARTITION_COMPARATOR,
+ true,
+ lconf);
+
+ } else {
+ Path rPath = new Path(qJobPath, QJob.OUTPUT_RHAT + "-*");
+ rhatInput =
+ new SequenceFileDirValueIterator<>(rPath,
+ PathType.GLOB,
+ null,
+ SSVDHelper.PARTITION_COMPARATOR,
+ true,
+ conf);
+ }
+
+ Validate.isTrue(rhatInput.hasNext(), "Empty R-hat input!");
+
+ closeables.addFirst(rhatInput);
+ outputs = new MultipleOutputs(new JobConf(conf));
+ closeables.addFirst(new IOUtils.MultipleOutputsCloseableAdapter(outputs));
+
+ qr = new QRLastStep(qhatInput, rhatInput, blockNum);
+ closeables.addFirst(qr);
+ /*
+ * it's so happens that current QRLastStep's implementation preloads R
+ * sequence into memory in the constructor so it's ok to close rhat input
+ * now.
+ */
+ if (!rhatInput.hasNext()) {
+ closeables.remove(rhatInput);
+ rhatInput.close();
+ }
+
+ OutputCollector<LongWritable, SparseRowBlockWritable> btBlockCollector =
+ new OutputCollector<LongWritable, SparseRowBlockWritable>() {
+
+ @Override
+ public void collect(LongWritable blockKey,
+ SparseRowBlockWritable block) throws IOException {
+ try {
+ mapContext.write(blockKey, block);
+ } catch (InterruptedException exc) {
+ throw new IOException("Interrupted.", exc);
+ }
+ }
+ };
+
+ btCollector =
+ new SparseRowBlockAccumulator(conf.getInt(PROP_OUTER_PROD_BLOCK_HEIGHT,
+ -1), btBlockCollector);
+ closeables.addFirst(btCollector);
+
+ // MAHOUT-817
+ computeSq = conf.get(PROP_XI_PATH) != null;
+
+ // MAHOUT-1067
+ nv = conf.getBoolean(PROP_NV, false);
+
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+ try {
+ if (sqAccum != null) {
+ /*
+ * hack: we will output sq partial sums with index -1 for summation.
+ */
+ SparseRowBlockWritable sbrw = new SparseRowBlockWritable(1);
+ sbrw.plusRow(0, sqAccum);
+ LongWritable lw = new LongWritable(-1);
+ context.write(lw, sbrw);
+ }
+ } finally {
+ IOUtils.close(closeables);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ private void outputQRow(Writable key, Vector qRow, Vector aRow) throws IOException {
+ if (nv && (aRow instanceof NamedVector)) {
+ qRowValue.set(new NamedVector(qRow, ((NamedVector) aRow).getName()));
+ } else {
+ qRowValue.set(qRow);
+ }
+ outputs.getCollector(OUTPUT_Q, null).collect(key, qRowValue);
+ }
+ }
+
+ public static class OuterProductCombiner
+ extends
+ Reducer<Writable, SparseRowBlockWritable, Writable, SparseRowBlockWritable> {
+
+ protected final SparseRowBlockWritable accum = new SparseRowBlockWritable();
+ protected final Deque<Closeable> closeables = new ArrayDeque<>();
+ protected int blockHeight;
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+ blockHeight =
+ context.getConfiguration().getInt(PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
+ }
+
+ @Override
+ protected void reduce(Writable key,
+ Iterable<SparseRowBlockWritable> values,
+ Context context) throws IOException,
+ InterruptedException {
+ for (SparseRowBlockWritable bw : values) {
+ accum.plusBlock(bw);
+ }
+ context.write(key, accum);
+ accum.clear();
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+
+ IOUtils.close(closeables);
+ }
+ }
+
+ public static class OuterProductReducer
+ extends
+ Reducer<LongWritable, SparseRowBlockWritable, IntWritable, VectorWritable> {
+
+ protected final SparseRowBlockWritable accum = new SparseRowBlockWritable();
+ protected final Deque<Closeable> closeables = new ArrayDeque<>();
+
+ protected int blockHeight;
+ private boolean outputBBt;
+ private UpperTriangular mBBt;
+ private MultipleOutputs outputs;
+ private final IntWritable btKey = new IntWritable();
+ private final VectorWritable btValue = new VectorWritable();
+
+ // MAHOUT-817
+ private Vector xi;
+ private final PlusMult pmult = new PlusMult(0);
+ private Vector sbAccum;
+
+ @Override
+ protected void setup(Context context) throws IOException,
+ InterruptedException {
+
+ Configuration conf = context.getConfiguration();
+ blockHeight = conf.getInt(PROP_OUTER_PROD_BLOCK_HEIGHT, -1);
+
+ outputBBt = conf.getBoolean(PROP_OUPTUT_BBT_PRODUCTS, false);
+
+ if (outputBBt) {
+ int k = conf.getInt(QJob.PROP_K, -1);
+ int p = conf.getInt(QJob.PROP_P, -1);
+
+ Validate.isTrue(k > 0, "invalid k parameter");
+ Validate.isTrue(p >= 0, "invalid p parameter");
+ mBBt = new UpperTriangular(k + p);
+
+ }
+
+ String xiPathStr = conf.get(PROP_XI_PATH);
+ if (xiPathStr != null) {
+ xi = SSVDHelper.loadAndSumUpVectors(new Path(xiPathStr), conf);
+ if (xi == null) {
+ throw new IOException(String.format("unable to load mean path xi from %s.",
+ xiPathStr));
+ }
+ }
+
+ if (outputBBt || xi != null) {
+ outputs = new MultipleOutputs(new JobConf(conf));
+ closeables.addFirst(new IOUtils.MultipleOutputsCloseableAdapter(outputs));
+ }
+
+ }
+
+ @Override
+ protected void reduce(LongWritable key,
+ Iterable<SparseRowBlockWritable> values,
+ Context context) throws IOException,
+ InterruptedException {
+
+ accum.clear();
+ for (SparseRowBlockWritable bw : values) {
+ accum.plusBlock(bw);
+ }
+
+ // MAHOUT-817:
+ if (key.get() == -1L) {
+
+ Vector sq = accum.getRows()[0];
+
+ @SuppressWarnings("unchecked")
+ OutputCollector<IntWritable, VectorWritable> sqOut =
+ outputs.getCollector(OUTPUT_SQ, null);
+
+ sqOut.collect(new IntWritable(0), new VectorWritable(sq));
+ return;
+ }
+
+ /*
+ * at this point, sum of rows should be in accum, so we just generate
+ * outer self product of it and add to BBt accumulator.
+ */
+
+ for (int k = 0; k < accum.getNumRows(); k++) {
+ Vector btRow = accum.getRows()[k];
+ btKey.set((int) (key.get() * blockHeight + accum.getRowIndices()[k]));
+ btValue.set(btRow);
+ context.write(btKey, btValue);
+
+ if (outputBBt) {
+ int kp = mBBt.numRows();
+ // accumulate partial BBt sum
+ for (int i = 0; i < kp; i++) {
+ double vi = btRow.get(i);
+ if (vi != 0.0) {
+ for (int j = i; j < kp; j++) {
+ double vj = btRow.get(j);
+ if (vj != 0.0) {
+ mBBt.setQuick(i, j, mBBt.getQuick(i, j) + vi * vj);
+ }
+ }
+ }
+ }
+ }
+
+ // MAHOUT-817
+ if (xi != null) {
+ // code defensively against shortened xi
+ int btIndex = btKey.get();
+ double xii = xi.size() > btIndex ? xi.getQuick(btIndex) : 0.0;
+ // compute s_b
+ pmult.setMultiplicator(xii);
+ if (sbAccum == null) {
+ sbAccum = new DenseVector(btRow.size());
+ }
+ sbAccum.assign(btRow, pmult);
+ }
+
+ }
+ }
+
+ @Override
+ protected void cleanup(Context context) throws IOException,
+ InterruptedException {
+
+ // if we output BBt instead of Bt then we need to do it.
+ try {
+ if (outputBBt) {
+
+ @SuppressWarnings("unchecked")
+ OutputCollector<Writable, Writable> collector =
+ outputs.getCollector(OUTPUT_BBT, null);
+
+ collector.collect(new IntWritable(),
+ new VectorWritable(new DenseVector(mBBt.getData())));
+ }
+
+ // MAHOUT-817
+ if (sbAccum != null) {
+ @SuppressWarnings("unchecked")
+ OutputCollector<IntWritable, VectorWritable> collector =
+ outputs.getCollector(OUTPUT_SB, null);
+
+ collector.collect(new IntWritable(), new VectorWritable(sbAccum));
+
+ }
+ } finally {
+ IOUtils.close(closeables);
+ }
+
+ }
+ }
+
+ public static void run(Configuration conf,
+ Path[] inputPathA,
+ Path inputPathQJob,
+ Path xiPath,
+ Path outputPath,
+ int minSplitSize,
+ int k,
+ int p,
+ int btBlockHeight,
+ int numReduceTasks,
+ boolean broadcast,
+ Class<? extends Writable> labelClass,
+ boolean outputBBtProducts)
+ throws ClassNotFoundException, InterruptedException, IOException {
+
+ JobConf oldApiJob = new JobConf(conf);
+
+ MultipleOutputs.addNamedOutput(oldApiJob,
+ OUTPUT_Q,
+ org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
+ labelClass,
+ VectorWritable.class);
+
+ if (outputBBtProducts) {
+ MultipleOutputs.addNamedOutput(oldApiJob,
+ OUTPUT_BBT,
+ org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
+ IntWritable.class,
+ VectorWritable.class);
+ /*
+ * MAHOUT-1067: if we are asked to output BBT products then named vector
+ * names should be propagated to Q too so that UJob could pick them up
+ * from there.
+ */
+ oldApiJob.setBoolean(PROP_NV, true);
+ }
+ if (xiPath != null) {
+ // compute pca -related stuff as well
+ MultipleOutputs.addNamedOutput(oldApiJob,
+ OUTPUT_SQ,
+ org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
+ IntWritable.class,
+ VectorWritable.class);
+ MultipleOutputs.addNamedOutput(oldApiJob,
+ OUTPUT_SB,
+ org.apache.hadoop.mapred.SequenceFileOutputFormat.class,
+ IntWritable.class,
+ VectorWritable.class);
+ }
+
+ /*
+ * HACK: we use old api multiple outputs since they are not available in the
+ * new api of either 0.20.2 or 0.20.203 but wrap it into a new api job so we
+ * can use new api interfaces.
+ */
+
+ Job job = new Job(oldApiJob);
+ job.setJobName("Bt-job");
+ job.setJarByClass(BtJob.class);
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ FileInputFormat.setInputPaths(job, inputPathA);
+ if (minSplitSize > 0) {
+ FileInputFormat.setMinInputSplitSize(job, minSplitSize);
+ }
+ FileOutputFormat.setOutputPath(job, outputPath);
+
+ // WARN: tight hadoop integration here:
+ job.getConfiguration().set("mapreduce.output.basename", OUTPUT_BT);
+
+ FileOutputFormat.setOutputCompressorClass(job, DefaultCodec.class);
+ SequenceFileOutputFormat.setOutputCompressionType(job,
+ CompressionType.BLOCK);
+
+ job.setMapOutputKeyClass(LongWritable.class);
+ job.setMapOutputValueClass(SparseRowBlockWritable.class);
+
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+
+ job.setMapperClass(BtMapper.class);
+ job.setCombinerClass(OuterProductCombiner.class);
+ job.setReducerClass(OuterProductReducer.class);
+
+ job.getConfiguration().setInt(QJob.PROP_K, k);
+ job.getConfiguration().setInt(QJob.PROP_P, p);
+ job.getConfiguration().set(PROP_QJOB_PATH, inputPathQJob.toString());
+ job.getConfiguration().setBoolean(PROP_OUPTUT_BBT_PRODUCTS,
+ outputBBtProducts);
+ job.getConfiguration().setInt(PROP_OUTER_PROD_BLOCK_HEIGHT, btBlockHeight);
+
+ job.setNumReduceTasks(numReduceTasks);
+
+ /*
+ * PCA-related options, MAHOUT-817
+ */
+ if (xiPath != null) {
+ job.getConfiguration().set(PROP_XI_PATH, xiPath.toString());
+ }
+
+ /*
+ * we can broadhast Rhat files since all of them are reuqired by each job,
+ * but not Q files which correspond to splits of A (so each split of A will
+ * require only particular Q file, each time different one).
+ */
+
+ if (broadcast) {
+ job.getConfiguration().set(PROP_RHAT_BROADCAST, "y");
+
+ FileSystem fs = FileSystem.get(inputPathQJob.toUri(), conf);
+ FileStatus[] fstats =
+ fs.globStatus(new Path(inputPathQJob, QJob.OUTPUT_RHAT + "-*"));
+ if (fstats != null) {
+ for (FileStatus fstat : fstats) {
+ /*
+ * new api is not enabled yet in our dependencies at this time, still
+ * using deprecated one
+ */
+ DistributedCache.addCacheFile(fstat.getPath().toUri(),
+ job.getConfiguration());
+ }
+ }
+ }
+
+ job.submit();
+ job.waitForCompletion(false);
+
+ if (!job.isSuccessful()) {
+ throw new IOException("Bt job unsuccessful.");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/DenseBlockWritable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/DenseBlockWritable.java b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/DenseBlockWritable.java
new file mode 100644
index 0000000..6a9b352
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/stochasticsvd/DenseBlockWritable.java
@@ -0,0 +1,83 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math.hadoop.stochasticsvd;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+
+import org.apache.hadoop.io.Writable;
+
+/**
+ * Ad-hoc substitution for {@link org.apache.mahout.math.MatrixWritable}.
+ * Perhaps more useful for situations with mostly dense data (such as Q-blocks)
+ * but reduces GC by reusing the same block memory between loads and writes.
+ * <p>
+ *
+ * in case of Q blocks, it doesn't even matter if they this data is dense cause
+ * we need to unpack it into dense for fast access in computations anyway and
+ * even if it is not so dense the block compressor in sequence files will take
+ * care of it for the serialized size.
+ * <p>
+ */
+public class DenseBlockWritable implements Writable {
+ private double[][] block;
+
+ public void setBlock(double[][] block) {
+ this.block = block;
+ }
+
+ public double[][] getBlock() {
+ return block;
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int m = in.readInt();
+ int n = in.readInt();
+ if (block == null) {
+ block = new double[m][0];
+ } else if (block.length != m) {
+ block = Arrays.copyOf(block, m);
+ }
+ for (int i = 0; i < m; i++) {
+ if (block[i] == null || block[i].length != n) {
+ block[i] = new double[n];
+ }
+ for (int j = 0; j < n; j++) {
+ block[i][j] = in.readDouble();
+ }
+
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ int m = block.length;
+ int n = block.length == 0 ? 0 : block[0].length;
+
+ out.writeInt(m);
+ out.writeInt(n);
+ for (double[] aBlock : block) {
+ for (int j = 0; j < n; j++) {
+ out.writeDouble(aBlock[j]);
+ }
+ }
+ }
+
+}
[22/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterable.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterable.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterable.java
new file mode 100644
index 0000000..d2fdf8d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterable.java
@@ -0,0 +1,67 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator.sequencefile;
+
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+
+/**
+ * <p>{@link Iterable} counterpart to {@link SequenceFileValueIterator}.</p>
+ */
+public final class SequenceFileValueIterable<V extends Writable> implements Iterable<V> {
+
+ private final Path path;
+ private final boolean reuseKeyValueInstances;
+ private final Configuration conf;
+
+ /**
+ * Like {@link #SequenceFileValueIterable(Path, boolean, Configuration)} but instances are not reused
+ * by default.
+ *
+ * @param path file to iterate over
+ */
+ public SequenceFileValueIterable(Path path, Configuration conf) {
+ this(path, false, conf);
+ }
+
+ /**
+ * @param path file to iterate over
+ * @param reuseKeyValueInstances if true, reuses instances of the value object instead of creating a new
+ * one for each read from the file
+ */
+ public SequenceFileValueIterable(Path path, boolean reuseKeyValueInstances, Configuration conf) {
+ this.path = path;
+ this.reuseKeyValueInstances = reuseKeyValueInstances;
+ this.conf = conf;
+ }
+
+ @Override
+ public Iterator<V> iterator() {
+ try {
+ return new SequenceFileValueIterator<>(path, reuseKeyValueInstances, conf);
+ } catch (IOException ioe) {
+ throw new IllegalStateException(path.toString(), ioe);
+ }
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterator.java b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterator.java
new file mode 100644
index 0000000..49d64c7
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/iterator/sequencefile/SequenceFileValueIterator.java
@@ -0,0 +1,97 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.iterator.sequencefile;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import com.google.common.collect.AbstractIterator;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.util.ReflectionUtils;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * <p>{@link java.util.Iterator} over a {@link SequenceFile}'s values only.</p>
+ */
+public final class SequenceFileValueIterator<V extends Writable> extends AbstractIterator<V> implements Closeable {
+
+ private final SequenceFile.Reader reader;
+ private final Configuration conf;
+ private final Class<V> valueClass;
+ private final Writable key;
+ private V value;
+ private final boolean reuseKeyValueInstances;
+
+ private static final Logger log = LoggerFactory.getLogger(SequenceFileValueIterator.class);
+
+ /**
+ * @throws IOException if path can't be read, or its key or value class can't be instantiated
+ */
+
+ public SequenceFileValueIterator(Path path, boolean reuseKeyValueInstances, Configuration conf) throws IOException {
+ value = null;
+ FileSystem fs = path.getFileSystem(conf);
+ path = path.makeQualified(path.toUri(), path);
+ reader = new SequenceFile.Reader(fs, path, conf);
+ this.conf = conf;
+ Class<? extends Writable> keyClass = (Class<? extends Writable>) reader.getKeyClass();
+ key = ReflectionUtils.newInstance(keyClass, conf);
+ valueClass = (Class<V>) reader.getValueClass();
+ this.reuseKeyValueInstances = reuseKeyValueInstances;
+ }
+
+ public Class<V> getValueClass() {
+ return valueClass;
+ }
+
+ @Override
+ public void close() throws IOException {
+ value = null;
+ Closeables.close(reader, true);
+ endOfData();
+ }
+
+ @Override
+ protected V computeNext() {
+ if (!reuseKeyValueInstances || value == null) {
+ value = ReflectionUtils.newInstance(valueClass, conf);
+ }
+ try {
+ boolean available = reader.next(key, value);
+ if (!available) {
+ close();
+ return null;
+ }
+ return value;
+ } catch (IOException ioe) {
+ try {
+ close();
+ } catch (IOException e) {
+ log.error(e.getMessage(), e);
+ }
+ throw new IllegalStateException(ioe);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/lucene/AnalyzerUtils.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/lucene/AnalyzerUtils.java b/mr/src/main/java/org/apache/mahout/common/lucene/AnalyzerUtils.java
new file mode 100644
index 0000000..37ca383
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/lucene/AnalyzerUtils.java
@@ -0,0 +1,61 @@
+package org.apache.mahout.common.lucene;
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.util.Version;
+import org.apache.mahout.common.ClassUtils;
+
+public final class AnalyzerUtils {
+
+ private AnalyzerUtils() {}
+
+ /**
+ * Create an Analyzer using the latest {@link org.apache.lucene.util.Version}. Note, if you need to pass in
+ * parameters to your constructor, you will need to wrap it in an implementation that does not take any arguments
+ * @param analyzerClassName - Lucene Analyzer Name
+ * @return {@link Analyzer}
+ * @throws ClassNotFoundException - {@link ClassNotFoundException}
+ */
+ public static Analyzer createAnalyzer(String analyzerClassName) throws ClassNotFoundException {
+ return createAnalyzer(analyzerClassName, Version.LUCENE_46);
+ }
+
+ public static Analyzer createAnalyzer(String analyzerClassName, Version version) throws ClassNotFoundException {
+ Class<? extends Analyzer> analyzerClass = Class.forName(analyzerClassName).asSubclass(Analyzer.class);
+ return createAnalyzer(analyzerClass, version);
+ }
+
+ /**
+ * Create an Analyzer using the latest {@link org.apache.lucene.util.Version}. Note, if you need to pass in
+ * parameters to your constructor, you will need to wrap it in an implementation that does not take any arguments
+ * @param analyzerClass The Analyzer Class to instantiate
+ * @return {@link Analyzer}
+ */
+ public static Analyzer createAnalyzer(Class<? extends Analyzer> analyzerClass) {
+ return createAnalyzer(analyzerClass, Version.LUCENE_46);
+ }
+
+ public static Analyzer createAnalyzer(Class<? extends Analyzer> analyzerClass, Version version) {
+ try {
+ return ClassUtils.instantiateAs(analyzerClass, Analyzer.class,
+ new Class<?>[] { Version.class }, new Object[] { version });
+ } catch (IllegalStateException e) {
+ return ClassUtils.instantiateAs(analyzerClass, Analyzer.class);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/lucene/IteratorTokenStream.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/lucene/IteratorTokenStream.java b/mr/src/main/java/org/apache/mahout/common/lucene/IteratorTokenStream.java
new file mode 100644
index 0000000..5facad8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/lucene/IteratorTokenStream.java
@@ -0,0 +1,45 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.lucene;
+
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+
+import java.util.Iterator;
+
+/** Used to emit tokens from an input string array in the style of TokenStream */
+public final class IteratorTokenStream extends TokenStream {
+ private final CharTermAttribute termAtt;
+ private final Iterator<String> iterator;
+
+ public IteratorTokenStream(Iterator<String> iterator) {
+ this.iterator = iterator;
+ this.termAtt = addAttribute(CharTermAttribute.class);
+ }
+
+ @Override
+ public boolean incrementToken() {
+ if (iterator.hasNext()) {
+ clearAttributes();
+ termAtt.append(iterator.next());
+ return true;
+ } else {
+ return false;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/lucene/TokenStreamIterator.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/lucene/TokenStreamIterator.java b/mr/src/main/java/org/apache/mahout/common/lucene/TokenStreamIterator.java
new file mode 100644
index 0000000..af60d8b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/lucene/TokenStreamIterator.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.lucene;
+
+import com.google.common.collect.AbstractIterator;
+import org.apache.lucene.analysis.TokenStream;
+import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
+
+import java.io.IOException;
+
+/**
+ * Provide an Iterator for the tokens in a TokenStream.
+ *
+ * Note, it is the responsibility of the instantiating class to properly consume the
+ * {@link org.apache.lucene.analysis.TokenStream}. See the Lucene {@link org.apache.lucene.analysis.TokenStream}
+ * documentation for more information.
+ */
+//TODO: consider using the char/byte arrays instead of strings, esp. when we upgrade to Lucene 4.0
+public final class TokenStreamIterator extends AbstractIterator<String> {
+
+ private final TokenStream tokenStream;
+
+ public TokenStreamIterator(TokenStream tokenStream) {
+ this.tokenStream = tokenStream;
+ }
+
+ @Override
+ protected String computeNext() {
+ try {
+ if (tokenStream.incrementToken()) {
+ return tokenStream.getAttribute(CharTermAttribute.class).toString();
+ } else {
+ tokenStream.end();
+ tokenStream.close();
+ return endOfData();
+ }
+ } catch (IOException e) {
+ throw new IllegalStateException("IO error while tokenizing", e);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsCombiner.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsCombiner.java b/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsCombiner.java
new file mode 100644
index 0000000..8e0385d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsCombiner.java
@@ -0,0 +1,34 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.mapreduce;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+public class MergeVectorsCombiner
+ extends Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable> {
+
+ @Override
+ public void reduce(WritableComparable<?> key, Iterable<VectorWritable> vectors, Context ctx)
+ throws IOException, InterruptedException {
+ ctx.write(key, VectorWritable.merge(vectors.iterator()));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsReducer.java b/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsReducer.java
new file mode 100644
index 0000000..b8d5dea
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/mapreduce/MergeVectorsReducer.java
@@ -0,0 +1,40 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.mapreduce;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+public class MergeVectorsReducer extends
+ Reducer<WritableComparable<?>,VectorWritable,WritableComparable<?>,VectorWritable> {
+
+ private final VectorWritable result = new VectorWritable();
+
+ @Override
+ public void reduce(WritableComparable<?> key, Iterable<VectorWritable> vectors, Context ctx)
+ throws IOException, InterruptedException {
+ Vector merged = VectorWritable.merge(vectors.iterator()).get();
+ result.set(new SequentialAccessSparseVector(merged));
+ ctx.write(key, result);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/mapreduce/TransposeMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/mapreduce/TransposeMapper.java b/mr/src/main/java/org/apache/mahout/common/mapreduce/TransposeMapper.java
new file mode 100644
index 0000000..c6c3f05
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/mapreduce/TransposeMapper.java
@@ -0,0 +1,49 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.mapreduce;
+
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+
+public class TransposeMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ public static final String NEW_NUM_COLS_PARAM = TransposeMapper.class.getName() + ".newNumCols";
+
+ private int newNumCols;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ newNumCols = ctx.getConfiguration().getInt(NEW_NUM_COLS_PARAM, Integer.MAX_VALUE);
+ }
+
+ @Override
+ protected void map(IntWritable r, VectorWritable v, Context ctx) throws IOException, InterruptedException {
+ int row = r.get();
+ for (Vector.Element e : v.get().nonZeroes()) {
+ RandomAccessSparseVector tmp = new RandomAccessSparseVector(newNumCols, 1);
+ tmp.setQuick(row, e.get());
+ r.set(e.index());
+ ctx.write(r, new VectorWritable(tmp));
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumCombiner.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumCombiner.java b/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumCombiner.java
new file mode 100644
index 0000000..1d93386
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumCombiner.java
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.mapreduce;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.Vectors;
+
+import java.io.IOException;
+
+public class VectorSumCombiner
+ extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
+
+ private final VectorWritable result = new VectorWritable();
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context ctx)
+ throws IOException, InterruptedException {
+ result.set(Vectors.sum(values.iterator()));
+ ctx.write(key, result);
+ }
+ }
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumReducer.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumReducer.java b/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumReducer.java
new file mode 100644
index 0000000..97d3805
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/mapreduce/VectorSumReducer.java
@@ -0,0 +1,35 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.mapreduce;
+
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.Vectors;
+
+import java.io.IOException;
+
+public class VectorSumReducer
+ extends Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> {
+
+ @Override
+ protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context ctx)
+ throws IOException, InterruptedException {
+ ctx.write(key, new VectorWritable(Vectors.sum(values.iterator())));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/nlp/NGrams.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/nlp/NGrams.java b/mr/src/main/java/org/apache/mahout/common/nlp/NGrams.java
new file mode 100644
index 0000000..7adadc1
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/nlp/NGrams.java
@@ -0,0 +1,94 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.nlp;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+public class NGrams {
+
+ private static final Splitter SPACE_TAB = Splitter.on(CharMatcher.anyOf(" \t"));
+
+ private final String line;
+ private final int gramSize;
+
+ public NGrams(String line, int gramSize) {
+ this.line = line;
+ this.gramSize = gramSize;
+ }
+
+ public Map<String,List<String>> generateNGrams() {
+ Map<String,List<String>> returnDocument = Maps.newHashMap();
+
+ Iterator<String> tokenizer = SPACE_TAB.split(line).iterator();
+ List<String> tokens = Lists.newArrayList();
+ String labelName = tokenizer.next();
+ List<String> previousN1Grams = Lists.newArrayList();
+ while (tokenizer.hasNext()) {
+
+ String nextToken = tokenizer.next();
+ if (previousN1Grams.size() == gramSize) {
+ previousN1Grams.remove(0);
+ }
+
+ previousN1Grams.add(nextToken);
+
+ StringBuilder gramBuilder = new StringBuilder();
+
+ for (String gram : previousN1Grams) {
+ gramBuilder.append(gram);
+ String token = gramBuilder.toString();
+ tokens.add(token);
+ gramBuilder.append(' ');
+ }
+ }
+ returnDocument.put(labelName, tokens);
+ return returnDocument;
+ }
+
+ public List<String> generateNGramsWithoutLabel() {
+
+ List<String> tokens = Lists.newArrayList();
+ List<String> previousN1Grams = Lists.newArrayList();
+ for (String nextToken : SPACE_TAB.split(line)) {
+
+ if (previousN1Grams.size() == gramSize) {
+ previousN1Grams.remove(0);
+ }
+
+ previousN1Grams.add(nextToken);
+
+ StringBuilder gramBuilder = new StringBuilder();
+
+ for (String gram : previousN1Grams) {
+ gramBuilder.append(gram);
+ String token = gramBuilder.toString();
+ tokens.add(token);
+ gramBuilder.append(' ');
+ }
+ }
+
+ return tokens;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/AbstractParameter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/AbstractParameter.java b/mr/src/main/java/org/apache/mahout/common/parameters/AbstractParameter.java
new file mode 100644
index 0000000..f0a7aa8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/parameters/AbstractParameter.java
@@ -0,0 +1,120 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.parameters;
+
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.hadoop.conf.Configuration;
+
+public abstract class AbstractParameter<T> implements Parameter<T> {
+
+ private T value;
+ private final String prefix;
+ private final String name;
+ private final String description;
+ private final Class<T> type;
+ private final String defaultValue;
+
+ protected AbstractParameter(Class<T> type,
+ String prefix,
+ String name,
+ Configuration jobConf,
+ T defaultValue,
+ String description) {
+ this.type = type;
+ this.name = name;
+ this.description = description;
+
+ this.value = defaultValue;
+ this.defaultValue = getStringValue();
+
+ this.prefix = prefix;
+ String jobConfValue = jobConf.get(prefix + name);
+ if (jobConfValue != null) {
+ setStringValue(jobConfValue);
+ }
+
+ }
+
+ @Override
+ public void configure(Configuration jobConf) {
+ // nothing to do
+ }
+
+ @Override
+ public void createParameters(String prefix, Configuration jobConf) { }
+
+ @Override
+ public String getStringValue() {
+ if (value == null) {
+ return null;
+ }
+ return value.toString();
+ }
+
+ @Override
+ public Collection<Parameter<?>> getParameters() {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public String prefix() {
+ return prefix;
+ }
+
+ @Override
+ public String name() {
+ return name;
+ }
+
+ @Override
+ public String description() {
+ return description;
+ }
+
+ @Override
+ public Class<T> type() {
+ return type;
+ }
+
+ @Override
+ public String defaultValue() {
+ return defaultValue;
+ }
+
+ @Override
+ public T get() {
+ return value;
+ }
+
+ @Override
+ public void set(T value) {
+ this.value = value;
+ }
+
+ @Override
+ public String toString() {
+ if (value != null) {
+ return value.toString();
+ } else {
+ return super.toString();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/ClassParameter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/ClassParameter.java b/mr/src/main/java/org/apache/mahout/common/parameters/ClassParameter.java
new file mode 100644
index 0000000..1d1c0bb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/parameters/ClassParameter.java
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.parameters;
+
+import org.apache.hadoop.conf.Configuration;
+
+public class ClassParameter extends AbstractParameter<Class> {
+
+ public ClassParameter(String prefix, String name, Configuration jobConf, Class<?> defaultValue, String description) {
+ super(Class.class, prefix, name, jobConf, defaultValue, description);
+ }
+
+ @Override
+ public void setStringValue(String stringValue) {
+ try {
+ set(Class.forName(stringValue));
+ } catch (ClassNotFoundException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ @Override
+ public String getStringValue() {
+ if (get() == null) {
+ return null;
+ }
+ return get().getName();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/DoubleParameter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/DoubleParameter.java b/mr/src/main/java/org/apache/mahout/common/parameters/DoubleParameter.java
new file mode 100644
index 0000000..cb3efcf
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/parameters/DoubleParameter.java
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.parameters;
+
+import org.apache.hadoop.conf.Configuration;
+
+public class DoubleParameter extends AbstractParameter<Double> {
+
+ public DoubleParameter(String prefix, String name, Configuration conf, double defaultValue, String description) {
+ super(Double.class, prefix, name, conf, defaultValue, description);
+ }
+
+ @Override
+ public void setStringValue(String stringValue) {
+ set(Double.valueOf(stringValue));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/Parameter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/Parameter.java b/mr/src/main/java/org/apache/mahout/common/parameters/Parameter.java
new file mode 100644
index 0000000..292fa27
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/parameters/Parameter.java
@@ -0,0 +1,62 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.parameters;
+
+/**
+ * An accessor to a parameters in the job.
+ *
+ * This is a composite entity that can it self contain more parameters. Say the parameters describes what
+ * DistanceMeasure class to use, once set this parameters would also produce the parameters available in that
+ * DistanceMeasure implementation.
+ */
+public interface Parameter<T> extends Parametered {
+ /** @return job configuration setting key prefix, e.g. 'org.apache.mahout.util.WeightedDistanceMeasure.' */
+ String prefix();
+
+ /** @return configuration parameters name, e.g. 'weightsFile' */
+ String name();
+
+ /** @return human readable description of parameters */
+ String description();
+
+ /** @return value class type */
+ Class<T> type();
+
+ /**
+ * @param stringValue
+ * value string representation
+ */
+ void setStringValue(String stringValue);
+
+ /**
+ * @return value string representation of current value
+ */
+ String getStringValue();
+
+ /**
+ * @param value
+ * new parameters value
+ */
+ void set(T value);
+
+ /** @return current parameters value */
+ T get();
+
+ /** @return value used if not set by consumer */
+ String defaultValue();
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/Parametered.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/Parametered.java b/mr/src/main/java/org/apache/mahout/common/parameters/Parametered.java
new file mode 100644
index 0000000..96c9457
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/parameters/Parametered.java
@@ -0,0 +1,206 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.parameters;
+
+import java.util.Collection;
+
+import org.apache.hadoop.conf.Configuration;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Meta information and accessors for configuring a job. */
+public interface Parametered {
+
+ Logger log = LoggerFactory.getLogger(Parametered.class);
+
+ Collection<Parameter<?>> getParameters();
+
+ /**
+ * EXPERT: consumers should never have to call this method. It would be friendly visible to
+ * {@link ParameteredGeneralizations} if java supported it. Calling this method should create a new list of
+ * parameters and is called
+ *
+ * @param prefix
+ * ends with a dot if not empty.
+ * @param jobConf
+ * configuration used for retrieving values
+ * @see ParameteredGeneralizations#configureParameters(String,Parametered,Configuration)
+ * invoking method
+ * @see ParameteredGeneralizations#configureParametersRecursively(Parametered,String,Configuration)
+ * invoking method
+ */
+ void createParameters(String prefix, Configuration jobConf);
+
+ void configure(Configuration config);
+
+ /** "multiple inheritance" */
+ final class ParameteredGeneralizations {
+ private ParameteredGeneralizations() { }
+
+ public static void configureParameters(Parametered parametered, Configuration jobConf) {
+ configureParameters(parametered.getClass().getSimpleName() + '.',
+ parametered, jobConf);
+
+ }
+
+ /**
+ * Calls
+ * {@link Parametered#createParameters(String,org.apache.hadoop.conf.Configuration)}
+ * on parameter parmetered, and then recur down its composite tree to invoke
+ * {@link Parametered#createParameters(String,org.apache.hadoop.conf.Configuration)}
+ * and {@link Parametered#configure(org.apache.hadoop.conf.Configuration)} on
+ * each composite part.
+ *
+ * @param prefix
+ * ends with a dot if not empty.
+ * @param parametered
+ * instance to be configured
+ * @param jobConf
+ * configuration used for retrieving values
+ */
+ public static void configureParameters(String prefix, Parametered parametered, Configuration jobConf) {
+ parametered.createParameters(prefix, jobConf);
+ configureParametersRecursively(parametered, prefix, jobConf);
+ }
+
+ private static void configureParametersRecursively(Parametered parametered, String prefix, Configuration jobConf) {
+ for (Parameter<?> parameter : parametered.getParameters()) {
+ if (log.isDebugEnabled()) {
+ log.debug("Configuring {}{}", prefix, parameter.name());
+ }
+ String name = prefix + parameter.name() + '.';
+ parameter.createParameters(name, jobConf);
+ parameter.configure(jobConf);
+ if (!parameter.getParameters().isEmpty()) {
+ configureParametersRecursively(parameter, name, jobConf);
+ }
+ }
+ }
+
+ public static String help(Parametered parametered) {
+ return new Help(parametered).toString();
+ }
+
+ public static String conf(Parametered parametered) {
+ return new Conf(parametered).toString();
+ }
+
+ private static final class Help {
+ static final int NAME_DESC_DISTANCE = 8;
+
+ private final StringBuilder sb;
+ private int longestName;
+ private int numChars = 100; // a few extra just to be sure
+
+ private Help(Parametered parametered) {
+ recurseCount(parametered);
+ numChars += (longestName + NAME_DESC_DISTANCE) * parametered.getParameters().size();
+ sb = new StringBuilder(numChars);
+ recurseWrite(parametered);
+ }
+
+ @Override
+ public String toString() {
+ return sb.toString();
+ }
+
+ private void recurseCount(Parametered parametered) {
+ for (Parameter<?> parameter : parametered.getParameters()) {
+ int parameterNameLength = parameter.name().length();
+ if (parameterNameLength > longestName) {
+ longestName = parameterNameLength;
+ }
+ recurseCount(parameter);
+ numChars += parameter.description().length();
+ }
+ }
+
+ private void recurseWrite(Parametered parametered) {
+ for (Parameter<?> parameter : parametered.getParameters()) {
+ sb.append(parameter.prefix());
+ sb.append(parameter.name());
+ int max = longestName - parameter.name().length() - parameter.prefix().length()
+ + NAME_DESC_DISTANCE;
+ for (int i = 0; i < max; i++) {
+ sb.append(' ');
+ }
+ sb.append(parameter.description());
+ if (parameter.defaultValue() != null) {
+ sb.append(" (default value '");
+ sb.append(parameter.defaultValue());
+ sb.append("')");
+ }
+ sb.append('\n');
+ recurseWrite(parameter);
+ }
+ }
+ }
+
+ private static final class Conf {
+ private final StringBuilder sb;
+ private int longestName;
+ private int numChars = 100; // a few extra just to be sure
+
+ private Conf(Parametered parametered) {
+ recurseCount(parametered);
+ sb = new StringBuilder(numChars);
+ recurseWrite(parametered);
+ }
+
+ @Override
+ public String toString() {
+ return sb.toString();
+ }
+
+ private void recurseCount(Parametered parametered) {
+ for (Parameter<?> parameter : parametered.getParameters()) {
+ int parameterNameLength = parameter.prefix().length() + parameter.name().length();
+ if (parameterNameLength > longestName) {
+ longestName = parameterNameLength;
+ }
+
+ numChars += parameterNameLength;
+ numChars += 5; // # $0\n$1 = $2\n\n
+ numChars += parameter.description().length();
+ if (parameter.getStringValue() != null) {
+ numChars += parameter.getStringValue().length();
+ }
+
+ recurseCount(parameter);
+ }
+ }
+
+ private void recurseWrite(Parametered parametered) {
+ for (Parameter<?> parameter : parametered.getParameters()) {
+ sb.append("# ");
+ sb.append(parameter.description());
+ sb.append('\n');
+ sb.append(parameter.prefix());
+ sb.append(parameter.name());
+ sb.append(" = ");
+ if (parameter.getStringValue() != null) {
+ sb.append(parameter.getStringValue());
+ }
+ sb.append('\n');
+ sb.append('\n');
+ recurseWrite(parameter);
+ }
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/common/parameters/PathParameter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/common/parameters/PathParameter.java b/mr/src/main/java/org/apache/mahout/common/parameters/PathParameter.java
new file mode 100644
index 0000000..a617fe3
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/common/parameters/PathParameter.java
@@ -0,0 +1,33 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.common.parameters;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+
+public class PathParameter extends AbstractParameter<Path> {
+
+ public PathParameter(String prefix, String name, Configuration jobConf, Path defaultValue, String description) {
+ super(Path.class, prefix, name, jobConf, defaultValue, description);
+ }
+
+ @Override
+ public void setStringValue(String stringValue) {
+ set(new Path(stringValue));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java b/mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java
new file mode 100644
index 0000000..1fd5506
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/driver/MahoutDriver.java
@@ -0,0 +1,244 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.driver;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.util.ProgramDriver;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * General-purpose driver class for Mahout programs. Utilizes org.apache.hadoop.util.ProgramDriver to run
+ * main methods of other classes, but first loads up default properties from a properties file.
+ * <p/>
+ * To run locally:
+ *
+ * <pre>$MAHOUT_HOME/bin/mahout run shortJobName [over-ride ops]</pre>
+ * <p/>
+ * Works like this: by default, the file "driver.classes.props" is loaded from the classpath, which
+ * defines a mapping between short names like "vectordump" and fully qualified class names.
+ * The format of driver.classes.props is like so:
+ * <p/>
+ *
+ * <pre>fully.qualified.class.name = shortJobName : descriptive string</pre>
+ * <p/>
+ * The default properties to be applied to the program run is pulled out of, by default, "<shortJobName>.props"
+ * (also off of the classpath).
+ * <p/>
+ * The format of the default properties files is as follows:
+ * <pre>
+ i|input = /path/to/my/input
+ o|output = /path/to/my/output
+ m|jarFile = /path/to/jarFile
+ # etc - each line is shortArg|longArg = value
+ </pre>
+ *
+ * The next argument to the Driver is supposed to be the short name of the class to be run (as defined in the
+ * driver.classes.props file).
+ * <p/>
+ * Then the class which will be run will have it's main called with
+ *
+ * <pre>main(new String[] { "--input", "/path/to/my/input", "--output", "/path/to/my/output" });</pre>
+ *
+ * After all the "default" properties are loaded from the file, any further command-line arguments are taken in,
+ * and over-ride the defaults.
+ * <p/>
+ * So if your driver.classes.props looks like so:
+ *
+ * <pre>org.apache.mahout.utils.vectors.VectorDumper = vecDump : dump vectors from a sequence file</pre>
+ *
+ * and you have a file core/src/main/resources/vecDump.props which looks like
+ * <pre>
+ o|output = /tmp/vectorOut
+ s|seqFile = /my/vector/sequenceFile
+ </pre>
+ *
+ * And you execute the command-line:
+ *
+ * <pre>$MAHOUT_HOME/bin/mahout run vecDump -s /my/otherVector/sequenceFile</pre>
+ *
+ * Then org.apache.mahout.utils.vectors.VectorDumper.main() will be called with arguments:
+ * <pre>{"--output", "/tmp/vectorOut", "-s", "/my/otherVector/sequenceFile"}</pre>
+ */
+public final class MahoutDriver {
+
+ private static final Logger log = LoggerFactory.getLogger(MahoutDriver.class);
+
+ private MahoutDriver() {
+ }
+
+ public static void main(String[] args) throws Throwable {
+
+ Properties mainClasses = loadProperties("driver.classes.props");
+ if (mainClasses == null) {
+ mainClasses = loadProperties("driver.classes.default.props");
+ }
+ if (mainClasses == null) {
+ throw new IOException("Can't load any properties file?");
+ }
+
+ boolean foundShortName = false;
+ ProgramDriver programDriver = new ProgramDriver();
+ for (Object key : mainClasses.keySet()) {
+ String keyString = (String) key;
+ if (args.length > 0 && shortName(mainClasses.getProperty(keyString)).equals(args[0])) {
+ foundShortName = true;
+ }
+ if (args.length > 0 && keyString.equalsIgnoreCase(args[0]) && isDeprecated(mainClasses, keyString)) {
+ log.error(desc(mainClasses.getProperty(keyString)));
+ return;
+ }
+ if (isDeprecated(mainClasses, keyString)) {
+ continue;
+ }
+ addClass(programDriver, keyString, mainClasses.getProperty(keyString));
+ }
+
+ if (args.length < 1 || args[0] == null || "-h".equals(args[0]) || "--help".equals(args[0])) {
+ programDriver.driver(args);
+ return;
+ }
+
+ String progName = args[0];
+ if (!foundShortName) {
+ addClass(programDriver, progName, progName);
+ }
+ shift(args);
+
+ Properties mainProps = loadProperties(progName + ".props");
+ if (mainProps == null) {
+ log.warn("No {}.props found on classpath, will use command-line arguments only", progName);
+ mainProps = new Properties();
+ }
+
+ Map<String,String[]> argMap = Maps.newHashMap();
+ int i = 0;
+ while (i < args.length && args[i] != null) {
+ List<String> argValues = Lists.newArrayList();
+ String arg = args[i];
+ i++;
+ if (arg.startsWith("-D")) { // '-Dkey=value' or '-Dkey=value1,value2,etc' case
+ String[] argSplit = arg.split("=");
+ arg = argSplit[0];
+ if (argSplit.length == 2) {
+ argValues.add(argSplit[1]);
+ }
+ } else { // '-key [values]' or '--key [values]' case.
+ while (i < args.length && args[i] != null) {
+ if (args[i].startsWith("-")) {
+ break;
+ }
+ argValues.add(args[i]);
+ i++;
+ }
+ }
+ argMap.put(arg, argValues.toArray(new String[argValues.size()]));
+ }
+
+ // Add properties from the .props file that are not overridden on the command line
+ for (String key : mainProps.stringPropertyNames()) {
+ String[] argNamePair = key.split("\\|");
+ String shortArg = '-' + argNamePair[0].trim();
+ String longArg = argNamePair.length < 2 ? null : "--" + argNamePair[1].trim();
+ if (!argMap.containsKey(shortArg) && (longArg == null || !argMap.containsKey(longArg))) {
+ argMap.put(longArg, new String[] {mainProps.getProperty(key)});
+ }
+ }
+
+ // Now add command-line args
+ List<String> argsList = Lists.newArrayList();
+ argsList.add(progName);
+ for (Map.Entry<String,String[]> entry : argMap.entrySet()) {
+ String arg = entry.getKey();
+ if (arg.startsWith("-D")) { // arg is -Dkey - if value for this !isEmpty(), then arg -> -Dkey + "=" + value
+ String[] argValues = entry.getValue();
+ if (argValues.length > 0 && !argValues[0].trim().isEmpty()) {
+ arg += '=' + argValues[0].trim();
+ }
+ argsList.add(1, arg);
+ } else {
+ argsList.add(arg);
+ for (String argValue : Arrays.asList(argMap.get(arg))) {
+ if (!argValue.isEmpty()) {
+ argsList.add(argValue);
+ }
+ }
+ }
+ }
+
+ long start = System.currentTimeMillis();
+
+ programDriver.driver(argsList.toArray(new String[argsList.size()]));
+
+ if (log.isInfoEnabled()) {
+ log.info("Program took {} ms (Minutes: {})", System.currentTimeMillis() - start,
+ (System.currentTimeMillis() - start) / 60000.0);
+ }
+ }
+
+ private static boolean isDeprecated(Properties mainClasses, String keyString) {
+ return "deprecated".equalsIgnoreCase(shortName(mainClasses.getProperty(keyString)));
+ }
+
+ private static Properties loadProperties(String resource) throws IOException {
+ InputStream propsStream = Thread.currentThread().getContextClassLoader().getResourceAsStream(resource);
+ if (propsStream != null) {
+ try {
+ Properties properties = new Properties();
+ properties.load(propsStream);
+ return properties;
+ } finally {
+ Closeables.close(propsStream, true);
+ }
+ }
+ return null;
+ }
+
+ private static String[] shift(String[] args) {
+ System.arraycopy(args, 1, args, 0, args.length - 1);
+ args[args.length - 1] = null;
+ return args;
+ }
+
+ private static String shortName(String valueString) {
+ return valueString.contains(":") ? valueString.substring(0, valueString.indexOf(':')).trim() : valueString;
+ }
+
+ private static String desc(String valueString) {
+ return valueString.contains(":") ? valueString.substring(valueString.indexOf(':')).trim() : valueString;
+ }
+
+ private static void addClass(ProgramDriver driver, String classString, String descString) {
+ try {
+ Class<?> clazz = Class.forName(classString);
+ driver.addClass(shortName(descString), clazz, desc(descString));
+ } catch (Throwable t) {
+ log.warn("Unable to add class: {}", classString, t);
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java b/mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
new file mode 100644
index 0000000..b744287
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/ep/EvolutionaryProcess.java
@@ -0,0 +1,228 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.ep;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.sgd.PolymorphicWritable;
+
+import java.io.Closeable;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Allows evolutionary optimization where the state function can't be easily
+ * packaged for the optimizer to execute. A good example of this is with
+ * on-line learning where optimizing the learning parameters is desirable.
+ * We would like to pass training examples to the learning algorithms, but
+ * we definitely want to do the training in multiple threads and then after
+ * several training steps, we want to do a selection and mutation step.
+ *
+ * In such a case, it is highly desirable to leave most of the control flow
+ * in the hands of our caller. As such, this class provides three functions,
+ * <ul>
+ * <li> Storage of the evolutionary state. The state variables have payloads
+ * which can be anything that implements Payload.
+ * <li> Threaded execution of a single operation on each of the members of the
+ * population being evolved. In the on-line learning example, this is used for
+ * training all of the classifiers in the population.
+ * <li> Propagating mutations of the most successful members of the population.
+ * This propagation involves copying the state and the payload and then updating
+ * the payload after mutation of the evolutionary state.
+ * </ul>
+ *
+ * The State class that we use for storing the state of each member of the
+ * population also provides parameter mapping. Check out Mapping and State
+ * for more info.
+ *
+ * @see Mapping
+ * @see Payload
+ * @see State
+ *
+ * @param <T> The payload class.
+ */
+public class EvolutionaryProcess<T extends Payload<U>, U> implements Writable, Closeable {
+ // used to execute operations on the population in thread parallel.
+ private ExecutorService pool;
+
+ // threadCount is serialized so that we can reconstruct the thread pool
+ private int threadCount;
+
+ // list of members of the population
+ private List<State<T, U>> population;
+
+ // how big should the population be. If this is changed, it will take effect
+ // the next time the population is mutated.
+
+ private int populationSize;
+
+ public EvolutionaryProcess() {
+ population = Lists.newArrayList();
+ }
+
+ /**
+ * Creates an evolutionary optimization framework with specified threadiness,
+ * population size and initial state.
+ * @param threadCount How many threads to use in parallelDo
+ * @param populationSize How large a population to use
+ * @param seed An initial population member
+ */
+ public EvolutionaryProcess(int threadCount, int populationSize, State<T, U> seed) {
+ this.populationSize = populationSize;
+ setThreadCount(threadCount);
+ initializePopulation(populationSize, seed);
+ }
+
+ private void initializePopulation(int populationSize, State<T, U> seed) {
+ population = Lists.newArrayList(seed);
+ for (int i = 0; i < populationSize; i++) {
+ population.add(seed.mutate());
+ }
+ }
+
+ public void add(State<T, U> value) {
+ population.add(value);
+ }
+
+ /**
+ * Nuke all but a few of the current population and then repopulate with
+ * variants of the survivors.
+ * @param survivors How many survivors we want to keep.
+ */
+ public void mutatePopulation(int survivors) {
+ // largest value first, oldest first in case of ties
+ Collections.sort(population);
+
+ // we copy here to avoid concurrent modification
+ List<State<T, U>> parents = Lists.newArrayList(population.subList(0, survivors));
+ population.subList(survivors, population.size()).clear();
+
+ // fill out the population with offspring from the survivors
+ int i = 0;
+ while (population.size() < populationSize) {
+ population.add(parents.get(i % survivors).mutate());
+ i++;
+ }
+ }
+
+ /**
+ * Execute an operation on all of the members of the population with many threads. The
+ * return value is taken as the current fitness of the corresponding member.
+ * @param fn What to do on each member. Gets payload and the mapped parameters as args.
+ * @return The member of the population with the best fitness.
+ * @throws InterruptedException Shouldn't happen.
+ * @throws ExecutionException If fn throws an exception, that exception will be collected
+ * and rethrown nested in an ExecutionException.
+ */
+ public State<T, U> parallelDo(final Function<Payload<U>> fn) throws InterruptedException, ExecutionException {
+ Collection<Callable<State<T, U>>> tasks = Lists.newArrayList();
+ for (final State<T, U> state : population) {
+ tasks.add(new Callable<State<T, U>>() {
+ @Override
+ public State<T, U> call() {
+ double v = fn.apply(state.getPayload(), state.getMappedParams());
+ state.setValue(v);
+ return state;
+ }
+ });
+ }
+
+ List<Future<State<T, U>>> r = pool.invokeAll(tasks);
+
+ // zip through the results and find the best one
+ double max = Double.NEGATIVE_INFINITY;
+ State<T, U> best = null;
+ for (Future<State<T, U>> future : r) {
+ State<T, U> s = future.get();
+ double value = s.getValue();
+ if (!Double.isNaN(value) && value >= max) {
+ max = value;
+ best = s;
+ }
+ }
+ if (best == null) {
+ best = r.get(0).get();
+ }
+
+ return best;
+ }
+
+ public void setThreadCount(int threadCount) {
+ this.threadCount = threadCount;
+ pool = Executors.newFixedThreadPool(threadCount);
+ }
+
+ public int getThreadCount() {
+ return threadCount;
+ }
+
+ public int getPopulationSize() {
+ return populationSize;
+ }
+
+ public List<State<T, U>> getPopulation() {
+ return population;
+ }
+
+ @Override
+ public void close() {
+ List<Runnable> remainingTasks = pool.shutdownNow();
+ try {
+ pool.awaitTermination(10, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ throw new IllegalStateException("Had to forcefully shut down " + remainingTasks.size() + " tasks");
+ }
+ if (!remainingTasks.isEmpty()) {
+ throw new IllegalStateException("Had to forcefully shut down " + remainingTasks.size() + " tasks");
+ }
+ }
+
+ public interface Function<T> {
+ double apply(T payload, double[] params);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(threadCount);
+ out.writeInt(population.size());
+ for (State<T, U> state : population) {
+ PolymorphicWritable.write(out, state);
+ }
+ }
+
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ setThreadCount(input.readInt());
+ int n = input.readInt();
+ population = Lists.newArrayList();
+ for (int i = 0; i < n; i++) {
+ State<T, U> state = (State<T, U>) PolymorphicWritable.read(input, State.class);
+ population.add(state);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/ep/Mapping.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/ep/Mapping.java b/mr/src/main/java/org/apache/mahout/ep/Mapping.java
new file mode 100644
index 0000000..41a8942
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/ep/Mapping.java
@@ -0,0 +1,206 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.ep;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import com.google.common.base.Preconditions;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.sgd.PolymorphicWritable;
+import org.apache.mahout.math.function.DoubleFunction;
+
+/**
+ * Provides coordinate tranformations so that evolution can proceed on the entire space of
+ * reals but have the output limited and squished in convenient (and safe) ways.
+ */
+public abstract class Mapping extends DoubleFunction implements Writable {
+
+ private Mapping() {
+ }
+
+ public static final class SoftLimit extends Mapping {
+ private double min;
+ private double max;
+ private double scale;
+
+ public SoftLimit() {
+ }
+
+ private SoftLimit(double min, double max, double scale) {
+ this.min = min;
+ this.max = max;
+ this.scale = scale;
+ }
+
+ @Override
+ public double apply(double v) {
+ return min + (max - min) * 1 / (1 + Math.exp(-v * scale));
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(min);
+ out.writeDouble(max);
+ out.writeDouble(scale);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ min = in.readDouble();
+ max = in.readDouble();
+ scale = in.readDouble();
+ }
+ }
+
+ public static final class LogLimit extends Mapping {
+ private Mapping wrapped;
+
+ public LogLimit() {
+ }
+
+ private LogLimit(double low, double high) {
+ wrapped = softLimit(Math.log(low), Math.log(high));
+ }
+
+ @Override
+ public double apply(double v) {
+ return Math.exp(wrapped.apply(v));
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) throws IOException {
+ PolymorphicWritable.write(dataOutput, wrapped);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ wrapped = PolymorphicWritable.read(in, Mapping.class);
+ }
+ }
+
+ public static final class Exponential extends Mapping {
+ private double scale;
+
+ public Exponential() {
+ }
+
+ private Exponential(double scale) {
+ this.scale = scale;
+ }
+
+ @Override
+ public double apply(double v) {
+ return Math.exp(v * scale);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(scale);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ scale = in.readDouble();
+ }
+ }
+
+ public static final class Identity extends Mapping {
+ @Override
+ public double apply(double v) {
+ return v;
+ }
+
+ @Override
+ public void write(DataOutput dataOutput) {
+ // stateless
+ }
+
+ @Override
+ public void readFields(DataInput dataInput) {
+ // stateless
+ }
+ }
+
+ /**
+ * Maps input to the open interval (min, max) with 0 going to the mean of min and
+ * max. When scale is large, a larger proportion of values are mapped to points
+ * near the boundaries. When scale is small, a larger proportion of values are mapped to
+ * points well within the boundaries.
+ * @param min The largest lower bound on values to be returned.
+ * @param max The least upper bound on values to be returned.
+ * @param scale Defines how sharp the boundaries are.
+ * @return A mapping that satisfies the desired constraint.
+ */
+ public static Mapping softLimit(double min, double max, double scale) {
+ return new SoftLimit(min, max, scale);
+ }
+
+ /**
+ * Maps input to the open interval (min, max) with 0 going to the mean of min and
+ * max. When scale is large, a larger proportion of values are mapped to points
+ * near the boundaries.
+ * @see #softLimit(double, double, double)
+ * @param min The largest lower bound on values to be returned.
+ * @param max The least upper bound on values to be returned.
+ * @return A mapping that satisfies the desired constraint.
+ */
+ public static Mapping softLimit(double min, double max) {
+ return softLimit(min, max, 1);
+ }
+
+ /**
+ * Maps input to positive values in the open interval (min, max) with
+ * 0 going to the geometric mean. Near the geometric mean, values are
+ * distributed roughly geometrically.
+ * @param low The largest lower bound for output results. Must be >0.
+ * @param high The least upper bound for output results. Must be >0.
+ * @return A mapped value.
+ */
+ public static Mapping logLimit(double low, double high) {
+ Preconditions.checkArgument(low > 0, "Lower bound for log limit must be > 0 but was %f", low);
+ Preconditions.checkArgument(high > 0, "Upper bound for log limit must be > 0 but was %f", high);
+ return new LogLimit(low, high);
+ }
+
+ /**
+ * Maps results to positive values.
+ * @return A positive value.
+ */
+ public static Mapping exponential() {
+ return exponential(1);
+ }
+
+ /**
+ * Maps results to positive values.
+ * @param scale If large, then large values are more likely.
+ * @return A positive value.
+ */
+ public static Mapping exponential(double scale) {
+ return new Exponential(scale);
+ }
+
+ /**
+ * Maps results to themselves.
+ * @return The original value.
+ */
+ public static Mapping identity() {
+ return new Identity();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/ep/Payload.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/ep/Payload.java b/mr/src/main/java/org/apache/mahout/ep/Payload.java
new file mode 100644
index 0000000..920237d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/ep/Payload.java
@@ -0,0 +1,36 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.ep;
+
+import org.apache.hadoop.io.Writable;
+
+/**
+ * Payloads for evolutionary state must be copyable and updatable. The copy should be a deep copy
+ * unless some aspect of the state is sharable or immutable.
+ * <p/>
+ * During mutation, a copy is first made and then after the parameters in the State structure are
+ * suitably modified, update is called with the scaled versions of the parameters.
+ *
+ * @param <T>
+ * @see State
+ */
+public interface Payload<T> extends Writable {
+ Payload<T> copy();
+
+ void update(double[] params);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/ep/State.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/ep/State.java b/mr/src/main/java/org/apache/mahout/ep/State.java
new file mode 100644
index 0000000..7a0fb5e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/ep/State.java
@@ -0,0 +1,302 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.ep;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.sgd.PolymorphicWritable;
+import org.apache.mahout.common.RandomUtils;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Locale;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Records evolutionary state and provides a mutation operation for recorded-step meta-mutation.
+ *
+ * You provide the payload, this class provides the mutation operations. During mutation,
+ * the payload is copied and after the state variables are changed, they are passed to the
+ * payload.
+ *
+ * Parameters are internally mutated in a state space that spans all of R^n, but parameters
+ * passed to the payload are transformed as specified by a call to setMap(). The default
+ * mapping is the identity map, but uniform-ish or exponential-ish coverage of a range are
+ * also supported.
+ *
+ * More information on the underlying algorithm can be found in the following paper
+ *
+ * http://arxiv.org/abs/0803.3838
+ *
+ * @see Mapping
+ */
+public class State<T extends Payload<U>, U> implements Comparable<State<T, U>>, Writable {
+
+ // object count is kept to break ties in comparison.
+ private static final AtomicInteger OBJECT_COUNT = new AtomicInteger();
+
+ private int id = OBJECT_COUNT.getAndIncrement();
+ private Random gen = RandomUtils.getRandom();
+ // current state
+ private double[] params;
+ // mappers to transform state
+ private Mapping[] maps;
+ // omni-directional mutation
+ private double omni;
+ // directional mutation
+ private double[] step;
+ // current fitness value
+ private double value;
+ private T payload;
+
+ public State() {
+ }
+
+ /**
+ * Invent a new state with no momentum (yet).
+ */
+ public State(double[] x0, double omni) {
+ params = Arrays.copyOf(x0, x0.length);
+ this.omni = omni;
+ step = new double[params.length];
+ maps = new Mapping[params.length];
+ }
+
+ /**
+ * Deep copies a state, useful in mutation.
+ */
+ public State<T, U> copy() {
+ State<T, U> r = new State<>();
+ r.params = Arrays.copyOf(this.params, this.params.length);
+ r.omni = this.omni;
+ r.step = Arrays.copyOf(this.step, this.step.length);
+ r.maps = Arrays.copyOf(this.maps, this.maps.length);
+ if (this.payload != null) {
+ r.payload = (T) this.payload.copy();
+ }
+ r.gen = this.gen;
+ return r;
+ }
+
+ /**
+ * Clones this state with a random change in position. Copies the payload and
+ * lets it know about the change.
+ *
+ * @return A new state.
+ */
+ public State<T, U> mutate() {
+ double sum = 0;
+ for (double v : step) {
+ sum += v * v;
+ }
+ sum = Math.sqrt(sum);
+ double lambda = 1 + gen.nextGaussian();
+
+ State<T, U> r = this.copy();
+ double magnitude = 0.9 * omni + sum / 10;
+ r.omni = magnitude * -Math.log1p(-gen.nextDouble());
+ for (int i = 0; i < step.length; i++) {
+ r.step[i] = lambda * step[i] + r.omni * gen.nextGaussian();
+ r.params[i] += r.step[i];
+ }
+ if (this.payload != null) {
+ r.payload.update(r.getMappedParams());
+ }
+ return r;
+ }
+
+ /**
+ * Defines the transformation for a parameter.
+ * @param i Which parameter's mapping to define.
+ * @param m The mapping to use.
+ * @see org.apache.mahout.ep.Mapping
+ */
+ public void setMap(int i, Mapping m) {
+ maps[i] = m;
+ }
+
+ /**
+ * Returns a transformed parameter.
+ * @param i The parameter to return.
+ * @return The value of the parameter.
+ */
+ public double get(int i) {
+ Mapping m = maps[i];
+ return m == null ? params[i] : m.apply(params[i]);
+ }
+
+ public int getId() {
+ return id;
+ }
+
+ public double[] getParams() {
+ return params;
+ }
+
+ public Mapping[] getMaps() {
+ return maps;
+ }
+
+ /**
+ * Returns all the parameters in mapped form.
+ * @return An array of parameters.
+ */
+ public double[] getMappedParams() {
+ double[] r = Arrays.copyOf(params, params.length);
+ for (int i = 0; i < params.length; i++) {
+ r[i] = get(i);
+ }
+ return r;
+ }
+
+ public double getOmni() {
+ return omni;
+ }
+
+ public double[] getStep() {
+ return step;
+ }
+
+ public T getPayload() {
+ return payload;
+ }
+
+ public double getValue() {
+ return value;
+ }
+
+ public void setOmni(double omni) {
+ this.omni = omni;
+ }
+
+ public void setId(int id) {
+ this.id = id;
+ }
+
+ public void setStep(double[] step) {
+ this.step = step;
+ }
+
+ public void setMaps(Mapping[] maps) {
+ this.maps = maps;
+ }
+
+ public void setMaps(Iterable<Mapping> maps) {
+ Collection<Mapping> list = Lists.newArrayList(maps);
+ this.maps = list.toArray(new Mapping[list.size()]);
+ }
+
+ public void setValue(double v) {
+ value = v;
+ }
+
+ public void setPayload(T payload) {
+ this.payload = payload;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (!(o instanceof State)) {
+ return false;
+ }
+ State<?,?> other = (State<?,?>) o;
+ return id == other.id && value == other.value;
+ }
+
+ @Override
+ public int hashCode() {
+ return RandomUtils.hashDouble(value) ^ id;
+ }
+
+ /**
+ * Natural order is to sort in descending order of score. Creation order is used as a
+ * tie-breaker.
+ *
+ * @param other The state to compare with.
+ * @return -1, 0, 1 if the other state is better, identical or worse than this one.
+ */
+ @Override
+ public int compareTo(State<T, U> other) {
+ int r = Double.compare(other.value, this.value);
+ if (r != 0) {
+ return r;
+ }
+ if (this.id < other.id) {
+ return -1;
+ }
+ if (this.id > other.id) {
+ return 1;
+ }
+ return 0;
+ }
+
+ @Override
+ public String toString() {
+ double sum = 0;
+ for (double v : step) {
+ sum += v * v;
+ }
+ return String.format(Locale.ENGLISH, "<S/%s %.3f %.3f>", payload, omni + Math.sqrt(sum), value);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(id);
+ out.writeInt(params.length);
+ for (double v : params) {
+ out.writeDouble(v);
+ }
+ for (Mapping map : maps) {
+ PolymorphicWritable.write(out, map);
+ }
+
+ out.writeDouble(omni);
+ for (double v : step) {
+ out.writeDouble(v);
+ }
+
+ out.writeDouble(value);
+ PolymorphicWritable.write(out, payload);
+ }
+
+ @Override
+ public void readFields(DataInput input) throws IOException {
+ id = input.readInt();
+ int n = input.readInt();
+ params = new double[n];
+ for (int i = 0; i < n; i++) {
+ params[i] = input.readDouble();
+ }
+
+ maps = new Mapping[n];
+ for (int i = 0; i < n; i++) {
+ maps[i] = PolymorphicWritable.read(input, Mapping.class);
+ }
+ omni = input.readDouble();
+ step = new double[n];
+ for (int i = 0; i < n; i++) {
+ step[i] = input.readDouble();
+ }
+ value = input.readDouble();
+ payload = (T) PolymorphicWritable.read(input, Payload.class);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/ep/package-info.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/ep/package-info.java b/mr/src/main/java/org/apache/mahout/ep/package-info.java
new file mode 100644
index 0000000..4afe677
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/ep/package-info.java
@@ -0,0 +1,26 @@
+/**
+ * <p>Provides basic evolutionary optimization using <a href="http://arxiv.org/abs/0803.3838">recorded-step</a>
+ * mutation.</p>
+ *
+ * <p>With this style of optimization, we can optimize a function {@code f: R^n -> R} by stochastic
+ * hill-climbing with some of the benefits of conjugate gradient style history encoded in the mutation function.
+ * This mutation function will adapt to allow weakly directed search rather than using the somewhat more
+ * conventional symmetric Gaussian.</p>
+ *
+ * <p>With recorded-step mutation, the meta-mutation parameters are all auto-encoded in the current state of each point.
+ * This avoids the classic problem of having more mutation rate parameters than are in the original state and then
+ * requiring even more parameters to describe the meta-mutation rate. Instead, we store the previous point and one
+ * omni-directional mutation component. Mutation is performed by first mutating along the line formed by the previous
+ * and current points and then adding a scaled symmetric Gaussian. The magnitude of the omni-directional mutation is
+ * then mutated using itself as a scale.</p>
+ *
+ * <p>Because it is convenient to not restrict the parameter space, this package also provides convenient parameter
+ * mapping methods. These mapping methods map the set of reals to a finite open interval (a,b) in such a way that
+ * {@code lim_{x->-\inf} f(x) = a} and {@code lim_{x->\inf} f(x) = b}. The linear mapping is defined so that
+ * {@code f(0) = (a+b)/2} and the exponential mapping requires that a and b are both positive and has
+ * {@code f(0) = sqrt(ab)}. The linear mapping is useful for values that must stay roughly within a range but
+ * which are roughly uniform within the center of that range. The exponential
+ * mapping is useful for values that must stay within a range but whose distribution is roughly exponential near
+ * geometric mean of the end-points. An identity mapping is also supplied.</p>
+ */
+package org.apache.mahout.ep;
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java b/mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java
new file mode 100644
index 0000000..6618a1a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/DistributedRowMatrixWriter.java
@@ -0,0 +1,47 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.math;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+
+import java.io.IOException;
+
+public final class DistributedRowMatrixWriter {
+
+ private DistributedRowMatrixWriter() {
+ }
+
+ public static void write(Path outputDir, Configuration conf, Iterable<MatrixSlice> matrix) throws IOException {
+ FileSystem fs = outputDir.getFileSystem(conf);
+ SequenceFile.Writer writer = SequenceFile.createWriter(fs, conf, outputDir,
+ IntWritable.class, VectorWritable.class);
+ IntWritable topic = new IntWritable();
+ VectorWritable vector = new VectorWritable();
+ for (MatrixSlice slice : matrix) {
+ topic.set(slice.index());
+ vector.set(slice.vector());
+ writer.append(topic, vector);
+ }
+ writer.close();
+
+ }
+
+}
[51/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
MAHOUT-1655 Refactors mr-legacy into mahout-hdfs and mahout-mr, closes apache/mahout#86
Project: http://git-wip-us.apache.org/repos/asf/mahout/repo
Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/b988c493
Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/b988c493
Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/b988c493
Branch: refs/heads/master
Commit: b988c493b562ceeaa5f82027f108c67d06c1fc19
Parents: 0853c06
Author: pferrel <pa...@occamsmachete.com>
Authored: Wed Apr 1 11:06:30 2015 -0700
Committer: pferrel <pa...@occamsmachete.com>
Committed: Wed Apr 1 11:06:30 2015 -0700
----------------------------------------------------------------------
CHANGELOG | 2 +
bin/mahout | 8 +-
distribution/pom.xml | 6 +-
distribution/src/main/assembly/bin.xml | 23 +-
examples/pom.xml | 14 +-
.../classifier/df/mapreduce/TestForest.java | 2 +-
h2o/pom.xml | 4 +-
hdfs/pom.xml | 216 +
.../java/org/apache/mahout/common/IOUtils.java | 194 +
.../org/apache/mahout/math/MatrixWritable.java | 202 +
.../org/apache/mahout/math/VarIntWritable.java | 86 +
.../org/apache/mahout/math/VarLongWritable.java | 83 +
.../java/org/apache/mahout/math/Varint.java | 167 +
.../org/apache/mahout/math/VectorWritable.java | 267 +
.../apache/mahout/math/MatrixWritableTest.java | 148 +
.../java/org/apache/mahout/math/VarintTest.java | 189 +
.../apache/mahout/math/VectorWritableTest.java | 123 +
integration/pom.xml | 14 +-
.../mahout/text/PrefixAdditionFilter.java | 2 +-
.../text/ReadOnlyFileSystemDirectory.java | 4 +-
.../apache/mahout/utils/SequenceFileDumper.java | 2 +-
.../org/apache/mahout/utils/SplitInput.java | 8 +-
.../utils/clustering/JsonClusterWriter.java | 1 +
.../mahout/utils/vectors/VectorDumper.java | 2 +-
.../vectors/lucene/LuceneIterableTest.java | 6 +-
mr/pom.xml | 249 +
mr/src/main/assembly/job.xml | 61 +
mr/src/main/java/org/apache/mahout/Version.java | 41 +
.../cf/taste/common/NoSuchItemException.java | 32 +
.../cf/taste/common/NoSuchUserException.java | 32 +
.../mahout/cf/taste/common/Refreshable.java | 53 +
.../mahout/cf/taste/common/TasteException.java | 41 +
.../mahout/cf/taste/common/Weighting.java | 31 +
.../mahout/cf/taste/eval/DataModelBuilder.java | 45 +
.../mahout/cf/taste/eval/IRStatistics.java | 80 +
.../cf/taste/eval/RecommenderBuilder.java | 45 +
.../cf/taste/eval/RecommenderEvaluator.java | 105 +
.../taste/eval/RecommenderIRStatsEvaluator.java | 64 +
.../taste/eval/RelevantItemsDataSplitter.java | 62 +
.../cf/taste/hadoop/EntityEntityWritable.java | 98 +
.../cf/taste/hadoop/EntityPrefWritable.java | 89 +
.../cf/taste/hadoop/MutableRecommendedItem.java | 81 +
.../taste/hadoop/RecommendedItemsWritable.java | 96 +
.../cf/taste/hadoop/TasteHadoopUtils.java | 84 +
.../cf/taste/hadoop/ToEntityPrefsMapper.java | 78 +
.../cf/taste/hadoop/ToItemPrefsMapper.java | 46 +
.../mahout/cf/taste/hadoop/TopItemsQueue.java | 60 +
.../apache/mahout/cf/taste/hadoop/als/ALS.java | 107 +
.../cf/taste/hadoop/als/DatasetSplitter.java | 158 +
.../hadoop/als/FactorizationEvaluator.java | 172 +
.../hadoop/als/MultithreadedSharingMapper.java | 62 +
.../hadoop/als/ParallelALSFactorizationJob.java | 419 +
.../cf/taste/hadoop/als/PredictionMapper.java | 145 +
.../cf/taste/hadoop/als/RecommenderJob.java | 110 +
.../cf/taste/hadoop/als/SharingMapper.java | 59 +
.../hadoop/als/SolveExplicitFeedbackMapper.java | 61 +
.../hadoop/als/SolveImplicitFeedbackMapper.java | 58 +
.../item/AggregateAndRecommendReducer.java | 220 +
.../mahout/cf/taste/hadoop/item/IDReader.java | 250 +
.../item/ItemFilterAsVectorAndPrefsReducer.java | 62 +
.../cf/taste/hadoop/item/ItemFilterMapper.java | 47 +
.../cf/taste/hadoop/item/ItemIDIndexMapper.java | 56 +
.../taste/hadoop/item/ItemIDIndexReducer.java | 48 +
.../hadoop/item/PartialMultiplyMapper.java | 57 +
.../item/PrefAndSimilarityColumnWritable.java | 85 +
.../cf/taste/hadoop/item/RecommenderJob.java | 337 +
.../item/SimilarityMatrixRowWrapperMapper.java | 54 +
.../taste/hadoop/item/ToUserVectorsReducer.java | 84 +
.../hadoop/item/ToVectorAndPrefReducer.java | 63 +
.../hadoop/item/UserVectorSplitterMapper.java | 116 +
.../hadoop/item/VectorAndPrefsWritable.java | 92 +
.../taste/hadoop/item/VectorOrPrefWritable.java | 104 +
.../preparation/PreparePreferenceMatrixJob.java | 115 +
.../hadoop/preparation/ToItemVectorsMapper.java | 56 +
.../preparation/ToItemVectorsReducer.java | 38 +
.../similarity/item/ItemSimilarityJob.java | 233 +
.../similarity/item/TopSimilarItemsQueue.java | 60 +
.../common/AbstractLongPrimitiveIterator.java | 27 +
.../mahout/cf/taste/impl/common/BitSet.java | 93 +
.../mahout/cf/taste/impl/common/Cache.java | 178 +
.../cf/taste/impl/common/FastByIDMap.java | 661 +
.../mahout/cf/taste/impl/common/FastIDSet.java | 426 +
.../mahout/cf/taste/impl/common/FastMap.java | 729 +
.../taste/impl/common/FixedRunningAverage.java | 83 +
.../common/FixedRunningAverageAndStdDev.java | 51 +
.../taste/impl/common/FullRunningAverage.java | 109 +
.../common/FullRunningAverageAndStdDev.java | 107 +
.../impl/common/InvertedRunningAverage.java | 58 +
.../common/InvertedRunningAverageAndStdDev.java | 63 +
.../impl/common/LongPrimitiveArrayIterator.java | 93 +
.../impl/common/LongPrimitiveIterator.java | 39 +
.../cf/taste/impl/common/RefreshHelper.java | 122 +
.../mahout/cf/taste/impl/common/Retriever.java | 36 +
.../cf/taste/impl/common/RunningAverage.java | 67 +
.../impl/common/RunningAverageAndStdDev.java | 36 +
.../common/SamplingLongPrimitiveIterator.java | 111 +
.../cf/taste/impl/common/SkippingIterator.java | 35 +
.../impl/common/WeightedRunningAverage.java | 100 +
.../common/WeightedRunningAverageAndStdDev.java | 89 +
.../impl/common/jdbc/AbstractJDBCComponent.java | 88 +
.../taste/impl/common/jdbc/EachRowIterator.java | 92 +
.../impl/common/jdbc/ResultSetIterator.java | 66 +
.../AbstractDifferenceRecommenderEvaluator.java | 277 +
...eAbsoluteDifferenceRecommenderEvaluator.java | 59 +
.../GenericRecommenderIRStatsEvaluator.java | 237 +
.../eval/GenericRelevantItemsDataSplitter.java | 83 +
.../cf/taste/impl/eval/IRStatisticsImpl.java | 95 +
.../mahout/cf/taste/impl/eval/LoadCallable.java | 40 +
.../cf/taste/impl/eval/LoadEvaluator.java | 61 +
.../cf/taste/impl/eval/LoadStatistics.java | 34 +
.../eval/OrderBasedRecommenderEvaluator.java | 431 +
.../impl/eval/RMSRecommenderEvaluator.java | 56 +
.../cf/taste/impl/eval/StatsCallable.java | 64 +
.../cf/taste/impl/model/AbstractDataModel.java | 53 +
.../cf/taste/impl/model/AbstractIDMigrator.java | 67 +
.../impl/model/AbstractJDBCIDMigrator.java | 108 +
.../impl/model/BooleanItemPreferenceArray.java | 234 +
.../cf/taste/impl/model/BooleanPreference.java | 64 +
.../impl/model/BooleanUserPreferenceArray.java | 234 +
.../impl/model/GenericBooleanPrefDataModel.java | 320 +
.../cf/taste/impl/model/GenericDataModel.java | 361 +
.../impl/model/GenericItemPreferenceArray.java | 301 +
.../cf/taste/impl/model/GenericPreference.java | 70 +
.../impl/model/GenericUserPreferenceArray.java | 307 +
.../cf/taste/impl/model/MemoryIDMigrator.java | 55 +
.../taste/impl/model/MySQLJDBCIDMigrator.java | 67 +
.../PlusAnonymousConcurrentUserDataModel.java | 352 +
.../impl/model/PlusAnonymousUserDataModel.java | 320 +
.../PlusAnonymousUserLongPrimitiveIterator.java | 90 +
.../cf/taste/impl/model/file/FileDataModel.java | 759 +
.../taste/impl/model/file/FileIDMigrator.java | 117 +
.../neighborhood/AbstractUserNeighborhood.java | 71 +
.../neighborhood/CachingUserNeighborhood.java | 69 +
.../neighborhood/NearestNUserNeighborhood.java | 122 +
.../neighborhood/ThresholdUserNeighborhood.java | 104 +
.../AbstractCandidateItemsStrategy.java | 57 +
.../impl/recommender/AbstractRecommender.java | 140 +
.../AllSimilarItemsCandidateItemsStrategy.java | 50 +
.../AllUnknownItemsCandidateItemsStrategy.java | 41 +
.../impl/recommender/ByRescoreComparator.java | 65 +
.../ByValueRecommendedItemComparator.java | 43 +
.../impl/recommender/CachingRecommender.java | 251 +
.../recommender/EstimatedPreferenceCapper.java | 46 +
.../GenericBooleanPrefItemBasedRecommender.java | 71 +
.../GenericBooleanPrefUserBasedRecommender.java | 82 +
.../GenericItemBasedRecommender.java | 378 +
.../recommender/GenericRecommendedItem.java | 76 +
.../GenericUserBasedRecommender.java | 247 +
.../recommender/ItemAverageRecommender.java | 199 +
.../recommender/ItemUserAverageRecommender.java | 240 +
.../cf/taste/impl/recommender/NullRescorer.java | 86 +
...ItemsNeighborhoodCandidateItemsStrategy.java | 48 +
.../impl/recommender/RandomRecommender.java | 97 +
.../SamplingCandidateItemsStrategy.java | 165 +
.../cf/taste/impl/recommender/SimilarUser.java | 80 +
.../cf/taste/impl/recommender/TopItems.java | 212 +
.../impl/recommender/svd/ALSWRFactorizer.java | 313 +
.../recommender/svd/AbstractFactorizer.java | 94 +
.../impl/recommender/svd/Factorization.java | 137 +
.../taste/impl/recommender/svd/Factorizer.java | 30 +
.../svd/FilePersistenceStrategy.java | 148 +
.../recommender/svd/NoPersistenceStrategy.java | 37 +
.../recommender/svd/ParallelSGDFactorizer.java | 340 +
.../recommender/svd/PersistenceStrategy.java | 46 +
.../recommender/svd/RatingSGDFactorizer.java | 221 +
.../recommender/svd/SVDPlusPlusFactorizer.java | 178 +
.../impl/recommender/svd/SVDPreference.java | 41 +
.../impl/recommender/svd/SVDRecommender.java | 185 +
.../impl/similarity/AbstractItemSimilarity.java | 64 +
.../impl/similarity/AbstractSimilarity.java | 343 +
.../similarity/AveragingPreferenceInferrer.java | 85 +
.../impl/similarity/CachingItemSimilarity.java | 111 +
.../impl/similarity/CachingUserSimilarity.java | 104 +
.../impl/similarity/CityBlockSimilarity.java | 98 +
.../similarity/EuclideanDistanceSimilarity.java | 67 +
.../impl/similarity/GenericItemSimilarity.java | 358 +
.../impl/similarity/GenericUserSimilarity.java | 238 +
.../similarity/LogLikelihoodSimilarity.java | 121 +
.../impl/similarity/LongPairMatchPredicate.java | 40 +
.../PearsonCorrelationSimilarity.java | 93 +
.../SpearmanCorrelationSimilarity.java | 135 +
.../TanimotoCoefficientSimilarity.java | 126 +
.../similarity/UncenteredCosineSimilarity.java | 69 +
.../file/FileItemItemSimilarityIterable.java | 46 +
.../file/FileItemItemSimilarityIterator.java | 60 +
.../similarity/file/FileItemSimilarity.java | 137 +
.../precompute/FileSimilarItemsWriter.java | 67 +
.../MultithreadedBatchItemSimilarities.java | 230 +
.../apache/mahout/cf/taste/model/DataModel.java | 199 +
.../mahout/cf/taste/model/IDMigrator.java | 63 +
.../mahout/cf/taste/model/JDBCDataModel.java | 43 +
.../mahout/cf/taste/model/Preference.java | 48 +
.../mahout/cf/taste/model/PreferenceArray.java | 143 +
.../cf/taste/model/UpdatableIDMigrator.java | 47 +
.../cf/taste/neighborhood/UserNeighborhood.java | 40 +
.../recommender/CandidateItemsStrategy.java | 37 +
.../mahout/cf/taste/recommender/IDRescorer.java | 47 +
.../taste/recommender/ItemBasedRecommender.java | 145 +
.../MostSimilarItemsCandidateItemsStrategy.java | 31 +
.../cf/taste/recommender/RecommendedItem.java | 41 +
.../cf/taste/recommender/Recommender.java | 132 +
.../mahout/cf/taste/recommender/Rescorer.java | 52 +
.../taste/recommender/UserBasedRecommender.java | 54 +
.../cf/taste/similarity/ItemSimilarity.java | 64 +
.../cf/taste/similarity/PreferenceInferrer.java | 47 +
.../cf/taste/similarity/UserSimilarity.java | 58 +
.../precompute/BatchItemSimilarities.java | 56 +
.../similarity/precompute/SimilarItem.java | 56 +
.../similarity/precompute/SimilarItems.java | 84 +
.../precompute/SimilarItemsWriter.java | 33 +
.../classifier/AbstractVectorClassifier.java | 248 +
.../mahout/classifier/ClassifierResult.java | 74 +
.../mahout/classifier/ConfusionMatrix.java | 444 +
.../apache/mahout/classifier/OnlineLearner.java | 96 +
.../classifier/RegressionResultAnalyzer.java | 144 +
.../mahout/classifier/ResultAnalyzer.java | 132 +
.../apache/mahout/classifier/df/Bagging.java | 60 +
.../apache/mahout/classifier/df/DFUtils.java | 181 +
.../mahout/classifier/df/DecisionForest.java | 244 +
.../mahout/classifier/df/ErrorEstimate.java | 50 +
.../df/builder/DecisionTreeBuilder.java | 421 +
.../df/builder/DefaultTreeBuilder.java | 252 +
.../classifier/df/builder/TreeBuilder.java | 41 +
.../apache/mahout/classifier/df/data/Data.java | 280 +
.../classifier/df/data/DataConverter.java | 71 +
.../mahout/classifier/df/data/DataLoader.java | 253 +
.../mahout/classifier/df/data/DataUtils.java | 88 +
.../mahout/classifier/df/data/Dataset.java | 421 +
.../classifier/df/data/DescriptorException.java | 27 +
.../classifier/df/data/DescriptorUtils.java | 109 +
.../mahout/classifier/df/data/Instance.java | 74 +
.../df/data/conditions/Condition.java | 56 +
.../classifier/df/data/conditions/Equals.java | 41 +
.../df/data/conditions/GreaterOrEquals.java | 41 +
.../classifier/df/data/conditions/Lesser.java | 41 +
.../mahout/classifier/df/mapreduce/Builder.java | 332 +
.../classifier/df/mapreduce/Classifier.java | 237 +
.../classifier/df/mapreduce/MapredMapper.java | 74 +
.../classifier/df/mapreduce/MapredOutput.java | 119 +
.../df/mapreduce/inmem/InMemBuilder.java | 113 +
.../df/mapreduce/inmem/InMemInputFormat.java | 283 +
.../df/mapreduce/inmem/InMemMapper.java | 105 +
.../df/mapreduce/inmem/package-info.java | 22 +
.../df/mapreduce/partial/PartialBuilder.java | 157 +
.../df/mapreduce/partial/Step1Mapper.java | 167 +
.../classifier/df/mapreduce/partial/TreeID.java | 57 +
.../df/mapreduce/partial/package-info.java | 16 +
.../classifier/df/node/CategoricalNode.java | 134 +
.../apache/mahout/classifier/df/node/Leaf.java | 94 +
.../apache/mahout/classifier/df/node/Node.java | 95 +
.../classifier/df/node/NumericalNode.java | 114 +
.../classifier/df/ref/SequentialBuilder.java | 77 +
.../classifier/df/split/DefaultIgSplit.java | 117 +
.../mahout/classifier/df/split/IgSplit.java | 34 +
.../mahout/classifier/df/split/OptIgSplit.java | 231 +
.../classifier/df/split/RegressionSplit.java | 176 +
.../mahout/classifier/df/split/Split.java | 67 +
.../mahout/classifier/df/tools/Describe.java | 148 +
.../classifier/df/tools/ForestVisualizer.java | 157 +
.../mahout/classifier/df/tools/Frequencies.java | 121 +
.../classifier/df/tools/FrequenciesJob.java | 296 +
.../classifier/df/tools/TreeVisualizer.java | 263 +
.../mahout/classifier/df/tools/UDistrib.java | 211 +
.../mahout/classifier/evaluation/Auc.java | 233 +
.../classifier/mlp/MultilayerPerceptron.java | 90 +
.../mahout/classifier/mlp/NeuralNetwork.java | 743 +
.../classifier/mlp/NeuralNetworkFunctions.java | 150 +
.../classifier/mlp/RunMultilayerPerceptron.java | 227 +
.../mlp/TrainMultilayerPerceptron.java | 332 +
.../AbstractNaiveBayesClassifier.java | 82 +
.../classifier/naivebayes/BayesUtils.java | 167 +
.../ComplementaryNaiveBayesClassifier.java | 43 +
.../classifier/naivebayes/NaiveBayesModel.java | 176 +
.../StandardNaiveBayesClassifier.java | 40 +
.../naivebayes/test/BayesTestMapper.java | 76 +
.../naivebayes/test/TestNaiveBayesDriver.java | 179 +
.../training/ComplementaryThetaTrainer.java | 83 +
.../training/IndexInstancesMapper.java | 53 +
.../naivebayes/training/ThetaMapper.java | 61 +
.../naivebayes/training/TrainNaiveBayesJob.java | 186 +
.../naivebayes/training/WeightsMapper.java | 68 +
.../sequencelearning/hmm/BaumWelchTrainer.java | 165 +
.../sequencelearning/hmm/HmmAlgorithms.java | 306 +
.../sequencelearning/hmm/HmmEvaluator.java | 194 +
.../sequencelearning/hmm/HmmModel.java | 383 +
.../sequencelearning/hmm/HmmTrainer.java | 488 +
.../sequencelearning/hmm/HmmUtils.java | 361 +
.../hmm/LossyHmmSerializer.java | 62 +
.../hmm/RandomSequenceGenerator.java | 108 +
.../sequencelearning/hmm/ViterbiEvaluator.java | 127 +
.../sgd/AbstractOnlineLogisticRegression.java | 317 +
.../sgd/AdaptiveLogisticRegression.java | 586 +
.../mahout/classifier/sgd/CrossFoldLearner.java | 334 +
.../mahout/classifier/sgd/CsvRecordFactory.java | 393 +
.../mahout/classifier/sgd/DefaultGradient.java | 49 +
.../mahout/classifier/sgd/ElasticBandPrior.java | 76 +
.../apache/mahout/classifier/sgd/Gradient.java | 30 +
.../mahout/classifier/sgd/GradientMachine.java | 405 +
.../org/apache/mahout/classifier/sgd/L1.java | 59 +
.../org/apache/mahout/classifier/sgd/L2.java | 66 +
.../mahout/classifier/sgd/MixedGradient.java | 66 +
.../mahout/classifier/sgd/ModelDissector.java | 232 +
.../mahout/classifier/sgd/ModelSerializer.java | 76 +
.../sgd/OnlineLogisticRegression.java | 172 +
.../classifier/sgd/PassiveAggressive.java | 204 +
.../classifier/sgd/PolymorphicWritable.java | 46 +
.../mahout/classifier/sgd/PriorFunction.java | 45 +
.../mahout/classifier/sgd/RankingGradient.java | 85 +
.../mahout/classifier/sgd/RecordFactory.java | 47 +
.../apache/mahout/classifier/sgd/TPrior.java | 61 +
.../mahout/classifier/sgd/UniformPrior.java | 47 +
.../mahout/classifier/sgd/package-info.java | 23 +
.../mahout/clustering/AbstractCluster.java | 391 +
.../org/apache/mahout/clustering/Cluster.java | 90 +
.../mahout/clustering/ClusteringUtils.java | 305 +
.../mahout/clustering/GaussianAccumulator.java | 62 +
.../org/apache/mahout/clustering/Model.java | 93 +
.../mahout/clustering/ModelDistribution.java | 41 +
.../clustering/OnlineGaussianAccumulator.java | 107 +
.../RunningSumsGaussianAccumulator.java | 90 +
.../clustering/UncommonDistributions.java | 136 +
.../apache/mahout/clustering/canopy/Canopy.java | 60 +
.../clustering/canopy/CanopyClusterer.java | 220 +
.../clustering/canopy/CanopyConfigKeys.java | 70 +
.../mahout/clustering/canopy/CanopyDriver.java | 379 +
.../mahout/clustering/canopy/CanopyMapper.java | 66 +
.../mahout/clustering/canopy/CanopyReducer.java | 70 +
.../ClusterClassificationConfigKeys.java | 33 +
.../classify/ClusterClassificationDriver.java | 313 +
.../classify/ClusterClassificationMapper.java | 161 +
.../clustering/classify/ClusterClassifier.java | 240 +
.../WeightedPropertyVectorWritable.java | 95 +
.../classify/WeightedVectorWritable.java | 72 +
.../fuzzykmeans/FuzzyKMeansClusterer.java | 59 +
.../fuzzykmeans/FuzzyKMeansDriver.java | 324 +
.../clustering/fuzzykmeans/FuzzyKMeansUtil.java | 76 +
.../clustering/fuzzykmeans/SoftCluster.java | 60 +
.../iterator/AbstractClusteringPolicy.java | 72 +
.../mahout/clustering/iterator/CIMapper.java | 71 +
.../mahout/clustering/iterator/CIReducer.java | 64 +
.../iterator/CanopyClusteringPolicy.java | 52 +
.../clustering/iterator/ClusterIterator.java | 219 +
.../clustering/iterator/ClusterWritable.java | 56 +
.../clustering/iterator/ClusteringPolicy.java | 66 +
.../iterator/ClusteringPolicyWritable.java | 55 +
.../iterator/DistanceMeasureCluster.java | 91 +
.../iterator/FuzzyKMeansClusteringPolicy.java | 91 +
.../iterator/KMeansClusteringPolicy.java | 64 +
.../clustering/kernel/IKernelProfile.java | 27 +
.../kernel/TriangularKernelProfile.java | 27 +
.../mahout/clustering/kmeans/KMeansDriver.java | 257 +
.../mahout/clustering/kmeans/KMeansUtil.java | 74 +
.../mahout/clustering/kmeans/Kluster.java | 117 +
.../clustering/kmeans/RandomSeedGenerator.java | 139 +
.../mahout/clustering/kmeans/package-info.java | 5 +
.../lda/cvb/CVB0DocInferenceMapper.java | 51 +
.../mahout/clustering/lda/cvb/CVB0Driver.java | 536 +
.../CVB0TopicTermVectorNormalizerMapper.java | 38 +
.../clustering/lda/cvb/CachingCVB0Mapper.java | 133 +
.../lda/cvb/CachingCVB0PerplexityMapper.java | 108 +
.../cvb/InMemoryCollapsedVariationalBayes0.java | 515 +
.../mahout/clustering/lda/cvb/ModelTrainer.java | 301 +
.../mahout/clustering/lda/cvb/TopicModel.java | 513 +
.../apache/mahout/clustering/package-info.java | 13 +
.../spectral/AffinityMatrixInputJob.java | 84 +
.../spectral/AffinityMatrixInputMapper.java | 78 +
.../spectral/AffinityMatrixInputReducer.java | 59 +
.../spectral/IntDoublePairWritable.java | 75 +
.../apache/mahout/clustering/spectral/Keys.java | 31 +
.../spectral/MatrixDiagonalizeJob.java | 108 +
.../clustering/spectral/UnitVectorizerJob.java | 79 +
.../mahout/clustering/spectral/VectorCache.java | 123 +
.../spectral/VectorMatrixMultiplicationJob.java | 139 +
.../clustering/spectral/VertexWritable.java | 101 +
.../spectral/kmeans/EigenSeedGenerator.java | 124 +
.../spectral/kmeans/SpectralKMeansDriver.java | 243 +
.../streaming/cluster/BallKMeans.java | 456 +
.../streaming/cluster/StreamingKMeans.java | 368 +
.../streaming/mapreduce/CentroidWritable.java | 88 +
.../mapreduce/StreamingKMeansDriver.java | 493 +
.../mapreduce/StreamingKMeansMapper.java | 102 +
.../mapreduce/StreamingKMeansReducer.java | 109 +
.../mapreduce/StreamingKMeansThread.java | 92 +
.../mapreduce/StreamingKMeansUtilsMR.java | 163 +
.../streaming/tools/ResplitSequenceFiles.java | 149 +
.../clustering/topdown/PathDirectory.java | 94 +
.../postprocessor/ClusterCountReader.java | 103 +
.../ClusterOutputPostProcessor.java | 139 +
.../ClusterOutputPostProcessorDriver.java | 182 +
.../ClusterOutputPostProcessorMapper.java | 58 +
.../ClusterOutputPostProcessorReducer.java | 62 +
.../org/apache/mahout/common/AbstractJob.java | 658 +
.../org/apache/mahout/common/ClassUtils.java | 61 +
.../apache/mahout/common/CommandLineUtil.java | 68 +
.../org/apache/mahout/common/HadoopUtil.java | 442 +
.../apache/mahout/common/IntPairWritable.java | 270 +
.../org/apache/mahout/common/IntegerTuple.java | 176 +
.../java/org/apache/mahout/common/LongPair.java | 80 +
.../org/apache/mahout/common/MemoryUtil.java | 99 +
.../java/org/apache/mahout/common/Pair.java | 99 +
.../org/apache/mahout/common/Parameters.java | 98 +
.../org/apache/mahout/common/StringTuple.java | 177 +
.../org/apache/mahout/common/StringUtils.java | 63 +
.../apache/mahout/common/TimingStatistics.java | 154 +
.../commandline/DefaultOptionCreator.java | 417 +
.../distance/ChebyshevDistanceMeasure.java | 63 +
.../common/distance/CosineDistanceMeasure.java | 119 +
.../mahout/common/distance/DistanceMeasure.java | 48 +
.../distance/EuclideanDistanceMeasure.java | 41 +
.../distance/MahalanobisDistanceMeasure.java | 204 +
.../distance/ManhattanDistanceMeasure.java | 70 +
.../distance/MinkowskiDistanceMeasure.java | 93 +
.../SquaredEuclideanDistanceMeasure.java | 59 +
.../distance/TanimotoDistanceMeasure.java | 69 +
.../distance/WeightedDistanceMeasure.java | 97 +
.../WeightedEuclideanDistanceMeasure.java | 52 +
.../WeightedManhattanDistanceMeasure.java | 53 +
.../iterator/CopyConstructorIterator.java | 64 +
.../common/iterator/CountingIterator.java | 43 +
.../common/iterator/FileLineIterable.java | 88 +
.../common/iterator/FileLineIterator.java | 167 +
.../iterator/FixedSizeSamplingIterator.java | 59 +
.../common/iterator/SamplingIterable.java | 45 +
.../common/iterator/SamplingIterator.java | 73 +
.../StableFixedSizeSamplingIterator.java | 72 +
.../common/iterator/StringRecordIterator.java | 55 +
.../iterator/sequencefile/PathFilters.java | 81 +
.../common/iterator/sequencefile/PathType.java | 27 +
.../sequencefile/SequenceFileDirIterable.java | 84 +
.../sequencefile/SequenceFileDirIterator.java | 136 +
.../SequenceFileDirValueIterable.java | 83 +
.../SequenceFileDirValueIterator.java | 159 +
.../sequencefile/SequenceFileIterable.java | 68 +
.../sequencefile/SequenceFileIterator.java | 118 +
.../sequencefile/SequenceFileValueIterable.java | 67 +
.../sequencefile/SequenceFileValueIterator.java | 97 +
.../mahout/common/lucene/AnalyzerUtils.java | 61 +
.../common/lucene/IteratorTokenStream.java | 45 +
.../common/lucene/TokenStreamIterator.java | 57 +
.../common/mapreduce/MergeVectorsCombiner.java | 34 +
.../common/mapreduce/MergeVectorsReducer.java | 40 +
.../common/mapreduce/TransposeMapper.java | 49 +
.../common/mapreduce/VectorSumCombiner.java | 38 +
.../common/mapreduce/VectorSumReducer.java | 35 +
.../org/apache/mahout/common/nlp/NGrams.java | 94 +
.../common/parameters/AbstractParameter.java | 120 +
.../common/parameters/ClassParameter.java | 44 +
.../common/parameters/DoubleParameter.java | 33 +
.../mahout/common/parameters/Parameter.java | 62 +
.../mahout/common/parameters/Parametered.java | 206 +
.../mahout/common/parameters/PathParameter.java | 33 +
.../org/apache/mahout/driver/MahoutDriver.java | 244 +
.../apache/mahout/ep/EvolutionaryProcess.java | 228 +
.../main/java/org/apache/mahout/ep/Mapping.java | 206 +
.../main/java/org/apache/mahout/ep/Payload.java | 36 +
.../main/java/org/apache/mahout/ep/State.java | 302 +
.../java/org/apache/mahout/ep/package-info.java | 26 +
.../mahout/math/DistributedRowMatrixWriter.java | 47 +
.../org/apache/mahout/math/MatrixUtils.java | 114 +
.../mahout/math/MultiLabelVectorWritable.java | 88 +
.../math/hadoop/DistributedRowMatrix.java | 385 +
.../math/hadoop/MatrixColumnMeansJob.java | 236 +
.../math/hadoop/MatrixMultiplicationJob.java | 177 +
.../mahout/math/hadoop/TimesSquaredJob.java | 251 +
.../apache/mahout/math/hadoop/TransposeJob.java | 85 +
.../decomposer/DistributedLanczosSolver.java | 298 +
.../math/hadoop/decomposer/EigenVector.java | 76 +
.../hadoop/decomposer/EigenVerificationJob.java | 332 +
.../decomposer/HdfsBackedLanczosState.java | 237 +
.../math/hadoop/similarity/SeedVectorUtil.java | 104 +
.../VectorDistanceInvertedMapper.java | 71 +
.../hadoop/similarity/VectorDistanceMapper.java | 80 +
.../similarity/VectorDistanceSimilarityJob.java | 153 +
.../similarity/cooccurrence/MutableElement.java | 50 +
.../cooccurrence/RowSimilarityJob.java | 562 +
.../cooccurrence/TopElementsQueue.java | 59 +
.../hadoop/similarity/cooccurrence/Vectors.java | 199 +
.../measures/CityBlockSimilarity.java | 26 +
.../measures/CooccurrenceCountSimilarity.java | 32 +
.../cooccurrence/measures/CosineSimilarity.java | 50 +
.../measures/CountbasedMeasure.java | 44 +
.../measures/EuclideanDistanceSimilarity.java | 57 +
.../measures/LoglikelihoodSimilarity.java | 34 +
.../measures/PearsonCorrelationSimilarity.java | 37 +
.../measures/TanimotoCoefficientSimilarity.java | 34 +
.../measures/VectorSimilarityMeasure.java | 32 +
.../measures/VectorSimilarityMeasures.java | 46 +
.../DistributedConjugateGradientSolver.java | 172 +
.../mahout/math/hadoop/stats/BasicStats.java | 148 +
.../StandardDeviationCalculatorMapper.java | 55 +
.../StandardDeviationCalculatorReducer.java | 37 +
.../math/hadoop/stats/VarianceTotals.java | 68 +
.../hadoop/stochasticsvd/ABtDenseOutJob.java | 585 +
.../math/hadoop/stochasticsvd/ABtJob.java | 494 +
.../mahout/math/hadoop/stochasticsvd/BtJob.java | 628 +
.../stochasticsvd/DenseBlockWritable.java | 83 +
.../mahout/math/hadoop/stochasticsvd/Omega.java | 257 +
.../mahout/math/hadoop/stochasticsvd/QJob.java | 237 +
.../math/hadoop/stochasticsvd/SSVDCli.java | 201 +
.../math/hadoop/stochasticsvd/SSVDHelper.java | 322 +
.../math/hadoop/stochasticsvd/SSVDSolver.java | 662 +
.../SparseRowBlockAccumulator.java | 90 +
.../stochasticsvd/SparseRowBlockWritable.java | 159 +
.../stochasticsvd/SplitPartitionedWritable.java | 151 +
.../mahout/math/hadoop/stochasticsvd/UJob.java | 170 +
.../mahout/math/hadoop/stochasticsvd/VJob.java | 224 +
.../math/hadoop/stochasticsvd/YtYJob.java | 220 +
.../stochasticsvd/qr/GivensThinSolver.java | 638 +
.../hadoop/stochasticsvd/qr/GramSchmidt.java | 52 +
.../hadoop/stochasticsvd/qr/QRFirstStep.java | 284 +
.../hadoop/stochasticsvd/qr/QRLastStep.java | 144 +
.../mahout/math/neighborhood/BruteSearch.java | 186 +
.../math/neighborhood/FastProjectionSearch.java | 326 +
.../mahout/math/neighborhood/HashedVector.java | 103 +
.../LocalitySensitiveHashSearch.java | 295 +
.../math/neighborhood/ProjectionSearch.java | 233 +
.../mahout/math/neighborhood/Searcher.java | 155 +
.../math/neighborhood/UpdatableSearcher.java | 37 +
.../mahout/math/random/RandomProjector.java | 133 +
.../math/ssvd/SequentialOutOfCoreSvd.java | 233 +
.../mahout/math/stats/GlobalOnlineAuc.java | 168 +
.../mahout/math/stats/GroupedOnlineAuc.java | 113 +
.../org/apache/mahout/math/stats/OnlineAuc.java | 38 +
.../org/apache/mahout/math/stats/Sampler.java | 79 +
.../mahout/vectorizer/DictionaryVectorizer.java | 416 +
.../mahout/vectorizer/DocumentProcessor.java | 99 +
.../EncodedVectorsFromSequenceFiles.java | 104 +
.../mahout/vectorizer/EncodingMapper.java | 92 +
.../mahout/vectorizer/HighDFWordsPruner.java | 147 +
.../SimpleTextEncodingVectorizer.java | 72 +
.../SparseVectorsFromSequenceFiles.java | 369 +
.../java/org/apache/mahout/vectorizer/TF.java | 30 +
.../org/apache/mahout/vectorizer/TFIDF.java | 31 +
.../apache/mahout/vectorizer/Vectorizer.java | 29 +
.../mahout/vectorizer/VectorizerConfig.java | 179 +
.../org/apache/mahout/vectorizer/Weight.java | 32 +
.../collocations/llr/CollocCombiner.java | 46 +
.../collocations/llr/CollocDriver.java | 284 +
.../collocations/llr/CollocMapper.java | 178 +
.../collocations/llr/CollocReducer.java | 176 +
.../vectorizer/collocations/llr/Gram.java | 239 +
.../vectorizer/collocations/llr/GramKey.java | 133 +
.../llr/GramKeyGroupComparator.java | 43 +
.../collocations/llr/GramKeyPartitioner.java | 40 +
.../vectorizer/collocations/llr/LLRReducer.java | 170 +
.../common/PartialVectorMergeReducer.java | 89 +
.../vectorizer/common/PartialVectorMerger.java | 144 +
.../document/SequenceFileTokenizerMapper.java | 70 +
.../encoders/AdaptiveWordValueEncoder.java | 69 +
.../encoders/CachingContinuousValueEncoder.java | 64 +
.../encoders/CachingStaticWordValueEncoder.java | 66 +
.../encoders/CachingTextValueEncoder.java | 25 +
.../encoders/CachingValueEncoder.java | 64 +
.../encoders/ConstantValueEncoder.java | 57 +
.../encoders/ContinuousValueEncoder.java | 76 +
.../mahout/vectorizer/encoders/Dictionary.java | 55 +
.../encoders/FeatureVectorEncoder.java | 279 +
.../encoders/InteractionValueEncoder.java | 126 +
.../encoders/LuceneTextValueEncoder.java | 133 +
.../encoders/StaticWordValueEncoder.java | 80 +
.../vectorizer/encoders/TextValueEncoder.java | 142 +
.../vectorizer/encoders/WordValueEncoder.java | 81 +
.../pruner/PrunedPartialVectorMergeReducer.java | 65 +
.../vectorizer/pruner/WordsPrunerReducer.java | 86 +
.../vectorizer/term/TFPartialVectorReducer.java | 139 +
.../vectorizer/term/TermCountCombiner.java | 41 +
.../mahout/vectorizer/term/TermCountMapper.java | 58 +
.../vectorizer/term/TermCountReducer.java | 55 +
.../term/TermDocumentCountMapper.java | 50 +
.../term/TermDocumentCountReducer.java | 41 +
.../mahout/vectorizer/tfidf/TFIDFConverter.java | 361 +
.../tfidf/TFIDFPartialVectorReducer.java | 114 +
mr/src/main/resources/version | 1 +
.../mahout/cf/taste/common/CommonTest.java | 60 +
.../cf/taste/hadoop/TasteHadoopUtilsTest.java | 40 +
.../cf/taste/hadoop/TopItemsQueueTest.java | 72 +
.../als/ParallelALSFactorizationJobTest.java | 379 +
.../cf/taste/hadoop/item/IDReaderTest.java | 66 +
.../taste/hadoop/item/RecommenderJobTest.java | 928 +
.../hadoop/item/ToUserVectorsReducerTest.java | 74 +
.../similarity/item/ItemSimilarityJobTest.java | 269 +
.../mahout/cf/taste/impl/TasteTestCase.java | 98 +
.../mahout/cf/taste/impl/common/BitSetTest.java | 74 +
.../mahout/cf/taste/impl/common/CacheTest.java | 61 +
.../cf/taste/impl/common/FastByIDMapTest.java | 147 +
.../cf/taste/impl/common/FastIDSetTest.java | 162 +
.../cf/taste/impl/common/FastMapTest.java | 228 +
.../impl/common/InvertedRunningAverageTest.java | 88 +
.../common/LongPrimitiveArrayIteratorTest.java | 56 +
.../cf/taste/impl/common/MockRefreshable.java | 45 +
.../cf/taste/impl/common/RefreshHelperTest.java | 70 +
.../common/RunningAverageAndStdDevTest.java | 107 +
.../taste/impl/common/RunningAverageTest.java | 75 +
.../SamplingLongPrimitiveIteratorTest.java | 91 +
.../impl/common/WeightedRunningAverageTest.java | 85 +
...ericRecommenderIRStatsEvaluatorImplTest.java | 73 +
.../taste/impl/eval/LoadEvaluationRunner.java | 68 +
.../model/BooleanItemPreferenceArrayTest.java | 89 +
.../model/BooleanUserPreferenceArrayTest.java | 89 +
.../taste/impl/model/GenericDataModelTest.java | 51 +
.../model/GenericItemPreferenceArrayTest.java | 110 +
.../model/GenericUserPreferenceArrayTest.java | 110 +
.../taste/impl/model/MemoryIDMigratorTest.java | 57 +
...lusAnonymousConcurrentUserDataModelTest.java | 313 +
.../impl/model/file/FileDataModelTest.java | 216 +
.../impl/model/file/FileIDMigratorTest.java | 103 +
.../impl/neighborhood/DummySimilarity.java | 68 +
.../neighborhood/NearestNNeighborhoodTest.java | 53 +
.../neighborhood/ThresholdNeighborhoodTest.java | 51 +
...lUnknownItemsCandidateItemsStrategyTest.java | 65 +
.../recommender/CachingRecommenderTest.java | 78 +
.../GenericItemBasedRecommenderTest.java | 324 +
.../GenericUserBasedRecommenderTest.java | 174 +
.../recommender/ItemAverageRecommenderTest.java | 43 +
.../ItemUserAverageRecommenderTest.java | 43 +
.../taste/impl/recommender/MockRecommender.java | 89 +
.../impl/recommender/NullRescorerTest.java | 47 +
...sNeighborhoodCandidateItemsStrategyTest.java | 75 +
.../impl/recommender/RandomRecommenderTest.java | 41 +
.../impl/recommender/ReversingRescorer.java | 46 +
.../SamplingCandidateItemsStrategyTest.java | 71 +
.../cf/taste/impl/recommender/TopItemsTest.java | 158 +
.../recommender/svd/ALSWRFactorizerTest.java | 208 +
.../svd/FilePersistenceStrategyTest.java | 53 +
.../svd/ParallelSGDFactorizerTest.java | 355 +
.../recommender/svd/SVDRecommenderTest.java | 86 +
.../AveragingPreferenceInferrerTest.java | 37 +
.../EuclideanDistanceSimilarityTest.java | 236 +
.../similarity/GenericItemSimilarityTest.java | 104 +
.../similarity/LogLikelihoodSimilarityTest.java | 80 +
.../PearsonCorrelationSimilarityTest.java | 265 +
.../impl/similarity/SimilarityTestCase.java | 35 +
.../SpearmanCorrelationSimilarityTest.java | 80 +
.../TanimotoCoefficientSimilarityTest.java | 121 +
.../similarity/file/FileItemSimilarityTest.java | 142 +
.../MultithreadedBatchItemSimilaritiesTest.java | 80 +
.../mahout/classifier/ClassifierData.java | 102 +
.../mahout/classifier/ConfusionMatrixTest.java | 119 +
.../RegressionResultAnalyzerTest.java | 128 +
.../classifier/df/DecisionForestTest.java | 206 +
.../df/builder/DecisionTreeBuilderTest.java | 78 +
.../df/builder/DefaultTreeBuilderTest.java | 74 +
.../df/builder/InfiniteRecursionTest.java | 60 +
.../classifier/df/data/DataConverterTest.java | 60 +
.../classifier/df/data/DataLoaderTest.java | 350 +
.../mahout/classifier/df/data/DataTest.java | 396 +
.../mahout/classifier/df/data/DatasetTest.java | 72 +
.../classifier/df/data/DescriptorUtilsTest.java | 92 +
.../apache/mahout/classifier/df/data/Utils.java | 283 +
.../mapreduce/inmem/InMemInputFormatTest.java | 109 +
.../df/mapreduce/inmem/InMemInputSplitTest.java | 77 +
.../mapreduce/partial/PartialBuilderTest.java | 197 +
.../df/mapreduce/partial/Step1MapperTest.java | 160 +
.../df/mapreduce/partial/TreeIDTest.java | 48 +
.../mahout/classifier/df/node/NodeTest.java | 108 +
.../classifier/df/split/DefaultIgSplitTest.java | 78 +
.../df/split/RegressionSplitTest.java | 87 +
.../classifier/df/tools/VisualizerTest.java | 211 +
.../mahout/classifier/evaluation/AucTest.java | 86 +
.../apache/mahout/classifier/mlp/Datasets.java | 866 +
.../mlp/RunMultilayerPerceptronTest.java | 66 +
.../mlp/TestMultilayerPerceptron.java | 88 +
.../classifier/mlp/TestNeuralNetwork.java | 353 +
.../mlp/TrainMultilayerPerceptronTest.java | 105 +
.../ComplementaryNaiveBayesClassifierTest.java | 47 +
.../naivebayes/NaiveBayesModelTest.java | 36 +
.../classifier/naivebayes/NaiveBayesTest.java | 135 +
.../naivebayes/NaiveBayesTestBase.java | 135 +
.../StandardNaiveBayesClassifierTest.java | 47 +
.../training/IndexInstancesMapperTest.java | 85 +
.../naivebayes/training/ThetaMapperTest.java | 61 +
.../naivebayes/training/WeightsMapperTest.java | 60 +
.../sequencelearning/hmm/HMMAlgorithmsTest.java | 164 +
.../sequencelearning/hmm/HMMEvaluatorTest.java | 63 +
.../sequencelearning/hmm/HMMModelTest.java | 32 +
.../sequencelearning/hmm/HMMTestBase.java | 73 +
.../sequencelearning/hmm/HMMTrainerTest.java | 163 +
.../sequencelearning/hmm/HMMUtilsTest.java | 161 +
.../sgd/AdaptiveLogisticRegressionTest.java | 186 +
.../classifier/sgd/CsvRecordFactoryTest.java | 90 +
.../classifier/sgd/GradientMachineTest.java | 41 +
.../classifier/sgd/ModelSerializerTest.java | 162 +
.../mahout/classifier/sgd/OnlineBaseTest.java | 160 +
.../sgd/OnlineLogisticRegressionTest.java | 330 +
.../classifier/sgd/PassiveAggressiveTest.java | 35 +
.../mahout/clustering/ClusteringTestUtils.java | 152 +
.../mahout/clustering/TestClusterInterface.java | 83 +
.../clustering/TestGaussianAccumulators.java | 186 +
.../clustering/canopy/TestCanopyCreation.java | 674 +
.../ClusterClassificationDriverTest.java | 255 +
.../fuzzykmeans/TestFuzzyKmeansClustering.java | 202 +
.../iterator/TestClusterClassifier.java | 238 +
.../clustering/kmeans/TestKmeansClustering.java | 385 +
.../kmeans/TestRandomSeedGenerator.java | 169 +
.../clustering/lda/cvb/TestCVBModelTrainer.java | 138 +
.../spectral/TestAffinityMatrixInputJob.java | 145 +
.../spectral/TestMatrixDiagonalizeJob.java | 116 +
.../spectral/TestUnitVectorizerJob.java | 65 +
.../clustering/spectral/TestVectorCache.java | 110 +
.../TestVectorMatrixMultiplicationJob.java | 75 +
.../spectral/kmeans/TestEigenSeedGenerator.java | 100 +
.../streaming/cluster/BallKMeansTest.java | 196 +
.../clustering/streaming/cluster/DataUtils.java | 92 +
.../streaming/cluster/StreamingKMeansTest.java | 169 +
.../mapreduce/StreamingKMeansTestMR.java | 283 +
.../tools/ResplitSequenceFilesTest.java | 80 +
.../clustering/topdown/PathDirectoryTest.java | 65 +
.../postprocessor/ClusterCountReaderTest.java | 121 +
.../ClusterOutputPostProcessorTest.java | 205 +
.../apache/mahout/common/AbstractJobTest.java | 240 +
.../DistributedCacheFileLocationTest.java | 46 +
.../mahout/common/DummyOutputCollector.java | 57 +
.../apache/mahout/common/DummyRecordWriter.java | 223 +
.../mahout/common/DummyRecordWriterTest.java | 45 +
.../mahout/common/DummyStatusReporter.java | 76 +
.../mahout/common/IntPairWritableTest.java | 114 +
.../apache/mahout/common/MahoutTestCase.java | 148 +
.../org/apache/mahout/common/MockIterator.java | 51 +
.../apache/mahout/common/StringUtilsTest.java | 70 +
.../distance/CosineDistanceMeasureTest.java | 66 +
.../distance/DefaultDistanceMeasureTest.java | 103 +
.../DefaultWeightedDistanceMeasureTest.java | 56 +
.../common/distance/TestChebyshevMeasure.java | 55 +
.../distance/TestEuclideanDistanceMeasure.java | 26 +
.../TestMahalanobisDistanceMeasure.java | 56 +
.../distance/TestManhattanDistanceMeasure.java | 26 +
.../common/distance/TestMinkowskiMeasure.java | 64 +
.../distance/TestTanimotoDistanceMeasure.java | 25 +
...estWeightedEuclideanDistanceMeasureTest.java | 25 +
.../TestWeightedManhattanDistanceMeasure.java | 26 +
.../common/iterator/CountingIteratorTest.java | 44 +
.../mahout/common/iterator/SamplerCase.java | 101 +
.../common/iterator/TestFixedSizeSampler.java | 33 +
.../common/iterator/TestSamplingIterator.java | 77 +
.../iterator/TestStableFixedSizeSampler.java | 33 +
.../mahout/common/lucene/AnalyzerUtilsTest.java | 38 +
.../apache/mahout/driver/MahoutDriverTest.java | 32 +
.../mahout/ep/EvolutionaryProcessTest.java | 81 +
.../apache/mahout/math/MatrixWritableTest.java | 148 +
.../java/org/apache/mahout/math/VarintTest.java | 189 +
.../apache/mahout/math/VectorWritableTest.java | 123 +
.../apache/mahout/math/hadoop/MathHelper.java | 236 +
.../math/hadoop/TestDistributedRowMatrix.java | 395 +
.../TestDistributedLanczosSolver.java | 132 +
.../TestDistributedLanczosSolverCLI.java | 190 +
.../TestVectorDistanceSimilarityJob.java | 238 +
.../cooccurrence/RowSimilarityJobTest.java | 214 +
.../measures/VectorSimilarityMeasuresTest.java | 133 +
.../TestDistributedConjugateGradientSolver.java | 59 +
...stDistributedConjugateGradientSolverCLI.java | 111 +
.../math/hadoop/stats/BasicStatsTest.java | 121 +
.../stochasticsvd/LocalSSVDPCASparseTest.java | 296 +
.../stochasticsvd/LocalSSVDSolverDenseTest.java | 206 +
.../LocalSSVDSolverSparseSequentialTest.java | 209 +
.../hadoop/stochasticsvd/SSVDCommonTest.java | 105 +
.../hadoop/stochasticsvd/SSVDTestsHelper.java | 172 +
.../LocalitySensitiveHashSearchTest.java | 119 +
.../mahout/math/neighborhood/LumpyData.java | 77 +
.../math/neighborhood/SearchQualityTest.java | 178 +
.../math/neighborhood/SearchSanityTest.java | 244 +
.../math/ssvd/SequentialOutOfCoreSvdTest.java | 195 +
.../apache/mahout/math/stats/OnlineAucTest.java | 127 +
.../apache/mahout/math/stats/SamplerTest.java | 45 +
.../vectorizer/DictionaryVectorizerTest.java | 220 +
.../vectorizer/DocumentProcessorTest.java | 81 +
.../EncodedVectorsFromSequenceFilesTest.java | 126 +
.../vectorizer/HighDFWordsPrunerTest.java | 154 +
.../vectorizer/RandomDocumentGenerator.java | 69 +
.../SparseVectorsFromSequenceFilesTest.java | 203 +
.../collocations/llr/CollocMapperTest.java | 180 +
.../collocations/llr/CollocReducerTest.java | 86 +
.../llr/GramKeyGroupComparatorTest.java | 45 +
.../llr/GramKeyPartitionerTest.java | 54 +
.../collocations/llr/GramKeyTest.java | 106 +
.../vectorizer/collocations/llr/GramTest.java | 216 +
.../collocations/llr/LLRReducerTest.java | 116 +
.../vectorizer/encoders/CachingEncoderTest.java | 48 +
.../encoders/ConstantValueEncoderTest.java | 74 +
.../encoders/ContinuousValueEncoderTest.java | 88 +
.../encoders/InteractionValueEncoderTest.java | 103 +
.../encoders/TextValueEncoderTest.java | 100 +
.../encoders/WordLikeValueEncoderTest.java | 99 +
mr/src/test/resources/FPGsynth.dat | 193 +
mr/src/test/resources/cancer.csv | 684 +
mr/src/test/resources/iris.csv | 151 +
mr/src/test/resources/retail.dat | 88162 +++++++++++++++++
.../retail_results_with_min_sup_100.dat | 6438 ++
mr/src/test/resources/sgd.csv | 61 +
mr/src/test/resources/word-list.txt | 512 +
mrlegacy/pom.xml | 236 -
mrlegacy/src/main/assembly/job.xml | 61 -
.../main/java/org/apache/mahout/Version.java | 41 -
.../cf/taste/common/NoSuchItemException.java | 32 -
.../cf/taste/common/NoSuchUserException.java | 32 -
.../mahout/cf/taste/common/Refreshable.java | 53 -
.../mahout/cf/taste/common/TasteException.java | 41 -
.../mahout/cf/taste/common/Weighting.java | 31 -
.../mahout/cf/taste/eval/DataModelBuilder.java | 45 -
.../mahout/cf/taste/eval/IRStatistics.java | 80 -
.../cf/taste/eval/RecommenderBuilder.java | 45 -
.../cf/taste/eval/RecommenderEvaluator.java | 105 -
.../taste/eval/RecommenderIRStatsEvaluator.java | 64 -
.../taste/eval/RelevantItemsDataSplitter.java | 62 -
.../cf/taste/hadoop/EntityEntityWritable.java | 98 -
.../cf/taste/hadoop/EntityPrefWritable.java | 89 -
.../cf/taste/hadoop/MutableRecommendedItem.java | 81 -
.../taste/hadoop/RecommendedItemsWritable.java | 96 -
.../cf/taste/hadoop/TasteHadoopUtils.java | 84 -
.../cf/taste/hadoop/ToEntityPrefsMapper.java | 78 -
.../cf/taste/hadoop/ToItemPrefsMapper.java | 46 -
.../mahout/cf/taste/hadoop/TopItemsQueue.java | 60 -
.../apache/mahout/cf/taste/hadoop/als/ALS.java | 107 -
.../cf/taste/hadoop/als/DatasetSplitter.java | 158 -
.../hadoop/als/FactorizationEvaluator.java | 172 -
.../hadoop/als/MultithreadedSharingMapper.java | 62 -
.../hadoop/als/ParallelALSFactorizationJob.java | 419 -
.../cf/taste/hadoop/als/PredictionMapper.java | 145 -
.../cf/taste/hadoop/als/RecommenderJob.java | 110 -
.../cf/taste/hadoop/als/SharingMapper.java | 59 -
.../hadoop/als/SolveExplicitFeedbackMapper.java | 61 -
.../hadoop/als/SolveImplicitFeedbackMapper.java | 58 -
.../item/AggregateAndRecommendReducer.java | 220 -
.../mahout/cf/taste/hadoop/item/IDReader.java | 250 -
.../item/ItemFilterAsVectorAndPrefsReducer.java | 62 -
.../cf/taste/hadoop/item/ItemFilterMapper.java | 47 -
.../cf/taste/hadoop/item/ItemIDIndexMapper.java | 56 -
.../taste/hadoop/item/ItemIDIndexReducer.java | 48 -
.../hadoop/item/PartialMultiplyMapper.java | 57 -
.../item/PrefAndSimilarityColumnWritable.java | 85 -
.../cf/taste/hadoop/item/RecommenderJob.java | 337 -
.../item/SimilarityMatrixRowWrapperMapper.java | 54 -
.../taste/hadoop/item/ToUserVectorsReducer.java | 84 -
.../hadoop/item/ToVectorAndPrefReducer.java | 63 -
.../hadoop/item/UserVectorSplitterMapper.java | 116 -
.../hadoop/item/VectorAndPrefsWritable.java | 92 -
.../taste/hadoop/item/VectorOrPrefWritable.java | 104 -
.../preparation/PreparePreferenceMatrixJob.java | 115 -
.../hadoop/preparation/ToItemVectorsMapper.java | 56 -
.../preparation/ToItemVectorsReducer.java | 38 -
.../similarity/item/ItemSimilarityJob.java | 233 -
.../similarity/item/TopSimilarItemsQueue.java | 60 -
.../common/AbstractLongPrimitiveIterator.java | 27 -
.../mahout/cf/taste/impl/common/BitSet.java | 93 -
.../mahout/cf/taste/impl/common/Cache.java | 178 -
.../cf/taste/impl/common/FastByIDMap.java | 661 -
.../mahout/cf/taste/impl/common/FastIDSet.java | 426 -
.../mahout/cf/taste/impl/common/FastMap.java | 729 -
.../taste/impl/common/FixedRunningAverage.java | 83 -
.../common/FixedRunningAverageAndStdDev.java | 51 -
.../taste/impl/common/FullRunningAverage.java | 109 -
.../common/FullRunningAverageAndStdDev.java | 107 -
.../impl/common/InvertedRunningAverage.java | 58 -
.../common/InvertedRunningAverageAndStdDev.java | 63 -
.../impl/common/LongPrimitiveArrayIterator.java | 93 -
.../impl/common/LongPrimitiveIterator.java | 39 -
.../cf/taste/impl/common/RefreshHelper.java | 122 -
.../mahout/cf/taste/impl/common/Retriever.java | 36 -
.../cf/taste/impl/common/RunningAverage.java | 67 -
.../impl/common/RunningAverageAndStdDev.java | 36 -
.../common/SamplingLongPrimitiveIterator.java | 111 -
.../cf/taste/impl/common/SkippingIterator.java | 35 -
.../impl/common/WeightedRunningAverage.java | 100 -
.../common/WeightedRunningAverageAndStdDev.java | 89 -
.../impl/common/jdbc/AbstractJDBCComponent.java | 88 -
.../taste/impl/common/jdbc/EachRowIterator.java | 92 -
.../impl/common/jdbc/ResultSetIterator.java | 66 -
.../AbstractDifferenceRecommenderEvaluator.java | 277 -
...eAbsoluteDifferenceRecommenderEvaluator.java | 59 -
.../GenericRecommenderIRStatsEvaluator.java | 237 -
.../eval/GenericRelevantItemsDataSplitter.java | 83 -
.../cf/taste/impl/eval/IRStatisticsImpl.java | 95 -
.../mahout/cf/taste/impl/eval/LoadCallable.java | 40 -
.../cf/taste/impl/eval/LoadEvaluator.java | 61 -
.../cf/taste/impl/eval/LoadStatistics.java | 34 -
.../eval/OrderBasedRecommenderEvaluator.java | 431 -
.../impl/eval/RMSRecommenderEvaluator.java | 56 -
.../cf/taste/impl/eval/StatsCallable.java | 64 -
.../cf/taste/impl/model/AbstractDataModel.java | 53 -
.../cf/taste/impl/model/AbstractIDMigrator.java | 67 -
.../impl/model/AbstractJDBCIDMigrator.java | 108 -
.../impl/model/BooleanItemPreferenceArray.java | 234 -
.../cf/taste/impl/model/BooleanPreference.java | 64 -
.../impl/model/BooleanUserPreferenceArray.java | 234 -
.../impl/model/GenericBooleanPrefDataModel.java | 320 -
.../cf/taste/impl/model/GenericDataModel.java | 361 -
.../impl/model/GenericItemPreferenceArray.java | 301 -
.../cf/taste/impl/model/GenericPreference.java | 70 -
.../impl/model/GenericUserPreferenceArray.java | 307 -
.../cf/taste/impl/model/MemoryIDMigrator.java | 55 -
.../taste/impl/model/MySQLJDBCIDMigrator.java | 67 -
.../PlusAnonymousConcurrentUserDataModel.java | 352 -
.../impl/model/PlusAnonymousUserDataModel.java | 320 -
.../PlusAnonymousUserLongPrimitiveIterator.java | 90 -
.../cf/taste/impl/model/file/FileDataModel.java | 759 -
.../taste/impl/model/file/FileIDMigrator.java | 117 -
.../neighborhood/AbstractUserNeighborhood.java | 71 -
.../neighborhood/CachingUserNeighborhood.java | 69 -
.../neighborhood/NearestNUserNeighborhood.java | 122 -
.../neighborhood/ThresholdUserNeighborhood.java | 104 -
.../AbstractCandidateItemsStrategy.java | 57 -
.../impl/recommender/AbstractRecommender.java | 140 -
.../AllSimilarItemsCandidateItemsStrategy.java | 50 -
.../AllUnknownItemsCandidateItemsStrategy.java | 41 -
.../impl/recommender/ByRescoreComparator.java | 65 -
.../ByValueRecommendedItemComparator.java | 43 -
.../impl/recommender/CachingRecommender.java | 251 -
.../recommender/EstimatedPreferenceCapper.java | 46 -
.../GenericBooleanPrefItemBasedRecommender.java | 71 -
.../GenericBooleanPrefUserBasedRecommender.java | 82 -
.../GenericItemBasedRecommender.java | 378 -
.../recommender/GenericRecommendedItem.java | 76 -
.../GenericUserBasedRecommender.java | 247 -
.../recommender/ItemAverageRecommender.java | 199 -
.../recommender/ItemUserAverageRecommender.java | 240 -
.../cf/taste/impl/recommender/NullRescorer.java | 86 -
...ItemsNeighborhoodCandidateItemsStrategy.java | 48 -
.../impl/recommender/RandomRecommender.java | 97 -
.../SamplingCandidateItemsStrategy.java | 165 -
.../cf/taste/impl/recommender/SimilarUser.java | 80 -
.../cf/taste/impl/recommender/TopItems.java | 212 -
.../impl/recommender/svd/ALSWRFactorizer.java | 313 -
.../recommender/svd/AbstractFactorizer.java | 94 -
.../impl/recommender/svd/Factorization.java | 137 -
.../taste/impl/recommender/svd/Factorizer.java | 30 -
.../svd/FilePersistenceStrategy.java | 148 -
.../recommender/svd/NoPersistenceStrategy.java | 37 -
.../recommender/svd/ParallelSGDFactorizer.java | 340 -
.../recommender/svd/PersistenceStrategy.java | 46 -
.../recommender/svd/RatingSGDFactorizer.java | 221 -
.../recommender/svd/SVDPlusPlusFactorizer.java | 178 -
.../impl/recommender/svd/SVDPreference.java | 41 -
.../impl/recommender/svd/SVDRecommender.java | 185 -
.../impl/similarity/AbstractItemSimilarity.java | 64 -
.../impl/similarity/AbstractSimilarity.java | 343 -
.../similarity/AveragingPreferenceInferrer.java | 85 -
.../impl/similarity/CachingItemSimilarity.java | 111 -
.../impl/similarity/CachingUserSimilarity.java | 104 -
.../impl/similarity/CityBlockSimilarity.java | 98 -
.../similarity/EuclideanDistanceSimilarity.java | 67 -
.../impl/similarity/GenericItemSimilarity.java | 358 -
.../impl/similarity/GenericUserSimilarity.java | 238 -
.../similarity/LogLikelihoodSimilarity.java | 121 -
.../impl/similarity/LongPairMatchPredicate.java | 40 -
.../PearsonCorrelationSimilarity.java | 93 -
.../SpearmanCorrelationSimilarity.java | 135 -
.../TanimotoCoefficientSimilarity.java | 126 -
.../similarity/UncenteredCosineSimilarity.java | 69 -
.../file/FileItemItemSimilarityIterable.java | 46 -
.../file/FileItemItemSimilarityIterator.java | 60 -
.../similarity/file/FileItemSimilarity.java | 137 -
.../precompute/FileSimilarItemsWriter.java | 67 -
.../MultithreadedBatchItemSimilarities.java | 230 -
.../apache/mahout/cf/taste/model/DataModel.java | 199 -
.../mahout/cf/taste/model/IDMigrator.java | 63 -
.../mahout/cf/taste/model/JDBCDataModel.java | 43 -
.../mahout/cf/taste/model/Preference.java | 48 -
.../mahout/cf/taste/model/PreferenceArray.java | 143 -
.../cf/taste/model/UpdatableIDMigrator.java | 47 -
.../cf/taste/neighborhood/UserNeighborhood.java | 40 -
.../recommender/CandidateItemsStrategy.java | 37 -
.../mahout/cf/taste/recommender/IDRescorer.java | 47 -
.../taste/recommender/ItemBasedRecommender.java | 145 -
.../MostSimilarItemsCandidateItemsStrategy.java | 31 -
.../cf/taste/recommender/RecommendedItem.java | 41 -
.../cf/taste/recommender/Recommender.java | 132 -
.../mahout/cf/taste/recommender/Rescorer.java | 52 -
.../taste/recommender/UserBasedRecommender.java | 54 -
.../cf/taste/similarity/ItemSimilarity.java | 64 -
.../cf/taste/similarity/PreferenceInferrer.java | 47 -
.../cf/taste/similarity/UserSimilarity.java | 58 -
.../precompute/BatchItemSimilarities.java | 56 -
.../similarity/precompute/SimilarItem.java | 56 -
.../similarity/precompute/SimilarItems.java | 84 -
.../precompute/SimilarItemsWriter.java | 33 -
.../classifier/AbstractVectorClassifier.java | 248 -
.../mahout/classifier/ClassifierResult.java | 74 -
.../mahout/classifier/ConfusionMatrix.java | 444 -
.../apache/mahout/classifier/OnlineLearner.java | 96 -
.../classifier/RegressionResultAnalyzer.java | 144 -
.../mahout/classifier/ResultAnalyzer.java | 132 -
.../apache/mahout/classifier/df/Bagging.java | 60 -
.../apache/mahout/classifier/df/DFUtils.java | 181 -
.../mahout/classifier/df/DecisionForest.java | 244 -
.../mahout/classifier/df/ErrorEstimate.java | 50 -
.../df/builder/DecisionTreeBuilder.java | 421 -
.../df/builder/DefaultTreeBuilder.java | 252 -
.../classifier/df/builder/TreeBuilder.java | 41 -
.../apache/mahout/classifier/df/data/Data.java | 280 -
.../classifier/df/data/DataConverter.java | 71 -
.../mahout/classifier/df/data/DataLoader.java | 253 -
.../mahout/classifier/df/data/DataUtils.java | 88 -
.../mahout/classifier/df/data/Dataset.java | 421 -
.../classifier/df/data/DescriptorException.java | 27 -
.../classifier/df/data/DescriptorUtils.java | 109 -
.../mahout/classifier/df/data/Instance.java | 74 -
.../df/data/conditions/Condition.java | 56 -
.../classifier/df/data/conditions/Equals.java | 41 -
.../df/data/conditions/GreaterOrEquals.java | 41 -
.../classifier/df/data/conditions/Lesser.java | 41 -
.../mahout/classifier/df/mapreduce/Builder.java | 332 -
.../classifier/df/mapreduce/Classifier.java | 237 -
.../classifier/df/mapreduce/MapredMapper.java | 74 -
.../classifier/df/mapreduce/MapredOutput.java | 119 -
.../df/mapreduce/inmem/InMemBuilder.java | 113 -
.../df/mapreduce/inmem/InMemInputFormat.java | 283 -
.../df/mapreduce/inmem/InMemMapper.java | 105 -
.../df/mapreduce/inmem/package-info.java | 22 -
.../df/mapreduce/partial/PartialBuilder.java | 157 -
.../df/mapreduce/partial/Step1Mapper.java | 167 -
.../classifier/df/mapreduce/partial/TreeID.java | 57 -
.../df/mapreduce/partial/package-info.java | 16 -
.../classifier/df/node/CategoricalNode.java | 134 -
.../apache/mahout/classifier/df/node/Leaf.java | 94 -
.../apache/mahout/classifier/df/node/Node.java | 95 -
.../classifier/df/node/NumericalNode.java | 114 -
.../classifier/df/ref/SequentialBuilder.java | 77 -
.../classifier/df/split/DefaultIgSplit.java | 117 -
.../mahout/classifier/df/split/IgSplit.java | 34 -
.../mahout/classifier/df/split/OptIgSplit.java | 231 -
.../classifier/df/split/RegressionSplit.java | 176 -
.../mahout/classifier/df/split/Split.java | 67 -
.../mahout/classifier/df/tools/Describe.java | 148 -
.../classifier/df/tools/ForestVisualizer.java | 157 -
.../mahout/classifier/df/tools/Frequencies.java | 121 -
.../classifier/df/tools/FrequenciesJob.java | 296 -
.../classifier/df/tools/TreeVisualizer.java | 263 -
.../mahout/classifier/df/tools/UDistrib.java | 211 -
.../mahout/classifier/evaluation/Auc.java | 233 -
.../classifier/mlp/MultilayerPerceptron.java | 90 -
.../mahout/classifier/mlp/NeuralNetwork.java | 743 -
.../classifier/mlp/NeuralNetworkFunctions.java | 150 -
.../classifier/mlp/RunMultilayerPerceptron.java | 227 -
.../mlp/TrainMultilayerPerceptron.java | 332 -
.../AbstractNaiveBayesClassifier.java | 82 -
.../classifier/naivebayes/BayesUtils.java | 167 -
.../ComplementaryNaiveBayesClassifier.java | 43 -
.../classifier/naivebayes/NaiveBayesModel.java | 176 -
.../StandardNaiveBayesClassifier.java | 40 -
.../naivebayes/test/BayesTestMapper.java | 76 -
.../naivebayes/test/TestNaiveBayesDriver.java | 179 -
.../training/ComplementaryThetaTrainer.java | 83 -
.../training/IndexInstancesMapper.java | 53 -
.../naivebayes/training/ThetaMapper.java | 61 -
.../naivebayes/training/TrainNaiveBayesJob.java | 186 -
.../naivebayes/training/WeightsMapper.java | 68 -
.../sequencelearning/hmm/BaumWelchTrainer.java | 165 -
.../sequencelearning/hmm/HmmAlgorithms.java | 306 -
.../sequencelearning/hmm/HmmEvaluator.java | 194 -
.../sequencelearning/hmm/HmmModel.java | 383 -
.../sequencelearning/hmm/HmmTrainer.java | 488 -
.../sequencelearning/hmm/HmmUtils.java | 361 -
.../hmm/LossyHmmSerializer.java | 62 -
.../hmm/RandomSequenceGenerator.java | 108 -
.../sequencelearning/hmm/ViterbiEvaluator.java | 127 -
.../sgd/AbstractOnlineLogisticRegression.java | 317 -
.../sgd/AdaptiveLogisticRegression.java | 586 -
.../mahout/classifier/sgd/CrossFoldLearner.java | 334 -
.../mahout/classifier/sgd/CsvRecordFactory.java | 393 -
.../mahout/classifier/sgd/DefaultGradient.java | 49 -
.../mahout/classifier/sgd/ElasticBandPrior.java | 76 -
.../apache/mahout/classifier/sgd/Gradient.java | 30 -
.../mahout/classifier/sgd/GradientMachine.java | 405 -
.../org/apache/mahout/classifier/sgd/L1.java | 59 -
.../org/apache/mahout/classifier/sgd/L2.java | 66 -
.../mahout/classifier/sgd/MixedGradient.java | 66 -
.../mahout/classifier/sgd/ModelDissector.java | 232 -
.../mahout/classifier/sgd/ModelSerializer.java | 76 -
.../sgd/OnlineLogisticRegression.java | 172 -
.../classifier/sgd/PassiveAggressive.java | 204 -
.../classifier/sgd/PolymorphicWritable.java | 46 -
.../mahout/classifier/sgd/PriorFunction.java | 45 -
.../mahout/classifier/sgd/RankingGradient.java | 85 -
.../mahout/classifier/sgd/RecordFactory.java | 47 -
.../apache/mahout/classifier/sgd/TPrior.java | 61 -
.../mahout/classifier/sgd/UniformPrior.java | 47 -
.../mahout/classifier/sgd/package-info.java | 23 -
.../mahout/clustering/AbstractCluster.java | 391 -
.../org/apache/mahout/clustering/Cluster.java | 90 -
.../mahout/clustering/ClusteringUtils.java | 305 -
.../mahout/clustering/GaussianAccumulator.java | 62 -
.../org/apache/mahout/clustering/Model.java | 93 -
.../mahout/clustering/ModelDistribution.java | 41 -
.../clustering/OnlineGaussianAccumulator.java | 107 -
.../RunningSumsGaussianAccumulator.java | 90 -
.../clustering/UncommonDistributions.java | 136 -
.../apache/mahout/clustering/canopy/Canopy.java | 60 -
.../clustering/canopy/CanopyClusterer.java | 220 -
.../clustering/canopy/CanopyConfigKeys.java | 70 -
.../mahout/clustering/canopy/CanopyDriver.java | 379 -
.../mahout/clustering/canopy/CanopyMapper.java | 66 -
.../mahout/clustering/canopy/CanopyReducer.java | 70 -
.../ClusterClassificationConfigKeys.java | 33 -
.../classify/ClusterClassificationDriver.java | 313 -
.../classify/ClusterClassificationMapper.java | 161 -
.../clustering/classify/ClusterClassifier.java | 240 -
.../WeightedPropertyVectorWritable.java | 95 -
.../classify/WeightedVectorWritable.java | 72 -
.../fuzzykmeans/FuzzyKMeansClusterer.java | 59 -
.../fuzzykmeans/FuzzyKMeansDriver.java | 324 -
.../clustering/fuzzykmeans/FuzzyKMeansUtil.java | 76 -
.../clustering/fuzzykmeans/SoftCluster.java | 60 -
.../iterator/AbstractClusteringPolicy.java | 72 -
.../mahout/clustering/iterator/CIMapper.java | 71 -
.../mahout/clustering/iterator/CIReducer.java | 64 -
.../iterator/CanopyClusteringPolicy.java | 52 -
.../clustering/iterator/ClusterIterator.java | 219 -
.../clustering/iterator/ClusterWritable.java | 56 -
.../clustering/iterator/ClusteringPolicy.java | 66 -
.../iterator/ClusteringPolicyWritable.java | 55 -
.../iterator/DistanceMeasureCluster.java | 91 -
.../iterator/FuzzyKMeansClusteringPolicy.java | 91 -
.../iterator/KMeansClusteringPolicy.java | 64 -
.../clustering/kernel/IKernelProfile.java | 27 -
.../kernel/TriangularKernelProfile.java | 27 -
.../mahout/clustering/kmeans/KMeansDriver.java | 257 -
.../mahout/clustering/kmeans/KMeansUtil.java | 74 -
.../mahout/clustering/kmeans/Kluster.java | 117 -
.../clustering/kmeans/RandomSeedGenerator.java | 140 -
.../mahout/clustering/kmeans/package-info.java | 5 -
.../lda/cvb/CVB0DocInferenceMapper.java | 51 -
.../mahout/clustering/lda/cvb/CVB0Driver.java | 536 -
.../CVB0TopicTermVectorNormalizerMapper.java | 38 -
.../clustering/lda/cvb/CachingCVB0Mapper.java | 133 -
.../lda/cvb/CachingCVB0PerplexityMapper.java | 108 -
.../cvb/InMemoryCollapsedVariationalBayes0.java | 515 -
.../mahout/clustering/lda/cvb/ModelTrainer.java | 301 -
.../mahout/clustering/lda/cvb/TopicModel.java | 513 -
.../apache/mahout/clustering/package-info.java | 13 -
.../spectral/AffinityMatrixInputJob.java | 84 -
.../spectral/AffinityMatrixInputMapper.java | 78 -
.../spectral/AffinityMatrixInputReducer.java | 59 -
.../spectral/IntDoublePairWritable.java | 75 -
.../apache/mahout/clustering/spectral/Keys.java | 31 -
.../spectral/MatrixDiagonalizeJob.java | 108 -
.../clustering/spectral/UnitVectorizerJob.java | 79 -
.../mahout/clustering/spectral/VectorCache.java | 123 -
.../spectral/VectorMatrixMultiplicationJob.java | 139 -
.../clustering/spectral/VertexWritable.java | 101 -
.../spectral/kmeans/EigenSeedGenerator.java | 124 -
.../spectral/kmeans/SpectralKMeansDriver.java | 243 -
.../streaming/cluster/BallKMeans.java | 456 -
.../streaming/cluster/StreamingKMeans.java | 368 -
.../streaming/mapreduce/CentroidWritable.java | 88 -
.../mapreduce/StreamingKMeansDriver.java | 493 -
.../mapreduce/StreamingKMeansMapper.java | 102 -
.../mapreduce/StreamingKMeansReducer.java | 109 -
.../mapreduce/StreamingKMeansThread.java | 92 -
.../mapreduce/StreamingKMeansUtilsMR.java | 163 -
.../streaming/tools/ResplitSequenceFiles.java | 149 -
.../clustering/topdown/PathDirectory.java | 94 -
.../postprocessor/ClusterCountReader.java | 103 -
.../ClusterOutputPostProcessor.java | 139 -
.../ClusterOutputPostProcessorDriver.java | 182 -
.../ClusterOutputPostProcessorMapper.java | 58 -
.../ClusterOutputPostProcessorReducer.java | 62 -
.../org/apache/mahout/common/AbstractJob.java | 658 -
.../org/apache/mahout/common/ClassUtils.java | 61 -
.../apache/mahout/common/CommandLineUtil.java | 68 -
.../org/apache/mahout/common/HadoopUtil.java | 442 -
.../java/org/apache/mahout/common/IOUtils.java | 194 -
.../apache/mahout/common/IntPairWritable.java | 270 -
.../org/apache/mahout/common/IntegerTuple.java | 176 -
.../java/org/apache/mahout/common/LongPair.java | 80 -
.../org/apache/mahout/common/MemoryUtil.java | 99 -
.../java/org/apache/mahout/common/Pair.java | 99 -
.../org/apache/mahout/common/Parameters.java | 98 -
.../org/apache/mahout/common/StringTuple.java | 177 -
.../org/apache/mahout/common/StringUtils.java | 63 -
.../apache/mahout/common/TimingStatistics.java | 154 -
.../commandline/DefaultOptionCreator.java | 417 -
.../distance/ChebyshevDistanceMeasure.java | 63 -
.../common/distance/CosineDistanceMeasure.java | 119 -
.../mahout/common/distance/DistanceMeasure.java | 48 -
.../distance/EuclideanDistanceMeasure.java | 41 -
.../distance/MahalanobisDistanceMeasure.java | 204 -
.../distance/ManhattanDistanceMeasure.java | 70 -
.../distance/MinkowskiDistanceMeasure.java | 93 -
.../SquaredEuclideanDistanceMeasure.java | 59 -
.../distance/TanimotoDistanceMeasure.java | 69 -
.../distance/WeightedDistanceMeasure.java | 97 -
.../WeightedEuclideanDistanceMeasure.java | 52 -
.../WeightedManhattanDistanceMeasure.java | 53 -
.../iterator/CopyConstructorIterator.java | 64 -
.../common/iterator/CountingIterator.java | 43 -
.../common/iterator/FileLineIterable.java | 88 -
.../common/iterator/FileLineIterator.java | 167 -
.../iterator/FixedSizeSamplingIterator.java | 59 -
.../common/iterator/SamplingIterable.java | 45 -
.../common/iterator/SamplingIterator.java | 73 -
.../StableFixedSizeSamplingIterator.java | 72 -
.../common/iterator/StringRecordIterator.java | 55 -
.../iterator/sequencefile/PathFilters.java | 81 -
.../common/iterator/sequencefile/PathType.java | 27 -
.../sequencefile/SequenceFileDirIterable.java | 84 -
.../sequencefile/SequenceFileDirIterator.java | 136 -
.../SequenceFileDirValueIterable.java | 83 -
.../SequenceFileDirValueIterator.java | 159 -
.../sequencefile/SequenceFileIterable.java | 68 -
.../sequencefile/SequenceFileIterator.java | 118 -
.../sequencefile/SequenceFileValueIterable.java | 67 -
.../sequencefile/SequenceFileValueIterator.java | 97 -
.../mahout/common/lucene/AnalyzerUtils.java | 61 -
.../common/lucene/IteratorTokenStream.java | 45 -
.../common/lucene/TokenStreamIterator.java | 57 -
.../common/mapreduce/MergeVectorsCombiner.java | 34 -
.../common/mapreduce/MergeVectorsReducer.java | 40 -
.../common/mapreduce/TransposeMapper.java | 49 -
.../common/mapreduce/VectorSumCombiner.java | 38 -
.../common/mapreduce/VectorSumReducer.java | 35 -
.../org/apache/mahout/common/nlp/NGrams.java | 94 -
.../common/parameters/AbstractParameter.java | 120 -
.../common/parameters/ClassParameter.java | 44 -
.../common/parameters/DoubleParameter.java | 33 -
.../mahout/common/parameters/Parameter.java | 62 -
.../mahout/common/parameters/Parametered.java | 206 -
.../mahout/common/parameters/PathParameter.java | 33 -
.../org/apache/mahout/driver/MahoutDriver.java | 244 -
.../apache/mahout/ep/EvolutionaryProcess.java | 228 -
.../main/java/org/apache/mahout/ep/Mapping.java | 206 -
.../main/java/org/apache/mahout/ep/Payload.java | 36 -
.../main/java/org/apache/mahout/ep/State.java | 302 -
.../java/org/apache/mahout/ep/package-info.java | 26 -
.../mahout/math/DistributedRowMatrixWriter.java | 47 -
.../org/apache/mahout/math/MatrixUtils.java | 114 -
.../org/apache/mahout/math/MatrixWritable.java | 202 -
.../mahout/math/MultiLabelVectorWritable.java | 88 -
.../org/apache/mahout/math/VarIntWritable.java | 86 -
.../org/apache/mahout/math/VarLongWritable.java | 83 -
.../java/org/apache/mahout/math/Varint.java | 167 -
.../org/apache/mahout/math/VectorWritable.java | 267 -
.../math/hadoop/DistributedRowMatrix.java | 385 -
.../math/hadoop/MatrixColumnMeansJob.java | 236 -
.../math/hadoop/MatrixMultiplicationJob.java | 177 -
.../mahout/math/hadoop/TimesSquaredJob.java | 251 -
.../apache/mahout/math/hadoop/TransposeJob.java | 85 -
.../decomposer/DistributedLanczosSolver.java | 298 -
.../math/hadoop/decomposer/EigenVector.java | 76 -
.../hadoop/decomposer/EigenVerificationJob.java | 332 -
.../decomposer/HdfsBackedLanczosState.java | 237 -
.../math/hadoop/similarity/SeedVectorUtil.java | 104 -
.../VectorDistanceInvertedMapper.java | 71 -
.../hadoop/similarity/VectorDistanceMapper.java | 80 -
.../similarity/VectorDistanceSimilarityJob.java | 153 -
.../similarity/cooccurrence/MutableElement.java | 50 -
.../cooccurrence/RowSimilarityJob.java | 562 -
.../cooccurrence/TopElementsQueue.java | 59 -
.../hadoop/similarity/cooccurrence/Vectors.java | 199 -
.../measures/CityBlockSimilarity.java | 26 -
.../measures/CooccurrenceCountSimilarity.java | 32 -
.../cooccurrence/measures/CosineSimilarity.java | 50 -
.../measures/CountbasedMeasure.java | 44 -
.../measures/EuclideanDistanceSimilarity.java | 57 -
.../measures/LoglikelihoodSimilarity.java | 34 -
.../measures/PearsonCorrelationSimilarity.java | 37 -
.../measures/TanimotoCoefficientSimilarity.java | 34 -
.../measures/VectorSimilarityMeasure.java | 32 -
.../measures/VectorSimilarityMeasures.java | 46 -
.../DistributedConjugateGradientSolver.java | 172 -
.../mahout/math/hadoop/stats/BasicStats.java | 148 -
.../StandardDeviationCalculatorMapper.java | 55 -
.../StandardDeviationCalculatorReducer.java | 37 -
.../math/hadoop/stats/VarianceTotals.java | 68 -
.../hadoop/stochasticsvd/ABtDenseOutJob.java | 585 -
.../math/hadoop/stochasticsvd/ABtJob.java | 494 -
.../mahout/math/hadoop/stochasticsvd/BtJob.java | 628 -
.../stochasticsvd/DenseBlockWritable.java | 83 -
.../mahout/math/hadoop/stochasticsvd/Omega.java | 257 -
.../mahout/math/hadoop/stochasticsvd/QJob.java | 237 -
.../math/hadoop/stochasticsvd/SSVDCli.java | 201 -
.../math/hadoop/stochasticsvd/SSVDHelper.java | 322 -
.../math/hadoop/stochasticsvd/SSVDSolver.java | 662 -
.../SparseRowBlockAccumulator.java | 90 -
.../stochasticsvd/SparseRowBlockWritable.java | 159 -
.../stochasticsvd/SplitPartitionedWritable.java | 151 -
.../mahout/math/hadoop/stochasticsvd/UJob.java | 170 -
.../mahout/math/hadoop/stochasticsvd/VJob.java | 224 -
.../math/hadoop/stochasticsvd/YtYJob.java | 220 -
.../stochasticsvd/qr/GivensThinSolver.java | 638 -
.../hadoop/stochasticsvd/qr/GramSchmidt.java | 52 -
.../hadoop/stochasticsvd/qr/QRFirstStep.java | 284 -
.../hadoop/stochasticsvd/qr/QRLastStep.java | 144 -
.../mahout/math/neighborhood/BruteSearch.java | 186 -
.../math/neighborhood/FastProjectionSearch.java | 326 -
.../mahout/math/neighborhood/HashedVector.java | 103 -
.../LocalitySensitiveHashSearch.java | 295 -
.../math/neighborhood/ProjectionSearch.java | 233 -
.../mahout/math/neighborhood/Searcher.java | 155 -
.../math/neighborhood/UpdatableSearcher.java | 37 -
.../mahout/math/random/RandomProjector.java | 133 -
.../math/ssvd/SequentialOutOfCoreSvd.java | 233 -
.../mahout/math/stats/GlobalOnlineAuc.java | 168 -
.../mahout/math/stats/GroupedOnlineAuc.java | 113 -
.../org/apache/mahout/math/stats/OnlineAuc.java | 38 -
.../org/apache/mahout/math/stats/Sampler.java | 79 -
.../mahout/vectorizer/DictionaryVectorizer.java | 416 -
.../mahout/vectorizer/DocumentProcessor.java | 99 -
.../EncodedVectorsFromSequenceFiles.java | 104 -
.../mahout/vectorizer/EncodingMapper.java | 92 -
.../mahout/vectorizer/HighDFWordsPruner.java | 147 -
.../SimpleTextEncodingVectorizer.java | 72 -
.../SparseVectorsFromSequenceFiles.java | 369 -
.../java/org/apache/mahout/vectorizer/TF.java | 30 -
.../org/apache/mahout/vectorizer/TFIDF.java | 31 -
.../apache/mahout/vectorizer/Vectorizer.java | 29 -
.../mahout/vectorizer/VectorizerConfig.java | 179 -
.../org/apache/mahout/vectorizer/Weight.java | 32 -
.../collocations/llr/CollocCombiner.java | 46 -
.../collocations/llr/CollocDriver.java | 284 -
.../collocations/llr/CollocMapper.java | 178 -
.../collocations/llr/CollocReducer.java | 176 -
.../vectorizer/collocations/llr/Gram.java | 239 -
.../vectorizer/collocations/llr/GramKey.java | 133 -
.../llr/GramKeyGroupComparator.java | 43 -
.../collocations/llr/GramKeyPartitioner.java | 40 -
.../vectorizer/collocations/llr/LLRReducer.java | 170 -
.../common/PartialVectorMergeReducer.java | 89 -
.../vectorizer/common/PartialVectorMerger.java | 144 -
.../document/SequenceFileTokenizerMapper.java | 70 -
.../encoders/AdaptiveWordValueEncoder.java | 69 -
.../encoders/CachingContinuousValueEncoder.java | 64 -
.../encoders/CachingStaticWordValueEncoder.java | 66 -
.../encoders/CachingTextValueEncoder.java | 25 -
.../encoders/CachingValueEncoder.java | 64 -
.../encoders/ConstantValueEncoder.java | 57 -
.../encoders/ContinuousValueEncoder.java | 76 -
.../mahout/vectorizer/encoders/Dictionary.java | 55 -
.../encoders/FeatureVectorEncoder.java | 279 -
.../encoders/InteractionValueEncoder.java | 126 -
.../encoders/LuceneTextValueEncoder.java | 133 -
.../encoders/StaticWordValueEncoder.java | 80 -
.../vectorizer/encoders/TextValueEncoder.java | 142 -
.../vectorizer/encoders/WordValueEncoder.java | 81 -
.../pruner/PrunedPartialVectorMergeReducer.java | 65 -
.../vectorizer/pruner/WordsPrunerReducer.java | 86 -
.../vectorizer/term/TFPartialVectorReducer.java | 139 -
.../vectorizer/term/TermCountCombiner.java | 41 -
.../mahout/vectorizer/term/TermCountMapper.java | 58 -
.../vectorizer/term/TermCountReducer.java | 55 -
.../term/TermDocumentCountMapper.java | 50 -
.../term/TermDocumentCountReducer.java | 41 -
.../mahout/vectorizer/tfidf/TFIDFConverter.java | 361 -
.../tfidf/TFIDFPartialVectorReducer.java | 114 -
mrlegacy/src/main/resources/version | 1 -
.../mahout/cf/taste/common/CommonTest.java | 60 -
.../cf/taste/hadoop/TasteHadoopUtilsTest.java | 40 -
.../cf/taste/hadoop/TopItemsQueueTest.java | 72 -
.../als/ParallelALSFactorizationJobTest.java | 379 -
.../cf/taste/hadoop/item/IDReaderTest.java | 66 -
.../taste/hadoop/item/RecommenderJobTest.java | 928 -
.../hadoop/item/ToUserVectorsReducerTest.java | 74 -
.../similarity/item/ItemSimilarityJobTest.java | 269 -
.../mahout/cf/taste/impl/TasteTestCase.java | 98 -
.../mahout/cf/taste/impl/common/BitSetTest.java | 74 -
.../mahout/cf/taste/impl/common/CacheTest.java | 61 -
.../cf/taste/impl/common/FastByIDMapTest.java | 147 -
.../cf/taste/impl/common/FastIDSetTest.java | 162 -
.../cf/taste/impl/common/FastMapTest.java | 228 -
.../impl/common/InvertedRunningAverageTest.java | 88 -
.../common/LongPrimitiveArrayIteratorTest.java | 56 -
.../cf/taste/impl/common/MockRefreshable.java | 45 -
.../cf/taste/impl/common/RefreshHelperTest.java | 70 -
.../common/RunningAverageAndStdDevTest.java | 107 -
.../taste/impl/common/RunningAverageTest.java | 75 -
.../SamplingLongPrimitiveIteratorTest.java | 91 -
.../impl/common/WeightedRunningAverageTest.java | 85 -
...ericRecommenderIRStatsEvaluatorImplTest.java | 73 -
.../taste/impl/eval/LoadEvaluationRunner.java | 68 -
.../model/BooleanItemPreferenceArrayTest.java | 89 -
.../model/BooleanUserPreferenceArrayTest.java | 89 -
.../taste/impl/model/GenericDataModelTest.java | 51 -
.../model/GenericItemPreferenceArrayTest.java | 110 -
.../model/GenericUserPreferenceArrayTest.java | 110 -
.../taste/impl/model/MemoryIDMigratorTest.java | 57 -
...lusAnonymousConcurrentUserDataModelTest.java | 313 -
.../impl/model/file/FileDataModelTest.java | 216 -
.../impl/model/file/FileIDMigratorTest.java | 103 -
.../impl/neighborhood/DummySimilarity.java | 68 -
.../neighborhood/NearestNNeighborhoodTest.java | 53 -
.../neighborhood/ThresholdNeighborhoodTest.java | 51 -
...lUnknownItemsCandidateItemsStrategyTest.java | 65 -
.../recommender/CachingRecommenderTest.java | 78 -
.../GenericItemBasedRecommenderTest.java | 324 -
.../GenericUserBasedRecommenderTest.java | 174 -
.../recommender/ItemAverageRecommenderTest.java | 43 -
.../ItemUserAverageRecommenderTest.java | 43 -
.../taste/impl/recommender/MockRecommender.java | 89 -
.../impl/recommender/NullRescorerTest.java | 47 -
...sNeighborhoodCandidateItemsStrategyTest.java | 75 -
.../impl/recommender/RandomRecommenderTest.java | 41 -
.../impl/recommender/ReversingRescorer.java | 46 -
.../SamplingCandidateItemsStrategyTest.java | 71 -
.../cf/taste/impl/recommender/TopItemsTest.java | 158 -
.../recommender/svd/ALSWRFactorizerTest.java | 208 -
.../svd/FilePersistenceStrategyTest.java | 53 -
.../svd/ParallelSGDFactorizerTest.java | 355 -
.../recommender/svd/SVDRecommenderTest.java | 86 -
.../AveragingPreferenceInferrerTest.java | 37 -
.../EuclideanDistanceSimilarityTest.java | 236 -
.../similarity/GenericItemSimilarityTest.java | 104 -
.../similarity/LogLikelihoodSimilarityTest.java | 80 -
.../PearsonCorrelationSimilarityTest.java | 265 -
.../impl/similarity/SimilarityTestCase.java | 35 -
.../SpearmanCorrelationSimilarityTest.java | 80 -
.../TanimotoCoefficientSimilarityTest.java | 121 -
.../similarity/file/FileItemSimilarityTest.java | 142 -
.../MultithreadedBatchItemSimilaritiesTest.java | 80 -
.../mahout/classifier/ClassifierData.java | 102 -
.../mahout/classifier/ConfusionMatrixTest.java | 119 -
.../RegressionResultAnalyzerTest.java | 128 -
.../classifier/df/DecisionForestTest.java | 206 -
.../df/builder/DecisionTreeBuilderTest.java | 78 -
.../df/builder/DefaultTreeBuilderTest.java | 74 -
.../df/builder/InfiniteRecursionTest.java | 60 -
.../classifier/df/data/DataConverterTest.java | 60 -
.../classifier/df/data/DataLoaderTest.java | 350 -
.../mahout/classifier/df/data/DataTest.java | 396 -
.../mahout/classifier/df/data/DatasetTest.java | 72 -
.../classifier/df/data/DescriptorUtilsTest.java | 92 -
.../apache/mahout/classifier/df/data/Utils.java | 283 -
.../mapreduce/inmem/InMemInputFormatTest.java | 109 -
.../df/mapreduce/inmem/InMemInputSplitTest.java | 77 -
.../mapreduce/partial/PartialBuilderTest.java | 197 -
.../df/mapreduce/partial/Step1MapperTest.java | 160 -
.../df/mapreduce/partial/TreeIDTest.java | 48 -
.../mahout/classifier/df/node/NodeTest.java | 108 -
.../classifier/df/split/DefaultIgSplitTest.java | 78 -
.../df/split/RegressionSplitTest.java | 87 -
.../classifier/df/tools/VisualizerTest.java | 211 -
.../mahout/classifier/evaluation/AucTest.java | 86 -
.../apache/mahout/classifier/mlp/Datasets.java | 866 -
.../mlp/RunMultilayerPerceptronTest.java | 66 -
.../mlp/TestMultilayerPerceptron.java | 88 -
.../classifier/mlp/TestNeuralNetwork.java | 353 -
.../mlp/TrainMultilayerPerceptronTest.java | 105 -
.../ComplementaryNaiveBayesClassifierTest.java | 47 -
.../naivebayes/NaiveBayesModelTest.java | 36 -
.../classifier/naivebayes/NaiveBayesTest.java | 135 -
.../naivebayes/NaiveBayesTestBase.java | 135 -
.../StandardNaiveBayesClassifierTest.java | 47 -
.../training/IndexInstancesMapperTest.java | 85 -
.../naivebayes/training/ThetaMapperTest.java | 61 -
.../naivebayes/training/WeightsMapperTest.java | 60 -
.../sequencelearning/hmm/HMMAlgorithmsTest.java | 164 -
.../sequencelearning/hmm/HMMEvaluatorTest.java | 63 -
.../sequencelearning/hmm/HMMModelTest.java | 32 -
.../sequencelearning/hmm/HMMTestBase.java | 73 -
.../sequencelearning/hmm/HMMTrainerTest.java | 163 -
.../sequencelearning/hmm/HMMUtilsTest.java | 161 -
.../sgd/AdaptiveLogisticRegressionTest.java | 186 -
.../classifier/sgd/CsvRecordFactoryTest.java | 90 -
.../classifier/sgd/GradientMachineTest.java | 41 -
.../classifier/sgd/ModelSerializerTest.java | 162 -
.../mahout/classifier/sgd/OnlineBaseTest.java | 160 -
.../sgd/OnlineLogisticRegressionTest.java | 330 -
.../classifier/sgd/PassiveAggressiveTest.java | 35 -
.../mahout/clustering/ClusteringTestUtils.java | 152 -
.../mahout/clustering/TestClusterInterface.java | 83 -
.../clustering/TestGaussianAccumulators.java | 186 -
.../clustering/canopy/TestCanopyCreation.java | 674 -
.../ClusterClassificationDriverTest.java | 255 -
.../fuzzykmeans/TestFuzzyKmeansClustering.java | 202 -
.../iterator/TestClusterClassifier.java | 238 -
.../clustering/kmeans/TestKmeansClustering.java | 385 -
.../kmeans/TestRandomSeedGenerator.java | 169 -
.../clustering/lda/cvb/TestCVBModelTrainer.java | 138 -
.../spectral/TestAffinityMatrixInputJob.java | 145 -
.../spectral/TestMatrixDiagonalizeJob.java | 116 -
.../spectral/TestUnitVectorizerJob.java | 65 -
.../clustering/spectral/TestVectorCache.java | 110 -
.../TestVectorMatrixMultiplicationJob.java | 75 -
.../spectral/kmeans/TestEigenSeedGenerator.java | 100 -
.../streaming/cluster/BallKMeansTest.java | 196 -
.../clustering/streaming/cluster/DataUtils.java | 92 -
.../streaming/cluster/StreamingKMeansTest.java | 169 -
.../mapreduce/StreamingKMeansTestMR.java | 283 -
.../tools/ResplitSequenceFilesTest.java | 80 -
.../clustering/topdown/PathDirectoryTest.java | 65 -
.../postprocessor/ClusterCountReaderTest.java | 121 -
.../ClusterOutputPostProcessorTest.java | 205 -
.../apache/mahout/common/AbstractJobTest.java | 240 -
.../DistributedCacheFileLocationTest.java | 46 -
.../mahout/common/DummyOutputCollector.java | 57 -
.../apache/mahout/common/DummyRecordWriter.java | 223 -
.../mahout/common/DummyRecordWriterTest.java | 45 -
.../mahout/common/DummyStatusReporter.java | 76 -
.../mahout/common/IntPairWritableTest.java | 114 -
.../apache/mahout/common/MahoutTestCase.java | 148 -
.../org/apache/mahout/common/MockIterator.java | 51 -
.../apache/mahout/common/StringUtilsTest.java | 70 -
.../distance/CosineDistanceMeasureTest.java | 66 -
.../distance/DefaultDistanceMeasureTest.java | 103 -
.../DefaultWeightedDistanceMeasureTest.java | 56 -
.../common/distance/TestChebyshevMeasure.java | 55 -
.../distance/TestEuclideanDistanceMeasure.java | 26 -
.../TestMahalanobisDistanceMeasure.java | 56 -
.../distance/TestManhattanDistanceMeasure.java | 26 -
.../common/distance/TestMinkowskiMeasure.java | 64 -
.../distance/TestTanimotoDistanceMeasure.java | 25 -
...estWeightedEuclideanDistanceMeasureTest.java | 25 -
.../TestWeightedManhattanDistanceMeasure.java | 26 -
.../common/iterator/CountingIteratorTest.java | 44 -
.../mahout/common/iterator/SamplerCase.java | 101 -
.../common/iterator/TestFixedSizeSampler.java | 33 -
.../common/iterator/TestSamplingIterator.java | 77 -
.../iterator/TestStableFixedSizeSampler.java | 33 -
.../mahout/common/lucene/AnalyzerUtilsTest.java | 38 -
.../apache/mahout/driver/MahoutDriverTest.java | 32 -
.../mahout/ep/EvolutionaryProcessTest.java | 81 -
.../apache/mahout/math/MatrixWritableTest.java | 148 -
.../java/org/apache/mahout/math/VarintTest.java | 189 -
.../apache/mahout/math/VectorWritableTest.java | 123 -
.../apache/mahout/math/hadoop/MathHelper.java | 236 -
.../math/hadoop/TestDistributedRowMatrix.java | 395 -
.../TestDistributedLanczosSolver.java | 132 -
.../TestDistributedLanczosSolverCLI.java | 190 -
.../TestVectorDistanceSimilarityJob.java | 238 -
.../cooccurrence/RowSimilarityJobTest.java | 214 -
.../measures/VectorSimilarityMeasuresTest.java | 133 -
.../TestDistributedConjugateGradientSolver.java | 59 -
...stDistributedConjugateGradientSolverCLI.java | 111 -
.../math/hadoop/stats/BasicStatsTest.java | 121 -
.../stochasticsvd/LocalSSVDPCASparseTest.java | 296 -
.../stochasticsvd/LocalSSVDSolverDenseTest.java | 206 -
.../LocalSSVDSolverSparseSequentialTest.java | 209 -
.../hadoop/stochasticsvd/SSVDCommonTest.java | 105 -
.../hadoop/stochasticsvd/SSVDTestsHelper.java | 172 -
.../LocalitySensitiveHashSearchTest.java | 119 -
.../mahout/math/neighborhood/LumpyData.java | 77 -
.../math/neighborhood/SearchQualityTest.java | 178 -
.../math/neighborhood/SearchSanityTest.java | 244 -
.../math/ssvd/SequentialOutOfCoreSvdTest.java | 195 -
.../apache/mahout/math/stats/OnlineAucTest.java | 127 -
.../apache/mahout/math/stats/SamplerTest.java | 45 -
.../vectorizer/DictionaryVectorizerTest.java | 220 -
.../vectorizer/DocumentProcessorTest.java | 81 -
.../EncodedVectorsFromSequenceFilesTest.java | 126 -
.../vectorizer/HighDFWordsPrunerTest.java | 154 -
.../vectorizer/RandomDocumentGenerator.java | 69 -
.../SparseVectorsFromSequenceFilesTest.java | 203 -
.../collocations/llr/CollocMapperTest.java | 180 -
.../collocations/llr/CollocReducerTest.java | 86 -
.../llr/GramKeyGroupComparatorTest.java | 45 -
.../llr/GramKeyPartitionerTest.java | 54 -
.../collocations/llr/GramKeyTest.java | 106 -
.../vectorizer/collocations/llr/GramTest.java | 216 -
.../collocations/llr/LLRReducerTest.java | 116 -
.../vectorizer/encoders/CachingEncoderTest.java | 48 -
.../encoders/ConstantValueEncoderTest.java | 74 -
.../encoders/ContinuousValueEncoderTest.java | 88 -
.../encoders/InteractionValueEncoderTest.java | 103 -
.../encoders/TextValueEncoderTest.java | 100 -
.../encoders/WordLikeValueEncoderTest.java | 99 -
mrlegacy/src/test/resources/FPGsynth.dat | 193 -
mrlegacy/src/test/resources/cancer.csv | 684 -
mrlegacy/src/test/resources/iris.csv | 151 -
mrlegacy/src/test/resources/retail.dat | 88162 -----------------
.../retail_results_with_min_sup_100.dat | 6438 --
mrlegacy/src/test/resources/sgd.csv | 61 -
mrlegacy/src/test/resources/word-list.txt | 512 -
pom.xml | 22 +-
spark/pom.xml | 105 +-
.../io/MahoutKryoRegistrator.scala | 1 -
.../apache/mahout/sparkbindings/package.scala | 8 +-
1563 files changed, 201330 insertions(+), 200686 deletions(-)
----------------------------------------------------------------------
[08/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
Posted by pa...@apache.org.
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java
new file mode 100644
index 0000000..dce23db
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java
@@ -0,0 +1,350 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import java.util.Collection;
+import java.util.Random;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.data.Dataset.Attribute;
+import org.junit.Test;
+
+public final class DataLoaderTest extends MahoutTestCase {
+
+ private Random rng;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+ }
+
+ @Test
+ public void testLoadDataWithDescriptor() throws Exception {
+ int nbAttributes = 10;
+ int datasize = 100;
+
+ // prepare the descriptors
+ String descriptor = Utils.randomDescriptor(rng, nbAttributes);
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ // prepare the data
+ double[][] data = Utils.randomDoubles(rng, descriptor, false, datasize);
+ Collection<Integer> missings = Lists.newArrayList();
+ String[] sData = prepareData(data, attrs, missings);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data loaded = DataLoader.loadData(dataset, sData);
+
+ testLoadedData(data, attrs, missings, loaded);
+ testLoadedDataset(data, attrs, missings, loaded);
+
+ // regression
+ data = Utils.randomDoubles(rng, descriptor, true, datasize);
+ missings = Lists.newArrayList();
+ sData = prepareData(data, attrs, missings);
+ dataset = DataLoader.generateDataset(descriptor, true, sData);
+ loaded = DataLoader.loadData(dataset, sData);
+
+ testLoadedData(data, attrs, missings, loaded);
+ testLoadedDataset(data, attrs, missings, loaded);
+ }
+
+ /**
+ * Test method for
+ * {@link DataLoader#generateDataset(CharSequence, boolean, String[])}
+ */
+ @Test
+ public void testGenerateDataset() throws Exception {
+ int nbAttributes = 10;
+ int datasize = 100;
+
+ // prepare the descriptors
+ String descriptor = Utils.randomDescriptor(rng, nbAttributes);
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ // prepare the data
+ double[][] data = Utils.randomDoubles(rng, descriptor, false, datasize);
+ Collection<Integer> missings = Lists.newArrayList();
+ String[] sData = prepareData(data, attrs, missings);
+ Dataset expected = DataLoader.generateDataset(descriptor, false, sData);
+
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+
+ assertEquals(expected, dataset);
+
+ // regression
+ data = Utils.randomDoubles(rng, descriptor, true, datasize);
+ missings = Lists.newArrayList();
+ sData = prepareData(data, attrs, missings);
+ expected = DataLoader.generateDataset(descriptor, true, sData);
+
+ dataset = DataLoader.generateDataset(descriptor, true, sData);
+
+ assertEquals(expected, dataset);
+}
+
+ /**
+ * Converts the data to an array of comma-separated strings and adds some
+ * missing values in all but IGNORED attributes
+ *
+ * @param missings indexes of vectors with missing values
+ */
+ private String[] prepareData(double[][] data, Attribute[] attrs, Collection<Integer> missings) {
+ int nbAttributes = attrs.length;
+
+ String[] sData = new String[data.length];
+
+ for (int index = 0; index < data.length; index++) {
+ int missingAttr;
+ if (rng.nextDouble() < 0.0) {
+ // add a missing value
+ missings.add(index);
+
+ // choose a random attribute (not IGNORED)
+ do {
+ missingAttr = rng.nextInt(nbAttributes);
+ } while (attrs[missingAttr].isIgnored());
+ } else {
+ missingAttr = -1;
+ }
+
+ StringBuilder builder = new StringBuilder();
+
+ for (int attr = 0; attr < nbAttributes; attr++) {
+ if (attr == missingAttr) {
+ // add a missing value here
+ builder.append('?').append(',');
+ } else {
+ builder.append(data[index][attr]).append(',');
+ }
+ }
+
+ sData[index] = builder.toString();
+ }
+
+ return sData;
+ }
+
+ /**
+ * Test if the loaded data matches the source data
+ *
+ * @param missings indexes of instance with missing values
+ */
+ static void testLoadedData(double[][] data, Attribute[] attrs, Collection<Integer> missings, Data loaded) {
+ int nbAttributes = attrs.length;
+
+ // check the vectors
+ assertEquals("number of instance", data.length - missings.size(), loaded .size());
+
+ // make sure that the attributes are loaded correctly
+ int lind = 0;
+ for (int index = 0; index < data.length; index++) {
+ if (missings.contains(index)) {
+ continue;
+ }// this vector won't be loaded
+
+ double[] vector = data[index];
+ Instance instance = loaded.get(lind);
+
+ int aId = 0;
+ for (int attr = 0; attr < nbAttributes; attr++) {
+ if (attrs[attr].isIgnored()) {
+ continue;
+ }
+
+ if (attrs[attr].isNumerical()) {
+ assertEquals(vector[attr], instance.get(aId), EPSILON);
+ aId++;
+ } else if (attrs[attr].isCategorical()) {
+ checkCategorical(data, missings, loaded, attr, aId, vector[attr],
+ instance.get(aId));
+ aId++;
+ } else if (attrs[attr].isLabel()) {
+ if (loaded.getDataset().isNumerical(aId)) {
+ assertEquals(vector[attr], instance.get(aId), EPSILON);
+ } else {
+ checkCategorical(data, missings, loaded, attr, aId, vector[attr],
+ instance.get(aId));
+ }
+ aId++;
+ }
+ }
+
+ lind++;
+ }
+
+ }
+
+ /**
+ * Test if the loaded dataset matches the source data
+ *
+ * @param missings indexes of instance with missing values
+ */
+ static void testLoadedDataset(double[][] data,
+ Attribute[] attrs,
+ Collection<Integer> missings,
+ Data loaded) {
+ int nbAttributes = attrs.length;
+
+ int iId = 0;
+ for (int index = 0; index < data.length; index++) {
+ if (missings.contains(index)) {
+ continue;
+ }
+
+ Instance instance = loaded.get(iId++);
+
+ int aId = 0;
+ for (int attr = 0; attr < nbAttributes; attr++) {
+ if (attrs[attr].isIgnored()) {
+ continue;
+ }
+
+ if (attrs[attr].isLabel()) {
+ if (!loaded.getDataset().isNumerical(aId)) {
+ double nValue = instance.get(aId);
+ String oValue = Double.toString(data[index][attr]);
+ assertEquals(loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON);
+ }
+ } else {
+ assertEquals(attrs[attr].isNumerical(), loaded.getDataset().isNumerical(aId));
+
+ if (attrs[attr].isCategorical()) {
+ double nValue = instance.get(aId);
+ String oValue = Double.toString(data[index][attr]);
+ assertEquals(loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON);
+ }
+ }
+ aId++;
+ }
+ }
+
+ }
+
+ @Test
+ public void testLoadDataFromFile() throws Exception {
+ int nbAttributes = 10;
+ int datasize = 100;
+
+ // prepare the descriptors
+ String descriptor = Utils.randomDescriptor(rng, nbAttributes);
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ // prepare the data
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, datasize);
+ Collection<Integer> missings = Lists.newArrayList();
+ String[] sData = prepareData(source, attrs, missings);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+
+ Path dataPath = Utils.writeDataToTestFile(sData);
+ FileSystem fs = dataPath.getFileSystem(getConfiguration());
+ Data loaded = DataLoader.loadData(dataset, fs, dataPath);
+
+ testLoadedData(source, attrs, missings, loaded);
+
+ // regression
+ source = Utils.randomDoubles(rng, descriptor, true, datasize);
+ missings = Lists.newArrayList();
+ sData = prepareData(source, attrs, missings);
+ dataset = DataLoader.generateDataset(descriptor, true, sData);
+
+ dataPath = Utils.writeDataToTestFile(sData);
+ fs = dataPath.getFileSystem(getConfiguration());
+ loaded = DataLoader.loadData(dataset, fs, dataPath);
+
+ testLoadedData(source, attrs, missings, loaded);
+}
+
+ /**
+ * Test method for
+ * {@link DataLoader#generateDataset(CharSequence, boolean, FileSystem, Path)}
+ */
+ @Test
+ public void testGenerateDatasetFromFile() throws Exception {
+ int nbAttributes = 10;
+ int datasize = 100;
+
+ // prepare the descriptors
+ String descriptor = Utils.randomDescriptor(rng, nbAttributes);
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ // prepare the data
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, datasize);
+ Collection<Integer> missings = Lists.newArrayList();
+ String[] sData = prepareData(source, attrs, missings);
+ Dataset expected = DataLoader.generateDataset(descriptor, false, sData);
+
+ Path path = Utils.writeDataToTestFile(sData);
+ FileSystem fs = path.getFileSystem(getConfiguration());
+
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, fs, path);
+
+ assertEquals(expected, dataset);
+
+ // regression
+ source = Utils.randomDoubles(rng, descriptor, false, datasize);
+ missings = Lists.newArrayList();
+ sData = prepareData(source, attrs, missings);
+ expected = DataLoader.generateDataset(descriptor, false, sData);
+
+ path = Utils.writeDataToTestFile(sData);
+ fs = path.getFileSystem(getConfiguration());
+
+ dataset = DataLoader.generateDataset(descriptor, false, fs, path);
+
+ assertEquals(expected, dataset);
+ }
+
+ /**
+ * each time oValue appears in data for the attribute 'attr', the nValue must
+ * appear in vectors for the same attribute.
+ *
+ * @param attr attribute's index in source
+ * @param aId attribute's index in loaded
+ * @param oValue old value in source
+ * @param nValue new value in loaded
+ */
+ static void checkCategorical(double[][] source,
+ Collection<Integer> missings,
+ Data loaded,
+ int attr,
+ int aId,
+ double oValue,
+ double nValue) {
+ int lind = 0;
+
+ for (int index = 0; index < source.length; index++) {
+ if (missings.contains(index)) {
+ continue;
+ }
+
+ if (source[index][attr] == oValue) {
+ assertEquals(nValue, loaded.get(lind).get(aId), EPSILON);
+ } else {
+ assertFalse(nValue == loaded.get(lind).get(aId));
+ }
+
+ lind++;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java
new file mode 100644
index 0000000..86e4461
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java
@@ -0,0 +1,396 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+import org.junit.Test;
+
+public class DataTest extends MahoutTestCase {
+
+ private static final int ATTRIBUTE_COUNT = 10;
+
+ private static final int DATA_SIZE = 100;
+
+ private Random rng;
+
+ private Data classifierData;
+
+ private Data regressionData;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+ classifierData = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
+ regressionData = Utils.randomData(rng, ATTRIBUTE_COUNT, true, DATA_SIZE);
+ }
+
+ /**
+ * Test method for
+ * {@link org.apache.mahout.classifier.df.data.Data#subset(org.apache.mahout.classifier.df.data.conditions.Condition)}.
+ */
+ @Test
+ public void testSubset() {
+ int n = 10;
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ int attr = rng.nextInt(classifierData.getDataset().nbAttributes());
+
+ double[] values = classifierData.values(attr);
+ double value = values[rng.nextInt(values.length)];
+
+ Data eSubset = classifierData.subset(Condition.equals(attr, value));
+ Data lSubset = classifierData.subset(Condition.lesser(attr, value));
+ Data gSubset = classifierData.subset(Condition.greaterOrEquals(attr, value));
+
+ for (int index = 0; index < DATA_SIZE; index++) {
+ Instance instance = classifierData.get(index);
+
+ if (instance.get(attr) < value) {
+ assertTrue(lSubset.contains(instance));
+ assertFalse(eSubset.contains(instance));
+ assertFalse(gSubset.contains(instance));
+ } else if (instance.get(attr) == value) {
+ assertFalse(lSubset.contains(instance));
+ assertTrue(eSubset.contains(instance));
+ assertTrue(gSubset.contains(instance));
+ } else {
+ assertFalse(lSubset.contains(instance));
+ assertFalse(eSubset.contains(instance));
+ assertTrue(gSubset.contains(instance));
+ }
+ }
+
+ // regression
+ attr = rng.nextInt(regressionData.getDataset().nbAttributes());
+
+ values = regressionData.values(attr);
+ value = values[rng.nextInt(values.length)];
+
+ eSubset = regressionData.subset(Condition.equals(attr, value));
+ lSubset = regressionData.subset(Condition.lesser(attr, value));
+ gSubset = regressionData.subset(Condition.greaterOrEquals(attr, value));
+
+ for (int index = 0; index < DATA_SIZE; index++) {
+ Instance instance = regressionData.get(index);
+
+ if (instance.get(attr) < value) {
+ assertTrue(lSubset.contains(instance));
+ assertFalse(eSubset.contains(instance));
+ assertFalse(gSubset.contains(instance));
+ } else if (instance.get(attr) == value) {
+ assertFalse(lSubset.contains(instance));
+ assertTrue(eSubset.contains(instance));
+ assertTrue(gSubset.contains(instance));
+ } else {
+ assertFalse(lSubset.contains(instance));
+ assertFalse(eSubset.contains(instance));
+ assertTrue(gSubset.contains(instance));
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testValues() throws Exception {
+ for (int attr = 0; attr < classifierData.getDataset().nbAttributes(); attr++) {
+ double[] values = classifierData.values(attr);
+
+ // each value of the attribute should appear exactly one time in values
+ for (int index = 0; index < DATA_SIZE; index++) {
+ assertEquals(1, count(values, classifierData.get(index).get(attr)));
+ }
+ }
+
+ for (int attr = 0; attr < regressionData.getDataset().nbAttributes(); attr++) {
+ double[] values = regressionData.values(attr);
+
+ // each value of the attribute should appear exactly one time in values
+ for (int index = 0; index < DATA_SIZE; index++) {
+ assertEquals(1, count(values, regressionData.get(index).get(attr)));
+ }
+ }
+ }
+
+ private static int count(double[] values, double value) {
+ int count = 0;
+ for (double v : values) {
+ if (v == value) {
+ count++;
+ }
+ }
+ return count;
+ }
+
+ @Test
+ public void testIdenticalTrue() throws Exception {
+ // generate a small data, only to get the dataset
+ Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, false, 1).getDataset();
+
+ // test empty data
+ Data empty = new Data(dataset);
+ assertTrue(empty.isIdentical());
+
+ // test identical data, except for the labels
+ Data identical = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
+ Instance model = identical.get(0);
+ for (int index = 1; index < DATA_SIZE; index++) {
+ for (int attr = 0; attr < identical.getDataset().nbAttributes(); attr++) {
+ identical.get(index).set(attr, model.get(attr));
+ }
+ }
+
+ assertTrue(identical.isIdentical());
+ }
+
+ @Test
+ public void testIdenticalFalse() throws Exception {
+ int n = 10;
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE);
+
+ // choose a random instance
+ int index = rng.nextInt(DATA_SIZE);
+ Instance instance = data.get(index);
+
+ // change a random attribute
+ int attr = rng.nextInt(data.getDataset().nbAttributes());
+ instance.set(attr, instance.get(attr) + 1);
+
+ assertFalse(data.isIdentical());
+ }
+ }
+
+ @Test
+ public void testIdenticalLabelTrue() throws Exception {
+ // generate a small data, only to get a dataset
+ Dataset dataset = Utils.randomData(rng, ATTRIBUTE_COUNT, false, 1).getDataset();
+
+ // test empty data
+ Data empty = new Data(dataset);
+ assertTrue(empty.identicalLabel());
+
+ // test identical labels
+ String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
+ double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false,
+ DATA_SIZE, rng.nextInt());
+ String[] sData = Utils.double2String(source);
+
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data data = DataLoader.loadData(dataset, sData);
+
+ assertTrue(data.identicalLabel());
+ }
+
+ @Test
+ public void testIdenticalLabelFalse() throws Exception {
+ int n = 10;
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
+ int label = Utils.findLabel(descriptor);
+ double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false,
+ DATA_SIZE, rng.nextInt());
+ // choose a random vector and change its label
+ int index = rng.nextInt(DATA_SIZE);
+ source[index][label]++;
+
+ String[] sData = Utils.double2String(source);
+
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data data = DataLoader.loadData(dataset, sData);
+
+ assertFalse(data.identicalLabel());
+ }
+ }
+
+ /**
+ * Test method for
+ * {@link org.apache.mahout.classifier.df.data.Data#bagging(java.util.Random)}.
+ */
+ @Test
+ public void testBagging() {
+ Data bag = classifierData.bagging(rng);
+
+ // the bag should have the same size as the data
+ assertEquals(classifierData.size(), bag.size());
+
+ // at least one element from the data should not be in the bag
+ boolean found = false;
+ for (int index = 0; index < classifierData.size() && !found; index++) {
+ found = !bag.contains(classifierData.get(index));
+ }
+
+ assertTrue("some instances from data should not be in the bag", found);
+
+ // regression
+ bag = regressionData.bagging(rng);
+
+ // the bag should have the same size as the data
+ assertEquals(regressionData.size(), bag.size());
+
+ // at least one element from the data should not be in the bag
+ found = false;
+ for (int index = 0; index < regressionData.size() && !found; index++) {
+ found = !bag.contains(regressionData.get(index));
+ }
+
+ assertTrue("some instances from data should not be in the bag", found);
+}
+
+ /**
+ * Test method for
+ * {@link org.apache.mahout.classifier.df.data.Data#rsplit(java.util.Random, int)}.
+ */
+ @Test
+ public void testRsplit() {
+
+ // rsplit should handle empty subsets
+ Data source = classifierData.clone();
+ Data subset = source.rsplit(rng, 0);
+ assertTrue("subset should be empty", subset.isEmpty());
+ assertEquals("source.size is incorrect", DATA_SIZE, source.size());
+
+ // rsplit should handle full size subsets
+ source = classifierData.clone();
+ subset = source.rsplit(rng, DATA_SIZE);
+ assertEquals("subset.size is incorrect", DATA_SIZE, subset.size());
+ assertTrue("source should be empty", source.isEmpty());
+
+ // random case
+ int subsize = rng.nextInt(DATA_SIZE);
+ source = classifierData.clone();
+ subset = source.rsplit(rng, subsize);
+ assertEquals("subset.size is incorrect", subsize, subset.size());
+ assertEquals("source.size is incorrect", DATA_SIZE - subsize, source.size());
+
+ // regression
+ // rsplit should handle empty subsets
+ source = regressionData.clone();
+ subset = source.rsplit(rng, 0);
+ assertTrue("subset should be empty", subset.isEmpty());
+ assertEquals("source.size is incorrect", DATA_SIZE, source.size());
+
+ // rsplit should handle full size subsets
+ source = regressionData.clone();
+ subset = source.rsplit(rng, DATA_SIZE);
+ assertEquals("subset.size is incorrect", DATA_SIZE, subset.size());
+ assertTrue("source should be empty", source.isEmpty());
+
+ // random case
+ subsize = rng.nextInt(DATA_SIZE);
+ source = regressionData.clone();
+ subset = source.rsplit(rng, subsize);
+ assertEquals("subset.size is incorrect", subsize, subset.size());
+ assertEquals("source.size is incorrect", DATA_SIZE - subsize, source.size());
+}
+
+ @Test
+ public void testCountLabel() throws Exception {
+ Dataset dataset = classifierData.getDataset();
+ int[] counts = new int[dataset.nblabels()];
+
+ int n = 10;
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ Arrays.fill(counts, 0);
+ classifierData.countLabels(counts);
+
+ for (int index = 0; index < classifierData.size(); index++) {
+ counts[(int) dataset.getLabel(classifierData.get(index))]--;
+ }
+
+ for (int label = 0; label < classifierData.getDataset().nblabels(); label++) {
+ assertEquals("Wrong label 'equals' count", 0, counts[0]);
+ }
+ }
+ }
+
+ @Test
+ public void testMajorityLabel() throws Exception {
+
+ // all instances have the same label
+ String descriptor = Utils.randomDescriptor(rng, ATTRIBUTE_COUNT);
+ int label = Utils.findLabel(descriptor);
+
+ int label1 = rng.nextInt();
+ double[][] source = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100,
+ label1);
+ String[] sData = Utils.double2String(source);
+
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data data = DataLoader.loadData(dataset, sData);
+
+ int code1 = dataset.labelCode(Double.toString(label1));
+
+ assertEquals(code1, data.majorityLabel(rng));
+
+ // 51/100 vectors have label2
+ int label2 = label1 + 1;
+ int nblabel2 = 51;
+ while (nblabel2 > 0) {
+ double[] vector = source[rng.nextInt(100)];
+ if (vector[label] != label2) {
+ vector[label] = label2;
+ nblabel2--;
+ }
+ }
+ sData = Utils.double2String(source);
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
+ data = DataLoader.loadData(dataset, sData);
+ int code2 = dataset.labelCode(Double.toString(label2));
+
+ // label2 should be the majority label
+ assertEquals(code2, data.majorityLabel(rng));
+
+ // 50 vectors with label1 and 50 vectors with label2
+ do {
+ double[] vector = source[rng.nextInt(100)];
+ if (vector[label] == label2) {
+ vector[label] = label1;
+ break;
+ }
+ } while (true);
+ sData = Utils.double2String(source);
+
+ data = DataLoader.loadData(dataset, sData);
+ code1 = dataset.labelCode(Double.toString(label1));
+ code2 = dataset.labelCode(Double.toString(label2));
+
+ // majorityLabel should return label1 and label2 at random
+ boolean found1 = false;
+ boolean found2 = false;
+ for (int index = 0; index < 10 && (!found1 || !found2); index++) {
+ int major = data.majorityLabel(rng);
+ if (major == code1) {
+ found1 = true;
+ }
+ if (major == code2) {
+ found2 = true;
+ }
+ }
+ assertTrue(found1 && found2);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java
new file mode 100644
index 0000000..3cdf65a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to You under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package org.apache.mahout.classifier.df.data;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public final class DatasetTest extends MahoutTestCase {
+
+ @Test
+ public void jsonEncoding() throws DescriptorException {
+ Dataset to = DataLoader.generateDataset("N C I L", true, new String[]{"1 foo 2 3", "4 bar 5 6"});
+
+ // to JSON
+ //assertEquals(json, to.toJSON());
+ assertEquals(3, to.nbAttributes());
+ assertEquals(1, to.getIgnored().length);
+ assertEquals(2, to.getIgnored()[0]);
+ assertEquals(2, to.getLabelId());
+ assertTrue(to.isNumerical(0));
+
+ // from JSON
+ Dataset fromJson = Dataset.fromJSON(to.toJSON());
+ assertEquals(3, fromJson.nbAttributes());
+ assertEquals(1, fromJson.getIgnored().length);
+ assertEquals(2, fromJson.getIgnored()[0]);
+ assertTrue(fromJson.isNumerical(0));
+
+ // read values for a nominal
+ assertNotEquals(fromJson.valueOf(1, "bar"), fromJson.valueOf(1, "foo"));
+ }
+
+ @Test
+ public void jsonEncodingIgnoreFeatures() throws DescriptorException {;
+ Dataset to = DataLoader.generateDataset("N C I L", false, new String[]{"1 foo 2 Red", "4 bar 5 Blue"});
+
+ // to JSON
+ //assertEquals(json, to.toJSON());
+ assertEquals(3, to.nbAttributes());
+ assertEquals(1, to.getIgnored().length);
+ assertEquals(2, to.getIgnored()[0]);
+ assertEquals(2, to.getLabelId());
+ assertTrue(to.isNumerical(0));
+ assertNotEquals(to.valueOf(1, "bar"), to.valueOf(1, "foo"));
+ assertNotEquals(to.valueOf(2, "Red"), to.valueOf(2, "Blue"));
+
+ // from JSON
+ Dataset fromJson = Dataset.fromJSON(to.toJSON());
+ assertEquals(3, fromJson.nbAttributes());
+ assertEquals(1, fromJson.getIgnored().length);
+ assertEquals(2, fromJson.getIgnored()[0]);
+ assertTrue(fromJson.isNumerical(0));
+
+ // read values for a nominal, one before and one after the ignore feature
+ assertNotEquals(fromJson.valueOf(1, "bar"), fromJson.valueOf(1, "foo"));
+ assertNotEquals(fromJson.valueOf(2, "Red"), fromJson.valueOf(2, "Blue"));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java
new file mode 100644
index 0000000..121e1f8
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/data/DescriptorUtilsTest.java
@@ -0,0 +1,92 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.data.Dataset.Attribute;
+import org.junit.Test;
+
+public final class DescriptorUtilsTest extends MahoutTestCase {
+
+ /**
+ * Test method for
+ * {@link org.apache.mahout.classifier.df.data.DescriptorUtils#parseDescriptor(java.lang.CharSequence)}.
+ */
+ @Test
+ public void testParseDescriptor() throws Exception {
+ int n = 10;
+ int maxnbAttributes = 100;
+
+ Random rng = RandomUtils.getRandom();
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ int nbAttributes = rng.nextInt(maxnbAttributes) + 1;
+
+ char[] tokens = Utils.randomTokens(rng, nbAttributes);
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(Utils.generateDescriptor(tokens));
+
+ // verify that the attributes matches the token list
+ assertEquals("attributes size", nbAttributes, attrs.length);
+
+ for (int attr = 0; attr < nbAttributes; attr++) {
+ switch (tokens[attr]) {
+ case 'I':
+ assertTrue(attrs[attr].isIgnored());
+ break;
+ case 'N':
+ assertTrue(attrs[attr].isNumerical());
+ break;
+ case 'C':
+ assertTrue(attrs[attr].isCategorical());
+ break;
+ case 'L':
+ assertTrue(attrs[attr].isLabel());
+ break;
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testGenerateDescription() throws Exception {
+ validate("", "");
+ validate("I L C C N N N C", "I L C C N N N C");
+ validate("I L C C N N N C", "I L 2 C 3 N C");
+ validate("I L C C N N N C", " I L 2 C 3 N C ");
+
+ try {
+ validate("", "I L 2 2 C 2 N C");
+ fail("2 consecutive multiplicators");
+ } catch (DescriptorException e) {
+ }
+
+ try {
+ validate("", "I L 2 C -2 N C");
+ fail("negative multiplicator");
+ } catch (DescriptorException e) {
+ }
+ }
+
+ private static void validate(String descriptor, CharSequence description) throws DescriptorException {
+ assertEquals(descriptor, DescriptorUtils.generateDescriptor(description));
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java b/mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java
new file mode 100644
index 0000000..1cf8b6a
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/data/Utils.java
@@ -0,0 +1,283 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.data;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Random;
+
+import com.google.common.base.Charsets;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.classifier.df.data.Dataset.Attribute;
+import org.apache.mahout.common.MahoutTestCase;
+
+/**
+ * Helper methods used by the tests
+ *
+ */
+public final class Utils {
+
+ private Utils() {}
+
+ /** Used when generating random CATEGORICAL values */
+ private static final int CATEGORICAL_RANGE = 100;
+
+ /**
+ * Generates a random list of tokens
+ * <ul>
+ * <li>each attribute has 50% chance to be NUMERICAL ('N') or CATEGORICAL
+ * ('C')</li>
+ * <li>10% of the attributes are IGNORED ('I')</li>
+ * <li>one randomly chosen attribute becomes the LABEL ('L')</li>
+ * </ul>
+ *
+ * @param rng Random number generator
+ * @param nbTokens number of tokens to generate
+ */
+ public static char[] randomTokens(Random rng, int nbTokens) {
+ char[] result = new char[nbTokens];
+
+ for (int token = 0; token < nbTokens; token++) {
+ double rand = rng.nextDouble();
+ if (rand < 0.1) {
+ result[token] = 'I'; // IGNORED
+ } else if (rand >= 0.5) {
+ result[token] = 'C';
+ } else {
+ result[token] = 'N'; // NUMERICAL
+ } // CATEGORICAL
+ }
+
+ // choose the label
+ result[rng.nextInt(nbTokens)] = 'L';
+
+ return result;
+ }
+
+ /**
+ * Generates a space-separated String that contains all the tokens
+ */
+ public static String generateDescriptor(char[] tokens) {
+ StringBuilder builder = new StringBuilder();
+
+ for (char token : tokens) {
+ builder.append(token).append(' ');
+ }
+
+ return builder.toString();
+ }
+
+ /**
+ * Generates a random descriptor as follows:<br>
+ * <ul>
+ * <li>each attribute has 50% chance to be NUMERICAL or CATEGORICAL</li>
+ * <li>10% of the attributes are IGNORED</li>
+ * <li>one randomly chosen attribute becomes the LABEL</li>
+ * </ul>
+ */
+ public static String randomDescriptor(Random rng, int nbAttributes) {
+ return generateDescriptor(randomTokens(rng, nbAttributes));
+ }
+
+ /**
+ * generates random data based on the given descriptor
+ *
+ * @param rng Random number generator
+ * @param descriptor attributes description
+ * @param number number of data lines to generate
+ */
+ public static double[][] randomDoubles(Random rng, CharSequence descriptor, boolean regression, int number)
+ throws DescriptorException {
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+
+ double[][] data = new double[number][];
+
+ for (int index = 0; index < number; index++) {
+ data[index] = randomVector(rng, attrs, regression);
+ }
+
+ return data;
+ }
+
+ /**
+ * Generates random data
+ *
+ * @param rng Random number generator
+ * @param nbAttributes number of attributes
+ * @param regression true is the label should be numerical
+ * @param size data size
+ */
+ public static Data randomData(Random rng, int nbAttributes, boolean regression, int size) throws DescriptorException {
+ String descriptor = randomDescriptor(rng, nbAttributes);
+ double[][] source = randomDoubles(rng, descriptor, regression, size);
+ String[] sData = double2String(source);
+ Dataset dataset = DataLoader.generateDataset(descriptor, regression, sData);
+
+ return DataLoader.loadData(dataset, sData);
+ }
+
+ /**
+ * generates a random vector based on the given attributes.<br>
+ * the attributes' values are generated as follows :<br>
+ * <ul>
+ * <li>each IGNORED attribute receives a Double.NaN</li>
+ * <li>each NUMERICAL attribute receives a random double</li>
+ * <li>each CATEGORICAL and LABEL attribute receives a random integer in the
+ * range [0, CATEGORICAL_RANGE[</li>
+ * </ul>
+ *
+ * @param attrs attributes description
+ */
+ private static double[] randomVector(Random rng, Attribute[] attrs, boolean regression) {
+ double[] vector = new double[attrs.length];
+
+ for (int attr = 0; attr < attrs.length; attr++) {
+ if (attrs[attr].isIgnored()) {
+ vector[attr] = Double.NaN;
+ } else if (attrs[attr].isNumerical()) {
+ vector[attr] = rng.nextDouble();
+ } else if (attrs[attr].isCategorical()) {
+ vector[attr] = rng.nextInt(CATEGORICAL_RANGE);
+ } else { // LABEL
+ if (regression) {
+ vector[attr] = rng.nextDouble();
+ } else {
+ vector[attr] = rng.nextInt(CATEGORICAL_RANGE);
+ }
+ }
+ }
+
+ return vector;
+ }
+
+ /**
+ * converts a double array to a comma-separated string
+ *
+ * @param v double array
+ * @return comma-separated string
+ */
+ private static String double2String(double[] v) {
+ StringBuilder builder = new StringBuilder();
+
+ for (double aV : v) {
+ builder.append(aV).append(',');
+ }
+
+ return builder.toString();
+ }
+
+ /**
+ * converts an array of double arrays to an array of comma-separated strings
+ *
+ * @param source array of double arrays
+ * @return array of comma-separated strings
+ */
+ public static String[] double2String(double[][] source) {
+ String[] output = new String[source.length];
+
+ for (int index = 0; index < source.length; index++) {
+ output[index] = double2String(source[index]);
+ }
+
+ return output;
+ }
+
+ /**
+ * Generates random data with same label value
+ *
+ * @param number data size
+ * @param value label value
+ */
+ public static double[][] randomDoublesWithSameLabel(Random rng,
+ CharSequence descriptor,
+ boolean regression,
+ int number,
+ int value) throws DescriptorException {
+ int label = findLabel(descriptor);
+
+ double[][] source = randomDoubles(rng, descriptor, regression, number);
+
+ for (int index = 0; index < number; index++) {
+ source[index][label] = value;
+ }
+
+ return source;
+ }
+
+ /**
+ * finds the label attribute's index
+ */
+ public static int findLabel(CharSequence descriptor) throws DescriptorException {
+ Attribute[] attrs = DescriptorUtils.parseDescriptor(descriptor);
+ return ArrayUtils.indexOf(attrs, Attribute.LABEL);
+ }
+
+ private static void writeDataToFile(String[] sData, Path path) throws IOException {
+ BufferedWriter output = null;
+ try {
+ output = Files.newWriter(new File(path.toString()), Charsets.UTF_8);
+ for (String line : sData) {
+ output.write(line);
+ output.write('\n');
+ }
+ } finally {
+ Closeables.close(output, false);
+ }
+
+ }
+
+ public static Path writeDataToTestFile(String[] sData) throws IOException {
+ Path testData = new Path("testdata/Data");
+ MahoutTestCase ca = new MahoutTestCase();
+ FileSystem fs = testData.getFileSystem(ca.getConfiguration());
+ if (!fs.exists(testData)) {
+ fs.mkdirs(testData);
+ }
+
+ Path path = new Path(testData, "DataLoaderTest.data");
+
+ writeDataToFile(sData, path);
+
+ return path;
+ }
+
+ /**
+ * Split the data into numMaps splits
+ */
+ public static String[][] splitData(String[] sData, int numMaps) {
+ int nbInstances = sData.length;
+ int partitionSize = nbInstances / numMaps;
+
+ String[][] splits = new String[numMaps][];
+
+ for (int partition = 0; partition < numMaps; partition++) {
+ int from = partition * partitionSize;
+ int to = partition == (numMaps - 1) ? nbInstances : (partition + 1) * partitionSize;
+
+ splits[partition] = Arrays.copyOfRange(sData, from, to);
+ }
+
+ return splits;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java
new file mode 100644
index 0000000..0a4a034
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputFormatTest.java
@@ -0,0 +1,109 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.inmem;
+
+import java.util.List;
+import java.util.Random;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.mapreduce.InputSplit;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.mapreduce.Builder;
+import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit;
+import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemRecordReader;
+import org.junit.Test;
+
+public final class InMemInputFormatTest extends MahoutTestCase {
+
+ @Test
+ public void testSplits() throws Exception {
+ int n = 1;
+ int maxNumSplits = 100;
+ int maxNbTrees = 1000;
+
+ Random rng = RandomUtils.getRandom();
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ int numSplits = rng.nextInt(maxNumSplits) + 1;
+ int nbTrees = rng.nextInt(maxNbTrees) + 1;
+
+ Configuration conf = getConfiguration();
+ Builder.setNbTrees(conf, nbTrees);
+
+ InMemInputFormat inputFormat = new InMemInputFormat();
+ List<InputSplit> splits = inputFormat.getSplits(conf, numSplits);
+
+ assertEquals(numSplits, splits.size());
+
+ int nbTreesPerSplit = nbTrees / numSplits;
+ int totalTrees = 0;
+ int expectedId = 0;
+
+ for (int index = 0; index < numSplits; index++) {
+ assertTrue(splits.get(index) instanceof InMemInputSplit);
+
+ InMemInputSplit split = (InMemInputSplit) splits.get(index);
+
+ assertEquals(expectedId, split.getFirstId());
+
+ if (index < numSplits - 1) {
+ assertEquals(nbTreesPerSplit, split.getNbTrees());
+ } else {
+ assertEquals(nbTrees - totalTrees, split.getNbTrees());
+ }
+
+ totalTrees += split.getNbTrees();
+ expectedId += split.getNbTrees();
+ }
+ }
+ }
+
+ @Test
+ public void testRecordReader() throws Exception {
+ int n = 1;
+ int maxNumSplits = 100;
+ int maxNbTrees = 1000;
+
+ Random rng = RandomUtils.getRandom();
+
+ for (int nloop = 0; nloop < n; nloop++) {
+ int numSplits = rng.nextInt(maxNumSplits) + 1;
+ int nbTrees = rng.nextInt(maxNbTrees) + 1;
+
+ Configuration conf = getConfiguration();
+ Builder.setNbTrees(conf, nbTrees);
+
+ InMemInputFormat inputFormat = new InMemInputFormat();
+ List<InputSplit> splits = inputFormat.getSplits(conf, numSplits);
+
+ for (int index = 0; index < numSplits; index++) {
+ InMemInputSplit split = (InMemInputSplit) splits.get(index);
+ InMemRecordReader reader = new InMemRecordReader(split);
+
+ reader.initialize(split, null);
+
+ for (int tree = 0; tree < split.getNbTrees(); tree++) {
+ // reader.next() should return true until there is no tree left
+ assertEquals(tree < split.getNbTrees(), reader.nextKeyValue());
+ assertEquals(split.getFirstId() + tree, reader.getCurrentKey().get());
+ }
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java
new file mode 100644
index 0000000..f94841d
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/inmem/InMemInputSplitTest.java
@@ -0,0 +1,77 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.inmem;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.mapreduce.inmem.InMemInputFormat.InMemInputSplit;
+import org.junit.Before;
+import org.junit.Test;
+
+public final class InMemInputSplitTest extends MahoutTestCase {
+
+ private Random rng;
+ private ByteArrayOutputStream byteOutStream;
+ private DataOutput out;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+ byteOutStream = new ByteArrayOutputStream();
+ out = new DataOutputStream(byteOutStream);
+ }
+
+ /**
+ * Make sure that all the fields are processed correctly
+ */
+ @Test
+ public void testWritable() throws Exception {
+ InMemInputSplit split = new InMemInputSplit(rng.nextInt(), rng.nextInt(1000), rng.nextLong());
+
+ split.write(out);
+ assertEquals(split, readSplit());
+ }
+
+ /**
+ * test the case seed == null
+ */
+ @Test
+ public void testNullSeed() throws Exception {
+ InMemInputSplit split = new InMemInputSplit(rng.nextInt(), rng.nextInt(1000), null);
+
+ split.write(out);
+ assertEquals(split, readSplit());
+ }
+
+ private InMemInputSplit readSplit() throws IOException {
+ ByteArrayInputStream byteInStream = new ByteArrayInputStream(byteOutStream.toByteArray());
+ DataInput in = new DataInputStream(byteInStream);
+ return InMemInputSplit.read(in);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java
new file mode 100644
index 0000000..3903c33
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialBuilderTest.java
@@ -0,0 +1,197 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Random;
+
+import com.google.common.collect.Lists;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.SequenceFile.Writer;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.builder.DefaultTreeBuilder;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.mapreduce.MapredOutput;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.junit.Test;
+
+public final class PartialBuilderTest extends MahoutTestCase {
+
+ private static final int NUM_MAPS = 5;
+
+ private static final int NUM_TREES = 32;
+
+ /** instances per partition */
+ private static final int NUM_INSTANCES = 20;
+
+ @Test
+ public void testProcessOutput() throws Exception {
+ Configuration conf = getConfiguration();
+ conf.setInt("mapred.map.tasks", NUM_MAPS);
+
+ Random rng = RandomUtils.getRandom();
+
+ // prepare the output
+ TreeID[] keys = new TreeID[NUM_TREES];
+ MapredOutput[] values = new MapredOutput[NUM_TREES];
+ int[] firstIds = new int[NUM_MAPS];
+ randomKeyValues(rng, keys, values, firstIds);
+
+ // store the output in a sequence file
+ Path base = getTestTempDirPath("testdata");
+ FileSystem fs = base.getFileSystem(conf);
+
+ Path outputFile = new Path(base, "PartialBuilderTest.seq");
+ Writer writer = SequenceFile.createWriter(fs, conf, outputFile,
+ TreeID.class, MapredOutput.class);
+
+ try {
+ for (int index = 0; index < NUM_TREES; index++) {
+ writer.append(keys[index], values[index]);
+ }
+ } finally {
+ Closeables.close(writer, false);
+ }
+
+ // load the output and make sure its valid
+ TreeID[] newKeys = new TreeID[NUM_TREES];
+ Node[] newTrees = new Node[NUM_TREES];
+
+ PartialBuilder.processOutput(new Job(conf), base, newKeys, newTrees);
+
+ // check the forest
+ for (int tree = 0; tree < NUM_TREES; tree++) {
+ assertEquals(values[tree].getTree(), newTrees[tree]);
+ }
+
+ assertTrue("keys not equal", Arrays.deepEquals(keys, newKeys));
+ }
+
+ /**
+ * Make sure that the builder passes the good parameters to the job
+ *
+ */
+ @Test
+ public void testConfigure() {
+ TreeBuilder treeBuilder = new DefaultTreeBuilder();
+ Path dataPath = new Path("notUsedDataPath");
+ Path datasetPath = new Path("notUsedDatasetPath");
+ Long seed = 5L;
+
+ new PartialBuilderChecker(treeBuilder, dataPath, datasetPath, seed);
+ }
+
+ /**
+ * Generates random (key, value) pairs. Shuffles the partition's order
+ *
+ * @param rng
+ * @param keys
+ * @param values
+ * @param firstIds partitions's first ids in hadoop's order
+ */
+ private static void randomKeyValues(Random rng, TreeID[] keys, MapredOutput[] values, int[] firstIds) {
+ int index = 0;
+ int firstId = 0;
+ Collection<Integer> partitions = Lists.newArrayList();
+
+ for (int p = 0; p < NUM_MAPS; p++) {
+ // select a random partition, not yet selected
+ int partition;
+ do {
+ partition = rng.nextInt(NUM_MAPS);
+ } while (partitions.contains(partition));
+
+ partitions.add(partition);
+
+ int nbTrees = Step1Mapper.nbTrees(NUM_MAPS, NUM_TREES, partition);
+
+ for (int treeId = 0; treeId < nbTrees; treeId++) {
+ Node tree = new Leaf(rng.nextInt(100));
+
+ keys[index] = new TreeID(partition, treeId);
+ values[index] = new MapredOutput(tree, nextIntArray(rng, NUM_INSTANCES));
+
+ index++;
+ }
+
+ firstIds[p] = firstId;
+ firstId += NUM_INSTANCES;
+ }
+
+ }
+
+ private static int[] nextIntArray(Random rng, int size) {
+ int[] array = new int[size];
+ for (int index = 0; index < size; index++) {
+ array[index] = rng.nextInt(101) - 1;
+ }
+
+ return array;
+ }
+
+ static class PartialBuilderChecker extends PartialBuilder {
+
+ private final Long seed;
+
+ private final TreeBuilder treeBuilder;
+
+ private final Path datasetPath;
+
+ PartialBuilderChecker(TreeBuilder treeBuilder, Path dataPath,
+ Path datasetPath, Long seed) {
+ super(treeBuilder, dataPath, datasetPath, seed);
+
+ this.seed = seed;
+ this.treeBuilder = treeBuilder;
+ this.datasetPath = datasetPath;
+ }
+
+ @Override
+ protected boolean runJob(Job job) throws IOException {
+ // no need to run the job, just check if the params are correct
+
+ Configuration conf = job.getConfiguration();
+
+ assertEquals(seed, getRandomSeed(conf));
+
+ // PartialBuilder should detect the 'local' mode and overrides the number
+ // of map tasks
+ assertEquals(1, conf.getInt("mapred.map.tasks", -1));
+
+ assertEquals(NUM_TREES, getNbTrees(conf));
+
+ assertFalse(isOutput(conf));
+
+ assertEquals(treeBuilder, getTreeBuilder(conf));
+
+ assertEquals(datasetPath, getDistributedCacheFile(conf, 0));
+
+ return true;
+ }
+
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java
new file mode 100644
index 0000000..a4c1bfd
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/Step1MapperTest.java
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import org.easymock.EasyMock;
+import java.util.Random;
+
+import org.apache.hadoop.io.LongWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Utils;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.common.MahoutTestCase;
+import org.easymock.Capture;
+import org.easymock.CaptureType;
+import org.junit.Test;
+
+public final class Step1MapperTest extends MahoutTestCase {
+
+ /**
+ * Make sure that the data used to build the trees is from the mapper's
+ * partition
+ *
+ */
+ private static class MockTreeBuilder implements TreeBuilder {
+
+ private Data expected;
+
+ public void setExpected(Data data) {
+ expected = data;
+ }
+
+ @Override
+ public Node build(Random rng, Data data) {
+ for (int index = 0; index < data.size(); index++) {
+ assertTrue(expected.contains(data.get(index)));
+ }
+
+ return new Leaf(Double.NaN);
+ }
+ }
+
+ /**
+ * Special Step1Mapper that can be configured without using a Configuration
+ *
+ */
+ private static class MockStep1Mapper extends Step1Mapper {
+ private MockStep1Mapper(TreeBuilder treeBuilder, Dataset dataset, Long seed,
+ int partition, int numMapTasks, int numTrees) {
+ configure(false, treeBuilder, dataset);
+ configure(seed, partition, numMapTasks, numTrees);
+ }
+ }
+
+ private static class TreeIDCapture extends Capture<TreeID> {
+
+ private TreeIDCapture() {
+ super(CaptureType.ALL);
+ }
+
+ @Override
+ public void setValue(final TreeID value) {
+ super.setValue(value.clone());
+ }
+ }
+
+ /** nb attributes per generated data instance */
+ static final int NUM_ATTRIBUTES = 4;
+
+ /** nb generated data instances */
+ static final int NUM_INSTANCES = 100;
+
+ /** nb trees to build */
+ static final int NUM_TREES = 10;
+
+ /** nb mappers to use */
+ static final int NUM_MAPPERS = 2;
+
+ @SuppressWarnings({ "rawtypes", "unchecked" })
+ @Test
+ public void testMapper() throws Exception {
+ Random rng = RandomUtils.getRandom();
+
+ // prepare the data
+ String descriptor = Utils.randomDescriptor(rng, NUM_ATTRIBUTES);
+ double[][] source = Utils.randomDoubles(rng, descriptor, false, NUM_INSTANCES);
+ String[] sData = Utils.double2String(source);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ String[][] splits = Utils.splitData(sData, NUM_MAPPERS);
+
+ MockTreeBuilder treeBuilder = new MockTreeBuilder();
+
+ LongWritable key = new LongWritable();
+ Text value = new Text();
+
+ int treeIndex = 0;
+
+ for (int partition = 0; partition < NUM_MAPPERS; partition++) {
+ String[] split = splits[partition];
+ treeBuilder.setExpected(DataLoader.loadData(dataset, split));
+
+ // expected number of trees that this mapper will build
+ int mapNbTrees = Step1Mapper.nbTrees(NUM_MAPPERS, NUM_TREES, partition);
+
+ Mapper.Context context = EasyMock.createNiceMock(Mapper.Context.class);
+ Capture<TreeID> capturedKeys = new TreeIDCapture();
+ context.write(EasyMock.capture(capturedKeys), EasyMock.anyObject());
+ EasyMock.expectLastCall().anyTimes();
+
+ EasyMock.replay(context);
+
+ MockStep1Mapper mapper = new MockStep1Mapper(treeBuilder, dataset, null,
+ partition, NUM_MAPPERS, NUM_TREES);
+
+ // make sure the mapper computed firstTreeId correctly
+ assertEquals(treeIndex, mapper.getFirstTreeId());
+
+ for (int index = 0; index < split.length; index++) {
+ key.set(index);
+ value.set(split[index]);
+ mapper.map(key, value, context);
+ }
+
+ mapper.cleanup(context);
+ EasyMock.verify(context);
+
+ // make sure the mapper built all its trees
+ assertEquals(mapNbTrees, capturedKeys.getValues().size());
+
+ // check the returned keys
+ for (TreeID k : capturedKeys.getValues()) {
+ assertEquals(partition, k.partition());
+ assertEquals(treeIndex, k.treeId());
+
+ treeIndex++;
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java
new file mode 100644
index 0000000..d3c30d4
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/TreeIDTest.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.mapreduce.partial;
+
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Test;
+
+public final class TreeIDTest extends MahoutTestCase {
+
+ @Test
+ public void testTreeID() {
+ Random rng = RandomUtils.getRandom();
+
+ for (int nloop = 0; nloop < 1000000; nloop++) {
+ int partition = Math.abs(rng.nextInt());
+ int treeId = rng.nextInt(TreeID.MAX_TREEID);
+
+ TreeID t1 = new TreeID(partition, treeId);
+
+ assertEquals(partition, t1.partition());
+ assertEquals(treeId, t1.treeId());
+
+ TreeID t2 = new TreeID();
+ t2.set(partition, treeId);
+
+ assertEquals(partition, t2.partition());
+ assertEquals(treeId, t2.treeId());
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java
new file mode 100644
index 0000000..236a2e0
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java
@@ -0,0 +1,108 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.node;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.DataInput;
+import java.io.DataInputStream;
+import java.io.DataOutput;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+public final class NodeTest extends MahoutTestCase {
+
+ private Random rng;
+
+ private ByteArrayOutputStream byteOutStream;
+ private DataOutput out;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ rng = RandomUtils.getRandom();
+
+ byteOutStream = new ByteArrayOutputStream();
+ out = new DataOutputStream(byteOutStream);
+ }
+
+ /**
+ * Test method for
+ * {@link org.apache.mahout.classifier.df.node.Node#read(java.io.DataInput)}.
+ */
+ @Test
+ public void testReadTree() throws Exception {
+ Node node1 = new CategoricalNode(rng.nextInt(),
+ new double[] { rng.nextDouble(), rng.nextDouble() },
+ new Node[] { new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()) });
+ Node node2 = new NumericalNode(rng.nextInt(), rng.nextDouble(),
+ new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()));
+
+ Node root = new CategoricalNode(rng.nextInt(),
+ new double[] { rng.nextDouble(), rng.nextDouble(), rng.nextDouble() },
+ new Node[] { node1, node2, new Leaf(rng.nextDouble()) });
+
+ // write the node to a DataOutput
+ root.write(out);
+
+ // read the node back
+ assertEquals(root, readNode());
+ }
+
+ Node readNode() throws IOException {
+ ByteArrayInputStream byteInStream = new ByteArrayInputStream(byteOutStream.toByteArray());
+ DataInput in = new DataInputStream(byteInStream);
+ return Node.read(in);
+ }
+
+ @Test
+ public void testReadLeaf() throws Exception {
+
+ Node leaf = new Leaf(rng.nextDouble());
+ leaf.write(out);
+ assertEquals(leaf, readNode());
+ }
+
+ @Test
+ public void testParseNumerical() throws Exception {
+
+ Node node = new NumericalNode(rng.nextInt(), rng.nextDouble(), new Leaf(rng
+ .nextInt()), new Leaf(rng.nextDouble()));
+ node.write(out);
+ assertEquals(node, readNode());
+ }
+
+ @Test
+ public void testCategoricalNode() throws Exception {
+
+ Node node = new CategoricalNode(rng.nextInt(), new double[]{rng.nextDouble(),
+ rng.nextDouble(), rng.nextDouble()}, new Node[]{
+ new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()),
+ new Leaf(rng.nextDouble())});
+
+ node.write(out);
+ assertEquals(node, readNode());
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java
new file mode 100644
index 0000000..c5eb635
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/split/DefaultIgSplitTest.java
@@ -0,0 +1,78 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import java.util.Random;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Utils;
+import org.junit.Test;
+
+public final class DefaultIgSplitTest extends MahoutTestCase {
+
+ private static final int NUM_ATTRIBUTES = 10;
+
+ @Test
+ public void testEntropy() throws Exception {
+ Random rng = RandomUtils.getRandom();
+ String descriptor = Utils.randomDescriptor(rng, NUM_ATTRIBUTES);
+ int label = Utils.findLabel(descriptor);
+
+ // all the vectors have the same label (0)
+ double[][] temp = Utils.randomDoublesWithSameLabel(rng, descriptor, false, 100, 0);
+ String[] sData = Utils.double2String(temp);
+ Dataset dataset = DataLoader.generateDataset(descriptor, false, sData);
+ Data data = DataLoader.loadData(dataset, sData);
+ DefaultIgSplit iG = new DefaultIgSplit();
+
+ double expected = 0.0 - 1.0 * Math.log(1.0) / Math.log(2.0);
+ assertEquals(expected, iG.entropy(data), EPSILON);
+
+ // 50/100 of the vectors have the label (1)
+ // 50/100 of the vectors have the label (0)
+ for (int index = 0; index < 50; index++) {
+ temp[index][label] = 1.0;
+ }
+ sData = Utils.double2String(temp);
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
+ data = DataLoader.loadData(dataset, sData);
+ iG = new DefaultIgSplit();
+
+ expected = 2.0 * -0.5 * Math.log(0.5) / Math.log(2.0);
+ assertEquals(expected, iG.entropy(data), EPSILON);
+
+ // 15/100 of the vectors have the label (2)
+ // 35/100 of the vectors have the label (1)
+ // 50/100 of the vectors have the label (0)
+ for (int index = 0; index < 15; index++) {
+ temp[index][label] = 2.0;
+ }
+ sData = Utils.double2String(temp);
+ dataset = DataLoader.generateDataset(descriptor, false, sData);
+ data = DataLoader.loadData(dataset, sData);
+ iG = new DefaultIgSplit();
+
+ expected = -0.15 * Math.log(0.15) / Math.log(2.0) - 0.35 * Math.log(0.35)
+ / Math.log(2.0) - 0.5 * Math.log(0.5) / Math.log(2.0);
+ assertEquals(expected, iG.entropy(data), EPSILON);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java
new file mode 100644
index 0000000..dbd1ef7
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/split/RegressionSplitTest.java
@@ -0,0 +1,87 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.split;
+
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.DescriptorException;
+import org.apache.mahout.classifier.df.data.conditions.Condition;
+import org.apache.mahout.common.MahoutTestCase;
+import org.junit.Test;
+
+public final class RegressionSplitTest extends MahoutTestCase {
+
+ private static Data[] generateTrainingData() throws DescriptorException {
+ // Training data
+ String[] trainData = new String[20];
+ for (int i = 0; i < trainData.length; i++) {
+ if (i % 3 == 0) {
+ trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+ } else if (i % 3 == 1) {
+ trainData[i] = "B," + (i + 20) + ',' + (40 - i);
+ } else {
+ trainData[i] = "C," + (i + 20) + ',' + (i + 20);
+ }
+ }
+ // Dataset
+ Dataset dataset = DataLoader.generateDataset("C N L", true, trainData);
+ Data[] datas = new Data[3];
+ datas[0] = DataLoader.loadData(dataset, trainData);
+
+ // Training data
+ trainData = new String[20];
+ for (int i = 0; i < trainData.length; i++) {
+ if (i % 2 == 0) {
+ trainData[i] = "A," + (50 - i) + ',' + (i + 10);
+ } else {
+ trainData[i] = "B," + (i + 10) + ',' + (50 - i);
+ }
+ }
+ datas[1] = DataLoader.loadData(dataset, trainData);
+
+ // Training data
+ trainData = new String[10];
+ for (int i = 0; i < trainData.length; i++) {
+ trainData[i] = "A," + (40 - i) + ',' + (i + 20);
+ }
+ datas[2] = DataLoader.loadData(dataset, trainData);
+
+ return datas;
+ }
+
+ @Test
+ public void testComputeSplit() throws DescriptorException {
+ Data[] datas = generateTrainingData();
+
+ RegressionSplit igSplit = new RegressionSplit();
+ Split split = igSplit.computeSplit(datas[0], 1);
+ assertEquals(180.0, split.getIg(), EPSILON);
+ assertEquals(38.0, split.getSplit(), EPSILON);
+ split = igSplit.computeSplit(datas[0].subset(Condition.lesser(1, 38.0)), 1);
+ assertEquals(76.5, split.getIg(), EPSILON);
+ assertEquals(21.5, split.getSplit(), EPSILON);
+
+ split = igSplit.computeSplit(datas[1], 0);
+ assertEquals(2205.0, split.getIg(), EPSILON);
+ assertEquals(Double.NaN, split.getSplit(), EPSILON);
+ split = igSplit.computeSplit(datas[1].subset(Condition.equals(0, 0.0)), 1);
+ assertEquals(250.0, split.getIg(), EPSILON);
+ assertEquals(41.0, split.getSplit(), EPSILON);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java b/mr/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
new file mode 100644
index 0000000..482682d
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java
@@ -0,0 +1,211 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.df.tools;
+
+import java.util.List;
+import java.util.Random;
+
+import org.apache.mahout.classifier.df.DecisionForest;
+import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
+import org.apache.mahout.classifier.df.data.Data;
+import org.apache.mahout.classifier.df.data.DataLoader;
+import org.apache.mahout.classifier.df.data.Dataset;
+import org.apache.mahout.classifier.df.data.Instance;
+import org.apache.mahout.classifier.df.node.CategoricalNode;
+import org.apache.mahout.classifier.df.node.Leaf;
+import org.apache.mahout.classifier.df.node.Node;
+import org.apache.mahout.classifier.df.node.NumericalNode;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.common.collect.Lists;
+
+public final class VisualizerTest extends MahoutTestCase {
+
+ private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no",
+ "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes",
+ "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no",
+ "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no",
+ "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes",
+ "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes",
+ "rainy,71,91,TRUE,no"};
+
+ private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-",
+ "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",};
+
+ private static final String[] ATTR_NAMES = {"outlook", "temperature",
+ "humidity", "windy", "play"};
+
+ private Random rng;
+
+ private Data data;
+
+ private Data testData;
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+
+ rng = RandomUtils.getRandom(1);
+
+ // Dataset
+ Dataset dataset = DataLoader
+ .generateDataset("C N N C L", false, TRAIN_DATA);
+
+ // Training data
+ data = DataLoader.loadData(dataset, TRAIN_DATA);
+
+ // Test data
+ testData = DataLoader.loadData(dataset, TEST_DATA);
+ }
+
+ @Test
+ public void testTreeVisualize() throws Exception {
+ // build tree
+ DecisionTreeBuilder builder = new DecisionTreeBuilder();
+ builder.setM(data.getDataset().nbAttributes() - 1);
+ Node tree = builder.build(rng, data);
+
+ String visualization = TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES);
+
+ assertTrue(
+ ("\n" +
+ "outlook = rainy\n" +
+ "| windy = FALSE : yes\n" +
+ "| windy = TRUE : no\n" +
+ "outlook = sunny\n" +
+ "| humidity < 77.5 : yes\n" +
+ "| humidity >= 77.5 : no\n" +
+ "outlook = overcast : yes").equals(visualization) ||
+ ("\n" +
+ "outlook = rainy\n" +
+ "| windy = TRUE : no\n" +
+ "| windy = FALSE : yes\n" +
+ "outlook = overcast : yes\n" +
+ "outlook = sunny\n" +
+ "| humidity < 77.5 : yes\n" +
+ "| humidity >= 77.5 : no").equals(visualization));
+ }
+
+ @Test
+ public void testPredictTrace() throws Exception {
+ // build tree
+ DecisionTreeBuilder builder = new DecisionTreeBuilder();
+ builder.setM(data.getDataset().nbAttributes() - 1);
+ Node tree = builder.build(rng, data);
+
+ String[] prediction = TreeVisualizer.predictTrace(tree, testData,
+ ATTR_NAMES);
+ Assert.assertArrayEquals(new String[] {
+ "outlook = rainy -> windy = TRUE -> no", "outlook = overcast -> yes",
+ "outlook = sunny -> (humidity = 90) >= 77.5 -> no"}, prediction);
+ }
+
+ @Test
+ public void testForestVisualize() throws Exception {
+ // Tree
+ NumericalNode root = new NumericalNode(2, 90, new Leaf(0),
+ new CategoricalNode(0, new double[] {0, 1, 2}, new Node[] {
+ new NumericalNode(1, 71, new Leaf(0), new Leaf(1)), new Leaf(1),
+ new Leaf(0)}));
+ List<Node> trees = Lists.newArrayList();
+ trees.add(root);
+
+ // Forest
+ DecisionForest forest = new DecisionForest(trees);
+ String visualization = ForestVisualizer.toString(forest, data.getDataset(), null);
+ assertTrue(
+ ("Tree[1]:\n2 < 90 : yes\n2 >= 90\n" +
+ "| 0 = rainy\n" +
+ "| | 1 < 71 : yes\n" +
+ "| | 1 >= 71 : no\n" +
+ "| 0 = sunny : no\n" +
+ "| 0 = overcast : yes\n").equals(visualization) ||
+ ("Tree[1]:\n" +
+ "2 < 90 : no\n" +
+ "2 >= 90\n" +
+ "| 0 = rainy\n" +
+ "| | 1 < 71 : no\n" +
+ "| | 1 >= 71 : yes\n" +
+ "| 0 = overcast : yes\n" +
+ "| 0 = sunny : no\n").equals(visualization));
+
+ visualization = ForestVisualizer.toString(forest, data.getDataset(), ATTR_NAMES);
+ assertTrue(
+ ("Tree[1]:\n" +
+ "humidity < 90 : yes\n" +
+ "humidity >= 90\n" +
+ "| outlook = rainy\n" +
+ "| | temperature < 71 : yes\n" +
+ "| | temperature >= 71 : no\n" +
+ "| outlook = sunny : no\n" +
+ "| outlook = overcast : yes\n").equals(visualization) ||
+ ("Tree[1]:\n" +
+ "humidity < 90 : no\n" +
+ "humidity >= 90\n" +
+ "| outlook = rainy\n" +
+ "| | temperature < 71 : no\n" +
+ "| | temperature >= 71 : yes\n" +
+ "| outlook = overcast : yes\n" +
+ "| outlook = sunny : no\n").equals(visualization));
+ }
+
+ @Test
+ public void testLeafless() throws Exception {
+ List<Instance> instances = Lists.newArrayList();
+ for (int i = 0; i < data.size(); i++) {
+ if (data.get(i).get(0) != 0.0d) {
+ instances.add(data.get(i));
+ }
+ }
+ Data lessData = new Data(data.getDataset(), instances);
+
+ // build tree
+ DecisionTreeBuilder builder = new DecisionTreeBuilder();
+ builder.setM(data.getDataset().nbAttributes() - 1);
+ builder.setMinSplitNum(0);
+ builder.setComplemented(false);
+ Node tree = builder.build(rng, lessData);
+
+ String visualization = TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES);
+ assertTrue(
+ ("\noutlook = sunny\n" +
+ "| humidity < 77.5 : yes\n" +
+ "| humidity >= 77.5 : no\n" +
+ "outlook = overcast : yes").equals(visualization) ||
+ ("\noutlook = overcast : yes\n" +
+ "outlook = sunny\n" +
+ "| humidity < 77.5 : yes\n" +
+ "| humidity >= 77.5 : no").equals(visualization));
+ }
+
+ @Test
+ public void testEmpty() throws Exception {
+ Data emptyData = new Data(data.getDataset());
+
+ // build tree
+ DecisionTreeBuilder builder = new DecisionTreeBuilder();
+ Node tree = builder.build(rng, emptyData);
+
+ assertEquals(" : unknown", TreeVisualizer.toString(tree, data.getDataset(), ATTR_NAMES));
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java
----------------------------------------------------------------------
diff --git a/mr/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java b/mr/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java
new file mode 100644
index 0000000..66fe97b
--- /dev/null
+++ b/mr/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.evaluation;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.jet.random.Normal;
+import org.junit.Test;
+
+import java.util.Random;
+
+public final class AucTest extends MahoutTestCase {
+
+ @Test
+ public void testAuc() {
+ Auc auc = new Auc();
+ Random gen = RandomUtils.getRandom();
+ auc.setProbabilityScore(false);
+ for (int i=0; i<100000; i++) {
+ auc.add(0, gen.nextGaussian());
+ auc.add(1, gen.nextGaussian() + 1);
+ }
+ assertEquals(0.76, auc.auc(), 0.01);
+ }
+
+ @Test
+ public void testTies() {
+ Auc auc = new Auc();
+ Random gen = RandomUtils.getRandom();
+ auc.setProbabilityScore(false);
+ for (int i=0; i<100000; i++) {
+ auc.add(0, gen.nextGaussian());
+ auc.add(1, gen.nextGaussian() + 1);
+ }
+
+ // ties outside the normal range could cause index out of range
+ auc.add(0, 5.0);
+ auc.add(0, 5.0);
+ auc.add(0, 5.0);
+ auc.add(0, 5.0);
+
+ auc.add(1, 5.0);
+ auc.add(1, 5.0);
+ auc.add(1, 5.0);
+
+ assertEquals(0.76, auc.auc(), 0.05);
+ }
+
+ @Test
+ public void testEntropy() {
+ Auc auc = new Auc();
+ Random gen = RandomUtils.getRandom();
+ Normal n0 = new Normal(-1, 1, gen);
+ Normal n1 = new Normal(1, 1, gen);
+ for (int i=0; i<100000; i++) {
+ double score = n0.nextDouble();
+ double p = n1.pdf(score) / (n0.pdf(score) + n1.pdf(score));
+ auc.add(0, p);
+
+ score = n1.nextDouble();
+ p = n1.pdf(score) / (n0.pdf(score) + n1.pdf(score));
+ auc.add(1, p);
+ }
+ Matrix m = auc.entropy();
+ assertEquals(-0.35, m.get(0, 0), 0.02);
+ assertEquals(-2.36, m.get(0, 1), 0.02);
+ assertEquals(-2.36, m.get(1, 0), 0.02);
+ assertEquals(-0.35, m.get(1, 1), 0.02);
+ }
+}