You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by pa...@apache.org on 2015/04/01 20:07:51 UTC
[20/51] [partial] mahout git commit: MAHOUT-1655 Refactors mr-legacy
into mahout-hdfs and mahout-mr, closes apache/mahout#86
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/HdfsBackedLanczosState.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/HdfsBackedLanczosState.java b/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/HdfsBackedLanczosState.java
new file mode 100644
index 0000000..f1874a8
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/decomposer/HdfsBackedLanczosState.java
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.decomposer;
+
+import java.io.IOException;
+import java.util.Map;
+
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configurable;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.SequenceFile;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.decomposer.lanczos.LanczosState;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class HdfsBackedLanczosState extends LanczosState implements Configurable {
+
+ private static final Logger log = LoggerFactory.getLogger(HdfsBackedLanczosState.class);
+
+ public static final String BASIS_PREFIX = "basis";
+ public static final String SINGULAR_PREFIX = "singular";
+ //public static final String METADATA_FILE = "metadata";
+
+ private Configuration conf;
+ private final Path baseDir;
+ private final Path basisPath;
+ private final Path singularVectorPath;
+ private FileSystem fs;
+
+ public HdfsBackedLanczosState(VectorIterable corpus, int desiredRank, Vector initialVector, Path dir) {
+ super(corpus, desiredRank, initialVector);
+ baseDir = dir;
+ //Path metadataPath = new Path(dir, METADATA_FILE);
+ basisPath = new Path(dir, BASIS_PREFIX);
+ singularVectorPath = new Path(dir, SINGULAR_PREFIX);
+ if (corpus instanceof Configurable) {
+ setConf(((Configurable)corpus).getConf());
+ }
+ }
+
+ @Override public void setConf(Configuration configuration) {
+ conf = configuration;
+ try {
+ setupDirs();
+ updateHdfsState();
+ } catch (IOException e) {
+ log.error("Could not retrieve filesystem: {}", conf, e);
+ }
+ }
+
+ @Override public Configuration getConf() {
+ return conf;
+ }
+
+ private void setupDirs() throws IOException {
+ fs = baseDir.getFileSystem(conf);
+ createDirIfNotExist(baseDir);
+ createDirIfNotExist(basisPath);
+ createDirIfNotExist(singularVectorPath);
+ }
+
+ private void createDirIfNotExist(Path path) throws IOException {
+ if (!fs.exists(path) && !fs.mkdirs(path)) {
+ throw new IOException("Unable to create: " + path);
+ }
+ }
+
+ @Override
+ public void setIterationNumber(int i) {
+ super.setIterationNumber(i);
+ try {
+ updateHdfsState();
+ } catch (IOException e) {
+ log.error("Could not update HDFS state: ", e);
+ }
+ }
+
+ protected void updateHdfsState() throws IOException {
+ if (conf == null) {
+ return;
+ }
+ int numBasisVectorsOnDisk = 0;
+ Path nextBasisVectorPath = new Path(basisPath, BASIS_PREFIX + '_' + numBasisVectorsOnDisk);
+ while (fs.exists(nextBasisVectorPath)) {
+ nextBasisVectorPath = new Path(basisPath, BASIS_PREFIX + '_' + ++numBasisVectorsOnDisk);
+ }
+ Vector nextVector;
+ while (numBasisVectorsOnDisk < iterationNumber
+ && (nextVector = getBasisVector(numBasisVectorsOnDisk)) != null) {
+ persistVector(nextBasisVectorPath, numBasisVectorsOnDisk, nextVector);
+ nextBasisVectorPath = new Path(basisPath, BASIS_PREFIX + '_' + ++numBasisVectorsOnDisk);
+ }
+ if (scaleFactor <= 0) {
+ scaleFactor = getScaleFactor(); // load from disk if possible
+ }
+ diagonalMatrix = getDiagonalMatrix(); // load from disk if possible
+ Vector norms = new DenseVector(diagonalMatrix.numCols() - 1);
+ Vector projections = new DenseVector(diagonalMatrix.numCols());
+ int i = 0;
+ while (i < diagonalMatrix.numCols() - 1) {
+ norms.set(i, diagonalMatrix.get(i, i + 1));
+ projections.set(i, diagonalMatrix.get(i, i));
+ i++;
+ }
+ projections.set(i, diagonalMatrix.get(i, i));
+ persistVector(new Path(baseDir, "projections"), 0, projections);
+ persistVector(new Path(baseDir, "norms"), 0, norms);
+ persistVector(new Path(baseDir, "scaleFactor"), 0, new DenseVector(new double[] {scaleFactor}));
+ for (Map.Entry<Integer, Vector> entry : singularVectors.entrySet()) {
+ persistVector(new Path(singularVectorPath, SINGULAR_PREFIX + '_' + entry.getKey()),
+ entry.getKey(), entry.getValue());
+ }
+ super.setIterationNumber(numBasisVectorsOnDisk);
+ }
+
+ protected void persistVector(Path p, int key, Vector vector) throws IOException {
+ SequenceFile.Writer writer = null;
+ try {
+ if (fs.exists(p)) {
+ log.warn("{} exists, will overwrite", p);
+ fs.delete(p, true);
+ }
+ writer = new SequenceFile.Writer(fs, conf, p,
+ IntWritable.class, VectorWritable.class);
+ writer.append(new IntWritable(key), new VectorWritable(vector));
+ } finally {
+ Closeables.close(writer, false);
+ }
+ }
+
+ protected Vector fetchVector(Path p, int keyIndex) throws IOException {
+ if (!fs.exists(p)) {
+ return null;
+ }
+ SequenceFile.Reader reader = new SequenceFile.Reader(fs, p, conf);
+ IntWritable key = new IntWritable();
+ VectorWritable vw = new VectorWritable();
+ while (reader.next(key, vw)) {
+ if (key.get() == keyIndex) {
+ return vw.get();
+ }
+ }
+ return null;
+ }
+
+ @Override
+ public Vector getBasisVector(int i) {
+ if (!basis.containsKey(i)) {
+ try {
+ Vector v = fetchVector(new Path(basisPath, BASIS_PREFIX + '_' + i), i);
+ basis.put(i, v);
+ } catch (IOException e) {
+ log.error("Could not load basis vector: {}", i, e);
+ }
+ }
+ return super.getBasisVector(i);
+ }
+
+ @Override
+ public Vector getRightSingularVector(int i) {
+ if (!singularVectors.containsKey(i)) {
+ try {
+ Vector v = fetchVector(new Path(singularVectorPath, BASIS_PREFIX + '_' + i), i);
+ singularVectors.put(i, v);
+ } catch (IOException e) {
+ log.error("Could not load singular vector: {}", i, e);
+ }
+ }
+ return super.getRightSingularVector(i);
+ }
+
+ @Override
+ public double getScaleFactor() {
+ if (scaleFactor <= 0) {
+ try {
+ Vector v = fetchVector(new Path(baseDir, "scaleFactor"), 0);
+ if (v != null && v.size() > 0) {
+ scaleFactor = v.get(0);
+ }
+ } catch (IOException e) {
+ log.error("could not load scaleFactor:", e);
+ }
+ }
+ return scaleFactor;
+ }
+
+ @Override
+ public Matrix getDiagonalMatrix() {
+ if (diagonalMatrix == null) {
+ diagonalMatrix = new DenseMatrix(desiredRank, desiredRank);
+ }
+ if (diagonalMatrix.get(0, 1) <= 0) {
+ try {
+ Vector norms = fetchVector(new Path(baseDir, "norms"), 0);
+ Vector projections = fetchVector(new Path(baseDir, "projections"), 0);
+ if (norms != null && projections != null) {
+ int i = 0;
+ while (i < projections.size() - 1) {
+ diagonalMatrix.set(i, i, projections.get(i));
+ diagonalMatrix.set(i, i + 1, norms.get(i));
+ diagonalMatrix.set(i + 1, i, norms.get(i));
+ i++;
+ }
+ diagonalMatrix.set(i, i, projections.get(i));
+ }
+ } catch (IOException e) {
+ log.error("Could not load diagonal matrix of norms and projections: ", e);
+ }
+ }
+ return diagonalMatrix;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java
new file mode 100644
index 0000000..9119f69
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java
@@ -0,0 +1,104 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity;
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.kmeans.Kluster;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.Collections;
+import java.util.List;
+
+final class SeedVectorUtil {
+
+ private static final Logger log = LoggerFactory.getLogger(SeedVectorUtil.class);
+
+ private SeedVectorUtil() {
+ }
+
+ public static List<NamedVector> loadSeedVectors(Configuration conf) {
+
+ String seedPathStr = conf.get(VectorDistanceSimilarityJob.SEEDS_PATH_KEY);
+ if (seedPathStr == null || seedPathStr.isEmpty()) {
+ return Collections.emptyList();
+ }
+
+ List<NamedVector> seedVectors = Lists.newArrayList();
+ long item = 0;
+ for (Writable value
+ : new SequenceFileDirValueIterable<>(new Path(seedPathStr),
+ PathType.LIST,
+ PathFilters.partFilter(),
+ conf)) {
+ Class<? extends Writable> valueClass = value.getClass();
+ if (valueClass.equals(Kluster.class)) {
+ // get the cluster info
+ Kluster cluster = (Kluster) value;
+ Vector vector = cluster.getCenter();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, cluster.getIdentifier()));
+ }
+ } else if (valueClass.equals(Canopy.class)) {
+ // get the cluster info
+ Canopy canopy = (Canopy) value;
+ Vector vector = canopy.getCenter();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, canopy.getIdentifier()));
+ }
+ } else if (valueClass.equals(Vector.class)) {
+ Vector vector = (Vector) value;
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, seedPathStr + '.' + item++));
+ }
+ } else if (valueClass.equals(VectorWritable.class) || valueClass.isInstance(VectorWritable.class)) {
+ VectorWritable vw = (VectorWritable) value;
+ Vector vector = vw.get();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, seedPathStr + '.' + item++));
+ }
+ } else {
+ throw new IllegalStateException("Bad value class: " + valueClass);
+ }
+ }
+ if (seedVectors.isEmpty()) {
+ throw new IllegalStateException("No seeds found. Check your path: " + seedPathStr);
+ }
+ log.info("Seed Vectors size: {}", seedVectors.size());
+ return seedVectors;
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java
new file mode 100644
index 0000000..c45d55a
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java
@@ -0,0 +1,71 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.List;
+
+/**
+ * Similar to {@link VectorDistanceMapper}, except it outputs
+ * <input, Vector>, where the vector is a dense vector contain one entry for every seed vector
+ */
+public final class VectorDistanceInvertedMapper
+ extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
+
+ private DistanceMeasure measure;
+ private List<NamedVector> seedVectors;
+
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ String keyName;
+ Vector valVec = value.get();
+ if (valVec instanceof NamedVector) {
+ keyName = ((NamedVector) valVec).getName();
+ } else {
+ keyName = key.toString();
+ }
+ Vector outVec = new DenseVector(new double[seedVectors.size()]);
+ int i = 0;
+ for (NamedVector seedVector : seedVectors) {
+ outVec.setQuick(i++, measure.distance(seedVector, valVec));
+ }
+ context.write(new Text(keyName), new VectorWritable(outVec));
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+ measure =
+ ClassUtils.instantiateAs(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY), DistanceMeasure.class);
+ measure.configure(conf);
+ seedVectors = SeedVectorUtil.loadSeedVectors(conf);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
new file mode 100644
index 0000000..9fccd8e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.WritableComparable;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.List;
+
+public final class VectorDistanceMapper
+ extends Mapper<WritableComparable<?>, VectorWritable, StringTuple, DoubleWritable> {
+
+ private DistanceMeasure measure;
+ private List<NamedVector> seedVectors;
+ private boolean usesThreshold = false;
+ private double maxDistance;
+
+ @Override
+ protected void map(WritableComparable<?> key, VectorWritable value, Context context)
+ throws IOException, InterruptedException {
+ String keyName;
+ Vector valVec = value.get();
+ if (valVec instanceof NamedVector) {
+ keyName = ((NamedVector) valVec).getName();
+ } else {
+ keyName = key.toString();
+ }
+
+ for (NamedVector seedVector : seedVectors) {
+ double distance = measure.distance(seedVector, valVec);
+ if (!usesThreshold || distance <= maxDistance) {
+ StringTuple outKey = new StringTuple();
+ outKey.add(seedVector.getName());
+ outKey.add(keyName);
+ context.write(outKey, new DoubleWritable(distance));
+ }
+ }
+ }
+
+ @Override
+ protected void setup(Context context) throws IOException, InterruptedException {
+ super.setup(context);
+ Configuration conf = context.getConfiguration();
+
+ String maxDistanceParam = conf.get(VectorDistanceSimilarityJob.MAX_DISTANCE);
+ if (maxDistanceParam != null) {
+ usesThreshold = true;
+ maxDistance = Double.parseDouble(maxDistanceParam);
+ }
+
+ measure = ClassUtils.instantiateAs(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY),
+ DistanceMeasure.class);
+ measure.configure(conf);
+ seedVectors = SeedVectorUtil.loadSeedVectors(conf);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
new file mode 100644
index 0000000..9f58f1e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
@@ -0,0 +1,153 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
+import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
+import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.StringTuple;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.math.VectorWritable;
+
+import com.google.common.base.Preconditions;
+
+import java.io.IOException;
+
+/**
+ * This class does a Map-side join between seed vectors (the map side can also be a Cluster) and a list of other vectors
+ * and emits the a tuple of seed id, other id, distance. It is a more generic version of KMean's mapper
+ */
+public class VectorDistanceSimilarityJob extends AbstractJob {
+
+ public static final String SEEDS = "seeds";
+ public static final String SEEDS_PATH_KEY = "seedsPath";
+ public static final String DISTANCE_MEASURE_KEY = "vectorDistSim.measure";
+ public static final String OUT_TYPE_KEY = "outType";
+ public static final String MAX_DISTANCE = "maxDistance";
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.distanceMeasureOption().create());
+ addOption(SEEDS, "s", "The set of vectors to compute distances against. Must fit in memory on the mapper");
+ addOption(MAX_DISTANCE, "mx", "set an upper-bound on distance (double) such that any pair of vectors with a"
+ + " distance greater than this value is ignored in the output. Ignored for non pairwise output!");
+ addOption(DefaultOptionCreator.overwriteOption().create());
+ addOption(OUT_TYPE_KEY, "ot", "[pw|v] -- Define the output style: pairwise, the default, (pw) or vector (v). "
+ + "Pairwise is a tuple of <seed, other, distance>, vector is <other, <Vector of size the number of seeds>>.",
+ "pw");
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path output = getOutputPath();
+ Path seeds = new Path(getOption(SEEDS));
+ String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION);
+ if (measureClass == null) {
+ measureClass = SquaredEuclideanDistanceMeasure.class.getName();
+ }
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ HadoopUtil.delete(getConf(), output);
+ }
+ DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class);
+ String outType = getOption(OUT_TYPE_KEY, "pw");
+
+ Double maxDistance = null;
+
+ if ("pw".equals(outType)) {
+ String maxDistanceArg = getOption(MAX_DISTANCE);
+ if (maxDistanceArg != null) {
+ maxDistance = Double.parseDouble(maxDistanceArg);
+ Preconditions.checkArgument(maxDistance > 0.0d, "value for " + MAX_DISTANCE + " must be greater than zero");
+ }
+ }
+
+ run(getConf(), input, seeds, output, measure, outType, maxDistance);
+ return 0;
+ }
+
+ public static void run(Configuration conf,
+ Path input,
+ Path seeds,
+ Path output,
+ DistanceMeasure measure, String outType)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ run(conf, input, seeds, output, measure, outType, null);
+ }
+
+ public static void run(Configuration conf,
+ Path input,
+ Path seeds,
+ Path output,
+ DistanceMeasure measure, String outType, Double maxDistance)
+ throws IOException, ClassNotFoundException, InterruptedException {
+ if (maxDistance != null) {
+ conf.set(MAX_DISTANCE, String.valueOf(maxDistance));
+ }
+ conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName());
+ conf.set(SEEDS_PATH_KEY, seeds.toString());
+ Job job = new Job(conf, "Vector Distance Similarity: seeds: " + seeds + " input: " + input);
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+ if ("pw".equalsIgnoreCase(outType)) {
+ job.setMapOutputKeyClass(StringTuple.class);
+ job.setOutputKeyClass(StringTuple.class);
+ job.setMapOutputValueClass(DoubleWritable.class);
+ job.setOutputValueClass(DoubleWritable.class);
+ job.setMapperClass(VectorDistanceMapper.class);
+ } else if ("v".equalsIgnoreCase(outType)) {
+ job.setMapOutputKeyClass(Text.class);
+ job.setOutputKeyClass(Text.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setMapperClass(VectorDistanceInvertedMapper.class);
+ } else {
+ throw new IllegalArgumentException("Invalid outType specified: " + outType);
+ }
+
+ job.setNumReduceTasks(0);
+ FileInputFormat.addInputPath(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+
+ job.setJarByClass(VectorDistanceSimilarityJob.class);
+ HadoopUtil.delete(conf, output);
+ if (!job.waitForCompletion(true)) {
+ throw new IllegalStateException("VectorDistance Similarity failed processing " + seeds);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/MutableElement.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/MutableElement.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/MutableElement.java
new file mode 100644
index 0000000..ecd0d94
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/MutableElement.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence;
+
+import org.apache.mahout.math.Vector;
+
+public class MutableElement implements Vector.Element {
+
+ private int index;
+ private double value;
+
+ MutableElement(int index, double value) {
+ this.index = index;
+ this.value = value;
+ }
+
+ @Override
+ public double get() {
+ return value;
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ public void setIndex(int index) {
+ this.index = index;
+ }
+
+ @Override
+ public void set(double value) {
+ this.value = value;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob.java
new file mode 100644
index 0000000..fb28821
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/RowSimilarityJob.java
@@ -0,0 +1,562 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence;
+
+import com.google.common.base.Preconditions;
+import com.google.common.primitives.Ints;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.NullWritable;
+import org.apache.hadoop.mapreduce.Job;
+import org.apache.hadoop.mapreduce.Mapper;
+import org.apache.hadoop.mapreduce.Reducer;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.common.AbstractJob;
+import org.apache.mahout.common.ClassUtils;
+import org.apache.mahout.common.HadoopUtil;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.common.mapreduce.VectorSumCombiner;
+import org.apache.mahout.common.mapreduce.VectorSumReducer;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasures;
+import org.apache.mahout.math.hadoop.similarity.cooccurrence.measures.VectorSimilarityMeasure;
+import org.apache.mahout.math.map.OpenIntIntHashMap;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+public class RowSimilarityJob extends AbstractJob {
+
+ public static final double NO_THRESHOLD = Double.MIN_VALUE;
+ public static final long NO_FIXED_RANDOM_SEED = Long.MIN_VALUE;
+
+ private static final String SIMILARITY_CLASSNAME = RowSimilarityJob.class + ".distributedSimilarityClassname";
+ private static final String NUMBER_OF_COLUMNS = RowSimilarityJob.class + ".numberOfColumns";
+ private static final String MAX_SIMILARITIES_PER_ROW = RowSimilarityJob.class + ".maxSimilaritiesPerRow";
+ private static final String EXCLUDE_SELF_SIMILARITY = RowSimilarityJob.class + ".excludeSelfSimilarity";
+
+ private static final String THRESHOLD = RowSimilarityJob.class + ".threshold";
+ private static final String NORMS_PATH = RowSimilarityJob.class + ".normsPath";
+ private static final String MAXVALUES_PATH = RowSimilarityJob.class + ".maxWeightsPath";
+
+ private static final String NUM_NON_ZERO_ENTRIES_PATH = RowSimilarityJob.class + ".nonZeroEntriesPath";
+ private static final int DEFAULT_MAX_SIMILARITIES_PER_ROW = 100;
+
+ private static final String OBSERVATIONS_PER_COLUMN_PATH = RowSimilarityJob.class + ".observationsPerColumnPath";
+
+ private static final String MAX_OBSERVATIONS_PER_ROW = RowSimilarityJob.class + ".maxObservationsPerRow";
+ private static final String MAX_OBSERVATIONS_PER_COLUMN = RowSimilarityJob.class + ".maxObservationsPerColumn";
+ private static final String RANDOM_SEED = RowSimilarityJob.class + ".randomSeed";
+
+ private static final int DEFAULT_MAX_OBSERVATIONS_PER_ROW = 500;
+ private static final int DEFAULT_MAX_OBSERVATIONS_PER_COLUMN = 500;
+
+ private static final int NORM_VECTOR_MARKER = Integer.MIN_VALUE;
+ private static final int MAXVALUE_VECTOR_MARKER = Integer.MIN_VALUE + 1;
+ private static final int NUM_NON_ZERO_ENTRIES_VECTOR_MARKER = Integer.MIN_VALUE + 2;
+
+ enum Counters { ROWS, USED_OBSERVATIONS, NEGLECTED_OBSERVATIONS, COOCCURRENCES, PRUNED_COOCCURRENCES }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new RowSimilarityJob(), args);
+ }
+
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption("numberOfColumns", "r", "Number of columns in the input matrix", false);
+ addOption("similarityClassname", "s", "Name of distributed similarity class to instantiate, alternatively use "
+ + "one of the predefined similarities (" + VectorSimilarityMeasures.list() + ')');
+ addOption("maxSimilaritiesPerRow", "m", "Number of maximum similarities per row (default: "
+ + DEFAULT_MAX_SIMILARITIES_PER_ROW + ')', String.valueOf(DEFAULT_MAX_SIMILARITIES_PER_ROW));
+ addOption("excludeSelfSimilarity", "ess", "compute similarity of rows to themselves?", String.valueOf(false));
+ addOption("threshold", "tr", "discard row pairs with a similarity value below this", false);
+ addOption("maxObservationsPerRow", null, "sample rows down to this number of entries",
+ String.valueOf(DEFAULT_MAX_OBSERVATIONS_PER_ROW));
+ addOption("maxObservationsPerColumn", null, "sample columns down to this number of entries",
+ String.valueOf(DEFAULT_MAX_OBSERVATIONS_PER_COLUMN));
+ addOption("randomSeed", null, "use this seed for sampling", false);
+ addOption(DefaultOptionCreator.overwriteOption().create());
+
+ Map<String,List<String>> parsedArgs = parseArguments(args);
+ if (parsedArgs == null) {
+ return -1;
+ }
+
+ int numberOfColumns;
+
+ if (hasOption("numberOfColumns")) {
+ // Number of columns explicitly specified via CLI
+ numberOfColumns = Integer.parseInt(getOption("numberOfColumns"));
+ } else {
+ // else get the number of columns by determining the cardinality of a vector in the input matrix
+ numberOfColumns = getDimensions(getInputPath());
+ }
+
+ String similarityClassnameArg = getOption("similarityClassname");
+ String similarityClassname;
+ try {
+ similarityClassname = VectorSimilarityMeasures.valueOf(similarityClassnameArg).getClassname();
+ } catch (IllegalArgumentException iae) {
+ similarityClassname = similarityClassnameArg;
+ }
+
+ // Clear the output and temp paths if the overwrite option has been set
+ if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) {
+ // Clear the temp path
+ HadoopUtil.delete(getConf(), getTempPath());
+ // Clear the output path
+ HadoopUtil.delete(getConf(), getOutputPath());
+ }
+
+ int maxSimilaritiesPerRow = Integer.parseInt(getOption("maxSimilaritiesPerRow"));
+ boolean excludeSelfSimilarity = Boolean.parseBoolean(getOption("excludeSelfSimilarity"));
+ double threshold = hasOption("threshold")
+ ? Double.parseDouble(getOption("threshold")) : NO_THRESHOLD;
+ long randomSeed = hasOption("randomSeed")
+ ? Long.parseLong(getOption("randomSeed")) : NO_FIXED_RANDOM_SEED;
+
+ int maxObservationsPerRow = Integer.parseInt(getOption("maxObservationsPerRow"));
+ int maxObservationsPerColumn = Integer.parseInt(getOption("maxObservationsPerColumn"));
+
+ Path weightsPath = getTempPath("weights");
+ Path normsPath = getTempPath("norms.bin");
+ Path numNonZeroEntriesPath = getTempPath("numNonZeroEntries.bin");
+ Path maxValuesPath = getTempPath("maxValues.bin");
+ Path pairwiseSimilarityPath = getTempPath("pairwiseSimilarity");
+
+ Path observationsPerColumnPath = getTempPath("observationsPerColumn.bin");
+
+ AtomicInteger currentPhase = new AtomicInteger();
+
+ Job countObservations = prepareJob(getInputPath(), getTempPath("notUsed"), CountObservationsMapper.class,
+ NullWritable.class, VectorWritable.class, SumObservationsReducer.class, NullWritable.class,
+ VectorWritable.class);
+ countObservations.setCombinerClass(VectorSumCombiner.class);
+ countObservations.getConfiguration().set(OBSERVATIONS_PER_COLUMN_PATH, observationsPerColumnPath.toString());
+ countObservations.setNumReduceTasks(1);
+ countObservations.waitForCompletion(true);
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ Job normsAndTranspose = prepareJob(getInputPath(), weightsPath, VectorNormMapper.class, IntWritable.class,
+ VectorWritable.class, MergeVectorsReducer.class, IntWritable.class, VectorWritable.class);
+ normsAndTranspose.setCombinerClass(MergeVectorsCombiner.class);
+ Configuration normsAndTransposeConf = normsAndTranspose.getConfiguration();
+ normsAndTransposeConf.set(THRESHOLD, String.valueOf(threshold));
+ normsAndTransposeConf.set(NORMS_PATH, normsPath.toString());
+ normsAndTransposeConf.set(NUM_NON_ZERO_ENTRIES_PATH, numNonZeroEntriesPath.toString());
+ normsAndTransposeConf.set(MAXVALUES_PATH, maxValuesPath.toString());
+ normsAndTransposeConf.set(SIMILARITY_CLASSNAME, similarityClassname);
+ normsAndTransposeConf.set(OBSERVATIONS_PER_COLUMN_PATH, observationsPerColumnPath.toString());
+ normsAndTransposeConf.set(MAX_OBSERVATIONS_PER_ROW, String.valueOf(maxObservationsPerRow));
+ normsAndTransposeConf.set(MAX_OBSERVATIONS_PER_COLUMN, String.valueOf(maxObservationsPerColumn));
+ normsAndTransposeConf.set(RANDOM_SEED, String.valueOf(randomSeed));
+
+ boolean succeeded = normsAndTranspose.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ Job pairwiseSimilarity = prepareJob(weightsPath, pairwiseSimilarityPath, CooccurrencesMapper.class,
+ IntWritable.class, VectorWritable.class, SimilarityReducer.class, IntWritable.class, VectorWritable.class);
+ pairwiseSimilarity.setCombinerClass(VectorSumReducer.class);
+ Configuration pairwiseConf = pairwiseSimilarity.getConfiguration();
+ pairwiseConf.set(THRESHOLD, String.valueOf(threshold));
+ pairwiseConf.set(NORMS_PATH, normsPath.toString());
+ pairwiseConf.set(NUM_NON_ZERO_ENTRIES_PATH, numNonZeroEntriesPath.toString());
+ pairwiseConf.set(MAXVALUES_PATH, maxValuesPath.toString());
+ pairwiseConf.set(SIMILARITY_CLASSNAME, similarityClassname);
+ pairwiseConf.setInt(NUMBER_OF_COLUMNS, numberOfColumns);
+ pairwiseConf.setBoolean(EXCLUDE_SELF_SIMILARITY, excludeSelfSimilarity);
+ boolean succeeded = pairwiseSimilarity.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ if (shouldRunNextPhase(parsedArgs, currentPhase)) {
+ Job asMatrix = prepareJob(pairwiseSimilarityPath, getOutputPath(), UnsymmetrifyMapper.class,
+ IntWritable.class, VectorWritable.class, MergeToTopKSimilaritiesReducer.class, IntWritable.class,
+ VectorWritable.class);
+ asMatrix.setCombinerClass(MergeToTopKSimilaritiesReducer.class);
+ asMatrix.getConfiguration().setInt(MAX_SIMILARITIES_PER_ROW, maxSimilaritiesPerRow);
+ boolean succeeded = asMatrix.waitForCompletion(true);
+ if (!succeeded) {
+ return -1;
+ }
+ }
+
+ return 0;
+ }
+
+ public static class CountObservationsMapper extends Mapper<IntWritable,VectorWritable,NullWritable,VectorWritable> {
+
+ private Vector columnCounts = new RandomAccessSparseVector(Integer.MAX_VALUE);
+
+ @Override
+ protected void map(IntWritable rowIndex, VectorWritable rowVectorWritable, Context ctx)
+ throws IOException, InterruptedException {
+
+ Vector row = rowVectorWritable.get();
+ for (Vector.Element elem : row.nonZeroes()) {
+ columnCounts.setQuick(elem.index(), columnCounts.getQuick(elem.index()) + 1);
+ }
+ }
+
+ @Override
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ ctx.write(NullWritable.get(), new VectorWritable(columnCounts));
+ }
+ }
+
+ public static class SumObservationsReducer extends Reducer<NullWritable,VectorWritable,NullWritable,VectorWritable> {
+ @Override
+ protected void reduce(NullWritable nullWritable, Iterable<VectorWritable> partialVectors, Context ctx)
+ throws IOException, InterruptedException {
+ Vector counts = Vectors.sum(partialVectors.iterator());
+ Vectors.write(counts, new Path(ctx.getConfiguration().get(OBSERVATIONS_PER_COLUMN_PATH)), ctx.getConfiguration());
+ }
+ }
+
+ public static class VectorNormMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private VectorSimilarityMeasure similarity;
+ private Vector norms;
+ private Vector nonZeroEntries;
+ private Vector maxValues;
+ private double threshold;
+
+ private OpenIntIntHashMap observationsPerColumn;
+ private int maxObservationsPerRow;
+ private int maxObservationsPerColumn;
+
+ private Random random;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+
+ Configuration conf = ctx.getConfiguration();
+
+ similarity = ClassUtils.instantiateAs(conf.get(SIMILARITY_CLASSNAME), VectorSimilarityMeasure.class);
+ norms = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ nonZeroEntries = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ maxValues = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ threshold = Double.parseDouble(conf.get(THRESHOLD));
+
+ observationsPerColumn = Vectors.readAsIntMap(new Path(conf.get(OBSERVATIONS_PER_COLUMN_PATH)), conf);
+ maxObservationsPerRow = conf.getInt(MAX_OBSERVATIONS_PER_ROW, DEFAULT_MAX_OBSERVATIONS_PER_ROW);
+ maxObservationsPerColumn = conf.getInt(MAX_OBSERVATIONS_PER_COLUMN, DEFAULT_MAX_OBSERVATIONS_PER_COLUMN);
+
+ long seed = Long.parseLong(conf.get(RANDOM_SEED));
+ if (seed == NO_FIXED_RANDOM_SEED) {
+ random = RandomUtils.getRandom();
+ } else {
+ random = RandomUtils.getRandom(seed);
+ }
+ }
+
+ private Vector sampleDown(Vector rowVector, Context ctx) {
+
+ int observationsPerRow = rowVector.getNumNondefaultElements();
+ double rowSampleRate = (double) Math.min(maxObservationsPerRow, observationsPerRow) / (double) observationsPerRow;
+
+ Vector downsampledRow = rowVector.like();
+ long usedObservations = 0;
+ long neglectedObservations = 0;
+
+ for (Vector.Element elem : rowVector.nonZeroes()) {
+
+ int columnCount = observationsPerColumn.get(elem.index());
+ double columnSampleRate = (double) Math.min(maxObservationsPerColumn, columnCount) / (double) columnCount;
+
+ if (random.nextDouble() <= Math.min(rowSampleRate, columnSampleRate)) {
+ downsampledRow.setQuick(elem.index(), elem.get());
+ usedObservations++;
+ } else {
+ neglectedObservations++;
+ }
+
+ }
+
+ ctx.getCounter(Counters.USED_OBSERVATIONS).increment(usedObservations);
+ ctx.getCounter(Counters.NEGLECTED_OBSERVATIONS).increment(neglectedObservations);
+
+ return downsampledRow;
+ }
+
+ @Override
+ protected void map(IntWritable row, VectorWritable vectorWritable, Context ctx)
+ throws IOException, InterruptedException {
+
+ Vector sampledRowVector = sampleDown(vectorWritable.get(), ctx);
+
+ Vector rowVector = similarity.normalize(sampledRowVector);
+
+ int numNonZeroEntries = 0;
+ double maxValue = Double.MIN_VALUE;
+
+ for (Vector.Element element : rowVector.nonZeroes()) {
+ RandomAccessSparseVector partialColumnVector = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ partialColumnVector.setQuick(row.get(), element.get());
+ ctx.write(new IntWritable(element.index()), new VectorWritable(partialColumnVector));
+
+ numNonZeroEntries++;
+ if (maxValue < element.get()) {
+ maxValue = element.get();
+ }
+ }
+
+ if (threshold != NO_THRESHOLD) {
+ nonZeroEntries.setQuick(row.get(), numNonZeroEntries);
+ maxValues.setQuick(row.get(), maxValue);
+ }
+ norms.setQuick(row.get(), similarity.norm(rowVector));
+
+ ctx.getCounter(Counters.ROWS).increment(1);
+ }
+
+ @Override
+ protected void cleanup(Context ctx) throws IOException, InterruptedException {
+ ctx.write(new IntWritable(NORM_VECTOR_MARKER), new VectorWritable(norms));
+ ctx.write(new IntWritable(NUM_NON_ZERO_ENTRIES_VECTOR_MARKER), new VectorWritable(nonZeroEntries));
+ ctx.write(new IntWritable(MAXVALUE_VECTOR_MARKER), new VectorWritable(maxValues));
+ }
+ }
+
+ private static class MergeVectorsCombiner extends Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+ @Override
+ protected void reduce(IntWritable row, Iterable<VectorWritable> partialVectors, Context ctx)
+ throws IOException, InterruptedException {
+ ctx.write(row, new VectorWritable(Vectors.merge(partialVectors)));
+ }
+ }
+
+ public static class MergeVectorsReducer extends Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private Path normsPath;
+ private Path numNonZeroEntriesPath;
+ private Path maxValuesPath;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ normsPath = new Path(ctx.getConfiguration().get(NORMS_PATH));
+ numNonZeroEntriesPath = new Path(ctx.getConfiguration().get(NUM_NON_ZERO_ENTRIES_PATH));
+ maxValuesPath = new Path(ctx.getConfiguration().get(MAXVALUES_PATH));
+ }
+
+ @Override
+ protected void reduce(IntWritable row, Iterable<VectorWritable> partialVectors, Context ctx)
+ throws IOException, InterruptedException {
+ Vector partialVector = Vectors.merge(partialVectors);
+
+ if (row.get() == NORM_VECTOR_MARKER) {
+ Vectors.write(partialVector, normsPath, ctx.getConfiguration());
+ } else if (row.get() == MAXVALUE_VECTOR_MARKER) {
+ Vectors.write(partialVector, maxValuesPath, ctx.getConfiguration());
+ } else if (row.get() == NUM_NON_ZERO_ENTRIES_VECTOR_MARKER) {
+ Vectors.write(partialVector, numNonZeroEntriesPath, ctx.getConfiguration(), true);
+ } else {
+ ctx.write(row, new VectorWritable(partialVector));
+ }
+ }
+ }
+
+
+ public static class CooccurrencesMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private VectorSimilarityMeasure similarity;
+
+ private OpenIntIntHashMap numNonZeroEntries;
+ private Vector maxValues;
+ private double threshold;
+
+ private static final Comparator<Vector.Element> BY_INDEX = new Comparator<Vector.Element>() {
+ @Override
+ public int compare(Vector.Element one, Vector.Element two) {
+ return Ints.compare(one.index(), two.index());
+ }
+ };
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ similarity = ClassUtils.instantiateAs(ctx.getConfiguration().get(SIMILARITY_CLASSNAME),
+ VectorSimilarityMeasure.class);
+ numNonZeroEntries = Vectors.readAsIntMap(new Path(ctx.getConfiguration().get(NUM_NON_ZERO_ENTRIES_PATH)),
+ ctx.getConfiguration());
+ maxValues = Vectors.read(new Path(ctx.getConfiguration().get(MAXVALUES_PATH)), ctx.getConfiguration());
+ threshold = Double.parseDouble(ctx.getConfiguration().get(THRESHOLD));
+ }
+
+ private boolean consider(Vector.Element occurrenceA, Vector.Element occurrenceB) {
+ int numNonZeroEntriesA = numNonZeroEntries.get(occurrenceA.index());
+ int numNonZeroEntriesB = numNonZeroEntries.get(occurrenceB.index());
+
+ double maxValueA = maxValues.get(occurrenceA.index());
+ double maxValueB = maxValues.get(occurrenceB.index());
+
+ return similarity.consider(numNonZeroEntriesA, numNonZeroEntriesB, maxValueA, maxValueB, threshold);
+ }
+
+ @Override
+ protected void map(IntWritable column, VectorWritable occurrenceVector, Context ctx)
+ throws IOException, InterruptedException {
+ Vector.Element[] occurrences = Vectors.toArray(occurrenceVector);
+ Arrays.sort(occurrences, BY_INDEX);
+
+ int cooccurrences = 0;
+ int prunedCooccurrences = 0;
+ for (int n = 0; n < occurrences.length; n++) {
+ Vector.Element occurrenceA = occurrences[n];
+ Vector dots = new RandomAccessSparseVector(Integer.MAX_VALUE);
+ for (int m = n; m < occurrences.length; m++) {
+ Vector.Element occurrenceB = occurrences[m];
+ if (threshold == NO_THRESHOLD || consider(occurrenceA, occurrenceB)) {
+ dots.setQuick(occurrenceB.index(), similarity.aggregate(occurrenceA.get(), occurrenceB.get()));
+ cooccurrences++;
+ } else {
+ prunedCooccurrences++;
+ }
+ }
+ ctx.write(new IntWritable(occurrenceA.index()), new VectorWritable(dots));
+ }
+ ctx.getCounter(Counters.COOCCURRENCES).increment(cooccurrences);
+ ctx.getCounter(Counters.PRUNED_COOCCURRENCES).increment(prunedCooccurrences);
+ }
+ }
+
+
+ public static class SimilarityReducer extends Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private VectorSimilarityMeasure similarity;
+ private int numberOfColumns;
+ private boolean excludeSelfSimilarity;
+ private Vector norms;
+ private double treshold;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ similarity = ClassUtils.instantiateAs(ctx.getConfiguration().get(SIMILARITY_CLASSNAME),
+ VectorSimilarityMeasure.class);
+ numberOfColumns = ctx.getConfiguration().getInt(NUMBER_OF_COLUMNS, -1);
+ Preconditions.checkArgument(numberOfColumns > 0, "Number of columns must be greater then 0! But numberOfColumns = " + numberOfColumns);
+ excludeSelfSimilarity = ctx.getConfiguration().getBoolean(EXCLUDE_SELF_SIMILARITY, false);
+ norms = Vectors.read(new Path(ctx.getConfiguration().get(NORMS_PATH)), ctx.getConfiguration());
+ treshold = Double.parseDouble(ctx.getConfiguration().get(THRESHOLD));
+ }
+
+ @Override
+ protected void reduce(IntWritable row, Iterable<VectorWritable> partialDots, Context ctx)
+ throws IOException, InterruptedException {
+ Iterator<VectorWritable> partialDotsIterator = partialDots.iterator();
+ Vector dots = partialDotsIterator.next().get();
+ while (partialDotsIterator.hasNext()) {
+ Vector toAdd = partialDotsIterator.next().get();
+ for (Element nonZeroElement : toAdd.nonZeroes()) {
+ dots.setQuick(nonZeroElement.index(), dots.getQuick(nonZeroElement.index()) + nonZeroElement.get());
+ }
+ }
+
+ Vector similarities = dots.like();
+ double normA = norms.getQuick(row.get());
+ for (Element b : dots.nonZeroes()) {
+ double similarityValue = similarity.similarity(b.get(), normA, norms.getQuick(b.index()), numberOfColumns);
+ if (similarityValue >= treshold) {
+ similarities.set(b.index(), similarityValue);
+ }
+ }
+ if (excludeSelfSimilarity) {
+ similarities.setQuick(row.get(), 0);
+ }
+ ctx.write(row, new VectorWritable(similarities));
+ }
+ }
+
+ public static class UnsymmetrifyMapper extends Mapper<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private int maxSimilaritiesPerRow;
+
+ @Override
+ protected void setup(Mapper.Context ctx) throws IOException, InterruptedException {
+ maxSimilaritiesPerRow = ctx.getConfiguration().getInt(MAX_SIMILARITIES_PER_ROW, 0);
+ Preconditions.checkArgument(maxSimilaritiesPerRow > 0, "Maximum number of similarities per row must be greater then 0!");
+ }
+
+ @Override
+ protected void map(IntWritable row, VectorWritable similaritiesWritable, Context ctx)
+ throws IOException, InterruptedException {
+ Vector similarities = similaritiesWritable.get();
+ // For performance, the creation of transposedPartial is moved out of the while loop and it is reused inside
+ Vector transposedPartial = new RandomAccessSparseVector(similarities.size(), 1);
+ TopElementsQueue topKQueue = new TopElementsQueue(maxSimilaritiesPerRow);
+ for (Element nonZeroElement : similarities.nonZeroes()) {
+ MutableElement top = topKQueue.top();
+ double candidateValue = nonZeroElement.get();
+ if (candidateValue > top.get()) {
+ top.setIndex(nonZeroElement.index());
+ top.set(candidateValue);
+ topKQueue.updateTop();
+ }
+
+ transposedPartial.setQuick(row.get(), candidateValue);
+ ctx.write(new IntWritable(nonZeroElement.index()), new VectorWritable(transposedPartial));
+ transposedPartial.setQuick(row.get(), 0.0);
+ }
+ Vector topKSimilarities = new RandomAccessSparseVector(similarities.size(), maxSimilaritiesPerRow);
+ for (Vector.Element topKSimilarity : topKQueue.getTopElements()) {
+ topKSimilarities.setQuick(topKSimilarity.index(), topKSimilarity.get());
+ }
+ ctx.write(row, new VectorWritable(topKSimilarities));
+ }
+ }
+
+ public static class MergeToTopKSimilaritiesReducer
+ extends Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ private int maxSimilaritiesPerRow;
+
+ @Override
+ protected void setup(Context ctx) throws IOException, InterruptedException {
+ maxSimilaritiesPerRow = ctx.getConfiguration().getInt(MAX_SIMILARITIES_PER_ROW, 0);
+ Preconditions.checkArgument(maxSimilaritiesPerRow > 0, "Maximum number of similarities per row must be greater then 0!");
+ }
+
+ @Override
+ protected void reduce(IntWritable row, Iterable<VectorWritable> partials, Context ctx)
+ throws IOException, InterruptedException {
+ Vector allSimilarities = Vectors.merge(partials);
+ Vector topKSimilarities = Vectors.topKElements(maxSimilaritiesPerRow, allSimilarities);
+ ctx.write(row, new VectorWritable(topKSimilarities));
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/TopElementsQueue.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/TopElementsQueue.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/TopElementsQueue.java
new file mode 100644
index 0000000..34135ac
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/TopElementsQueue.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence;
+
+import com.google.common.collect.Lists;
+import org.apache.lucene.util.PriorityQueue;
+
+import java.util.Collections;
+import java.util.List;
+
+public class TopElementsQueue extends PriorityQueue<MutableElement> {
+
+ private final int maxSize;
+
+ private static final int SENTINEL_INDEX = Integer.MIN_VALUE;
+
+ public TopElementsQueue(int maxSize) {
+ super(maxSize);
+ this.maxSize = maxSize;
+ }
+
+ public List<MutableElement> getTopElements() {
+ List<MutableElement> topElements = Lists.newArrayListWithCapacity(maxSize);
+ while (size() > 0) {
+ MutableElement top = pop();
+ // filter out "sentinel" objects necessary for maintaining an efficient priority queue
+ if (top.index() != SENTINEL_INDEX) {
+ topElements.add(top);
+ }
+ }
+ Collections.reverse(topElements);
+ return topElements;
+ }
+
+ @Override
+ protected MutableElement getSentinelObject() {
+ return new MutableElement(SENTINEL_INDEX, Double.MIN_VALUE);
+ }
+
+ @Override
+ protected boolean lessThan(MutableElement e1, MutableElement e2) {
+ return e1.get() < e2.get();
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/Vectors.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/Vectors.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/Vectors.java
new file mode 100644
index 0000000..66fb0ae
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/Vectors.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence;
+
+import java.io.DataInput;
+import java.io.IOException;
+import java.util.Iterator;
+
+import com.google.common.base.Preconditions;
+import com.google.common.io.Closeables;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataInputStream;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.mahout.common.iterator.FixedSizeSamplingIterator;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Varint;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.math.map.OpenIntIntHashMap;
+
+public final class Vectors {
+
+ private Vectors() {}
+
+ public static Vector maybeSample(Vector original, int sampleSize) {
+ if (original.getNumNondefaultElements() <= sampleSize) {
+ return original;
+ }
+ Vector sample = new RandomAccessSparseVector(original.size(), sampleSize);
+ Iterator<Element> sampledElements =
+ new FixedSizeSamplingIterator<>(sampleSize, original.nonZeroes().iterator());
+ while (sampledElements.hasNext()) {
+ Element elem = sampledElements.next();
+ sample.setQuick(elem.index(), elem.get());
+ }
+ return sample;
+ }
+
+ public static Vector topKElements(int k, Vector original) {
+ if (original.getNumNondefaultElements() <= k) {
+ return original;
+ }
+
+ TopElementsQueue topKQueue = new TopElementsQueue(k);
+ for (Element nonZeroElement : original.nonZeroes()) {
+ MutableElement top = topKQueue.top();
+ double candidateValue = nonZeroElement.get();
+ if (candidateValue > top.get()) {
+ top.setIndex(nonZeroElement.index());
+ top.set(candidateValue);
+ topKQueue.updateTop();
+ }
+ }
+
+ Vector topKSimilarities = new RandomAccessSparseVector(original.size(), k);
+ for (Vector.Element topKSimilarity : topKQueue.getTopElements()) {
+ topKSimilarities.setQuick(topKSimilarity.index(), topKSimilarity.get());
+ }
+ return topKSimilarities;
+ }
+
+ public static Vector merge(Iterable<VectorWritable> partialVectors) {
+ Iterator<VectorWritable> vectors = partialVectors.iterator();
+ Vector accumulator = vectors.next().get();
+ while (vectors.hasNext()) {
+ VectorWritable v = vectors.next();
+ if (v != null) {
+ for (Element nonZeroElement : v.get().nonZeroes()) {
+ accumulator.setQuick(nonZeroElement.index(), nonZeroElement.get());
+ }
+ }
+ }
+ return accumulator;
+ }
+
+ public static Vector sum(Iterator<VectorWritable> vectors) {
+ Vector sum = vectors.next().get();
+ while (vectors.hasNext()) {
+ sum.assign(vectors.next().get(), Functions.PLUS);
+ }
+ return sum;
+ }
+
+ static class TemporaryElement implements Vector.Element {
+
+ private final int index;
+ private double value;
+
+ TemporaryElement(int index, double value) {
+ this.index = index;
+ this.value = value;
+ }
+
+ TemporaryElement(Vector.Element toClone) {
+ this(toClone.index(), toClone.get());
+ }
+
+ @Override
+ public double get() {
+ return value;
+ }
+
+ @Override
+ public int index() {
+ return index;
+ }
+
+ @Override
+ public void set(double value) {
+ this.value = value;
+ }
+ }
+
+ public static Vector.Element[] toArray(VectorWritable vectorWritable) {
+ Vector.Element[] elements = new Vector.Element[vectorWritable.get().getNumNondefaultElements()];
+ int k = 0;
+ for (Element nonZeroElement : vectorWritable.get().nonZeroes()) {
+ elements[k++] = new TemporaryElement(nonZeroElement.index(), nonZeroElement.get());
+ }
+ return elements;
+ }
+
+ public static void write(Vector vector, Path path, Configuration conf) throws IOException {
+ write(vector, path, conf, false);
+ }
+
+ public static void write(Vector vector, Path path, Configuration conf, boolean laxPrecision) throws IOException {
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ FSDataOutputStream out = fs.create(path);
+ try {
+ VectorWritable vectorWritable = new VectorWritable(vector);
+ vectorWritable.setWritesLaxPrecision(laxPrecision);
+ vectorWritable.write(out);
+ } finally {
+ Closeables.close(out, false);
+ }
+ }
+
+ public static OpenIntIntHashMap readAsIntMap(Path path, Configuration conf) throws IOException {
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ FSDataInputStream in = fs.open(path);
+ try {
+ return readAsIntMap(in);
+ } finally {
+ Closeables.close(in, true);
+ }
+ }
+
+ /* ugly optimization for loading sparse vectors containing ints only */
+ private static OpenIntIntHashMap readAsIntMap(DataInput in) throws IOException {
+ int flags = in.readByte();
+ Preconditions.checkArgument(flags >> VectorWritable.NUM_FLAGS == 0,
+ "Unknown flags set: %d", Integer.toString(flags, 2));
+ boolean dense = (flags & VectorWritable.FLAG_DENSE) != 0;
+ boolean sequential = (flags & VectorWritable.FLAG_SEQUENTIAL) != 0;
+ boolean laxPrecision = (flags & VectorWritable.FLAG_LAX_PRECISION) != 0;
+ Preconditions.checkState(!dense && !sequential, "Only for reading sparse vectors!");
+
+ Varint.readUnsignedVarInt(in);
+
+ OpenIntIntHashMap values = new OpenIntIntHashMap();
+ int numNonDefaultElements = Varint.readUnsignedVarInt(in);
+ for (int i = 0; i < numNonDefaultElements; i++) {
+ int index = Varint.readUnsignedVarInt(in);
+ double value = laxPrecision ? in.readFloat() : in.readDouble();
+ values.put(index, (int) value);
+ }
+ return values;
+ }
+
+ public static Vector read(Path path, Configuration conf) throws IOException {
+ FileSystem fs = FileSystem.get(path.toUri(), conf);
+ FSDataInputStream in = fs.open(path);
+ try {
+ return VectorWritable.readVector(in);
+ } finally {
+ Closeables.close(in, true);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CityBlockSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CityBlockSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CityBlockSimilarity.java
new file mode 100644
index 0000000..0435d84
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CityBlockSimilarity.java
@@ -0,0 +1,26 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+public class CityBlockSimilarity extends CountbasedMeasure {
+
+ @Override
+ public double similarity(double dots, double normA, double normB, int numberOfColumns) {
+ return 1.0 / (1.0 + normA + normB - 2 * dots);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CooccurrenceCountSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CooccurrenceCountSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CooccurrenceCountSimilarity.java
new file mode 100644
index 0000000..61d071f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CooccurrenceCountSimilarity.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+public class CooccurrenceCountSimilarity extends CountbasedMeasure {
+
+ @Override
+ public double similarity(double dots, double normA, double normB, int numberOfColumns) {
+ return dots;
+ }
+
+ @Override
+ public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold) {
+ return numNonZeroEntriesA >= threshold && numNonZeroEntriesB >= threshold;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CosineSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CosineSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CosineSimilarity.java
new file mode 100644
index 0000000..3f4946b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CosineSimilarity.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.Vector;
+
+public class CosineSimilarity implements VectorSimilarityMeasure {
+
+ @Override
+ public Vector normalize(Vector vector) {
+ return vector.normalize();
+ }
+
+ @Override
+ public double norm(Vector vector) {
+ return VectorSimilarityMeasure.NO_NORM;
+ }
+
+ @Override
+ public double aggregate(double valueA, double nonZeroValueB) {
+ return valueA * nonZeroValueB;
+ }
+
+ @Override
+ public double similarity(double dots, double normA, double normB, int numberOfColumns) {
+ return dots;
+ }
+
+ @Override
+ public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold) {
+ return numNonZeroEntriesB >= threshold / maxValueA
+ && numNonZeroEntriesA >= threshold / maxValueB;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CountbasedMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CountbasedMeasure.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CountbasedMeasure.java
new file mode 100644
index 0000000..105df2b
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/CountbasedMeasure.java
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.Vector;
+
+public abstract class CountbasedMeasure implements VectorSimilarityMeasure {
+
+ @Override
+ public Vector normalize(Vector vector) {
+ return vector;
+ }
+
+ @Override
+ public double norm(Vector vector) {
+ return vector.norm(0);
+ }
+
+ @Override
+ public double aggregate(double valueA, double nonZeroValueB) {
+ return 1;
+ }
+
+ @Override
+ public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold) {
+ return true;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java
new file mode 100644
index 0000000..e61c3eb
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/EuclideanDistanceSimilarity.java
@@ -0,0 +1,57 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.Vector;
+
+public class EuclideanDistanceSimilarity implements VectorSimilarityMeasure {
+
+ @Override
+ public Vector normalize(Vector vector) {
+ return vector;
+ }
+
+ @Override
+ public double norm(Vector vector) {
+ double norm = 0;
+ for (Vector.Element e : vector.nonZeroes()) {
+ double value = e.get();
+ norm += value * value;
+ }
+ return norm;
+ }
+
+ @Override
+ public double aggregate(double valueA, double nonZeroValueB) {
+ return valueA * nonZeroValueB;
+ }
+
+ @Override
+ public double similarity(double dots, double normA, double normB, int numberOfColumns) {
+ // Arg can't be negative in theory, but can in practice due to rounding, so cap it.
+ // Also note that normA / normB are actually the squares of the norms.
+ double euclideanDistance = Math.sqrt(Math.max(0.0, normA - 2 * dots + normB));
+ return 1.0 / (1.0 + euclideanDistance);
+ }
+
+ @Override
+ public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold) {
+ return true;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java
new file mode 100644
index 0000000..7544b5d
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/LoglikelihoodSimilarity.java
@@ -0,0 +1,34 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.stats.LogLikelihood;
+
+public class LoglikelihoodSimilarity extends CountbasedMeasure {
+
+ @Override
+ public double similarity(double summedAggregations, double normA, double normB, int numberOfColumns) {
+ double logLikelihood =
+ LogLikelihood.logLikelihoodRatio((long) summedAggregations,
+ (long) (normB - summedAggregations),
+ (long) (normA - summedAggregations),
+ (long) (numberOfColumns - normA - normB + summedAggregations));
+ return 1.0 - 1.0 / (1.0 + logLikelihood);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java
new file mode 100644
index 0000000..c650d8f
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/PearsonCorrelationSimilarity.java
@@ -0,0 +1,37 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.Vector;
+
+public class PearsonCorrelationSimilarity extends CosineSimilarity {
+
+ @Override
+ public Vector normalize(Vector vector) {
+ if (vector.getNumNondefaultElements() == 0) {
+ return vector;
+ }
+
+ // center non-zero elements
+ double average = vector.norm(1) / vector.getNumNonZeroElements();
+ for (Vector.Element e : vector.nonZeroes()) {
+ e.set(e.get() - average);
+ }
+ return super.normalize(vector);
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java
new file mode 100644
index 0000000..e000579
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/TanimotoCoefficientSimilarity.java
@@ -0,0 +1,34 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+public class TanimotoCoefficientSimilarity extends CountbasedMeasure {
+
+ @Override
+ public double similarity(double dots, double normA, double normB, int numberOfColumns) {
+ // Return 0 even when dots == 0 since this will cause it to be ignored -- not NaN
+ return dots / (normA + normB - dots);
+ }
+
+ @Override
+ public boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold) {
+ return numNonZeroEntriesA >= numNonZeroEntriesB * threshold
+ && numNonZeroEntriesB >= numNonZeroEntriesA * threshold;
+ }
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java
new file mode 100644
index 0000000..77125c2
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasure.java
@@ -0,0 +1,32 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import org.apache.mahout.math.Vector;
+
+public interface VectorSimilarityMeasure {
+
+ double NO_NORM = 0.0;
+
+ Vector normalize(Vector vector);
+ double norm(Vector vector);
+ double aggregate(double nonZeroValueA, double nonZeroValueB);
+ double similarity(double summedAggregations, double normA, double normB, int numberOfColumns);
+ boolean consider(int numNonZeroEntriesA, int numNonZeroEntriesB, double maxValueA, double maxValueB,
+ double threshold);
+}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java
----------------------------------------------------------------------
diff --git a/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java
new file mode 100644
index 0000000..9d1160e
--- /dev/null
+++ b/mr/src/main/java/org/apache/mahout/math/hadoop/similarity/cooccurrence/measures/VectorSimilarityMeasures.java
@@ -0,0 +1,46 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.hadoop.similarity.cooccurrence.measures;
+
+import java.util.Arrays;
+
+public enum VectorSimilarityMeasures {
+
+ SIMILARITY_COOCCURRENCE(CooccurrenceCountSimilarity.class),
+ SIMILARITY_LOGLIKELIHOOD(LoglikelihoodSimilarity.class),
+ SIMILARITY_TANIMOTO_COEFFICIENT(TanimotoCoefficientSimilarity.class),
+ SIMILARITY_CITY_BLOCK(CityBlockSimilarity.class),
+ SIMILARITY_COSINE(CosineSimilarity.class),
+ SIMILARITY_PEARSON_CORRELATION(PearsonCorrelationSimilarity.class),
+ SIMILARITY_EUCLIDEAN_DISTANCE(EuclideanDistanceSimilarity.class);
+
+ private final Class<? extends VectorSimilarityMeasure> implementingClass;
+
+ VectorSimilarityMeasures(Class<? extends VectorSimilarityMeasure> impl) {
+ this.implementingClass = impl;
+ }
+
+ public String getClassname() {
+ return implementingClass.getName();
+ }
+
+ public static String list() {
+ return Arrays.toString(values());
+ }
+
+}