You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2011/07/15 23:07:49 UTC

svn commit: r1147318 - in /mahout/trunk/core/src: main/java/org/apache/mahout/math/hadoop/similarity/ test/java/org/apache/mahout/math/hadoop/similarity/

Author: gsingers
Date: Fri Jul 15 21:07:48 2011
New Revision: 1147318

URL: http://svn.apache.org/viewvc?rev=1147318&view=rev
Log:
MAHOUT-763: add alternative output mapping

Added:
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java
      - copied, changed from r1147257, mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
Modified:
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
    mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
    mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java

Added: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java?rev=1147318&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java Fri Jul 15 21:07:48 2011
@@ -0,0 +1,119 @@
+package org.apache.mahout.math.hadoop.similarity;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.kmeans.Cluster;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ *
+ *
+ **/
+class SeedVectorUtil {
+  private transient static Logger log = LoggerFactory.getLogger(SeedVectorUtil.class);
+
+  private SeedVectorUtil() {
+
+  }
+
+  public static void loadSeedVectors(Configuration conf, List<NamedVector> seedVectors) throws IOException {
+
+    String seedPathStr = conf.get(VectorDistanceSimilarityJob.SEEDS_PATH_KEY);
+    if (seedPathStr != null && seedPathStr.length() > 0) {
+
+      Path thePath = new Path(seedPathStr, "*");
+      Collection<Path> result = Lists.newArrayList();
+
+      // get all filtered file names in result list
+      FileSystem fs = thePath.getFileSystem(conf);
+      FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(thePath, PathFilters.partFilter())),
+              PathFilters.partFilter());
+
+      for (FileStatus match : matches) {
+        result.add(fs.makeQualified(match.getPath()));
+      }
+
+      long item = 0;
+      for (Path seedPath : result) {
+        for (Writable value : new SequenceFileValueIterable<Writable>(seedPath, conf)) {
+          Class<? extends Writable> valueClass = value.getClass();
+          if (valueClass.equals(Cluster.class)) {
+            // get the cluster info
+            Cluster cluster = (Cluster) value;
+            Vector vector = cluster.getCenter();
+            if (vector instanceof NamedVector) {
+              seedVectors.add((NamedVector) vector);
+            } else {
+              seedVectors.add(new NamedVector(vector, cluster.getIdentifier()));
+            }
+          } else if (valueClass.equals(Canopy.class)) {
+            // get the cluster info
+            Canopy canopy = (Canopy) value;
+            Vector vector = canopy.getCenter();
+            if (vector instanceof NamedVector) {
+              seedVectors.add((NamedVector) vector);
+            } else {
+              seedVectors.add(new NamedVector(vector, canopy.getIdentifier()));
+            }
+          } else if (valueClass.equals(Vector.class)) {
+            Vector vector = (Vector) value;
+            if (vector instanceof NamedVector) {
+              seedVectors.add((NamedVector) vector);
+            } else {
+              seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
+            }
+          } else if (valueClass.equals(VectorWritable.class) || valueClass.isInstance(VectorWritable.class)) {
+            VectorWritable vw = (VectorWritable) value;
+            Vector vector = vw.get();
+            if (vector instanceof NamedVector) {
+              seedVectors.add((NamedVector) vector);
+            } else {
+              seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
+            }
+          } else {
+            throw new IllegalStateException("Bad value class: " + valueClass);
+          }
+        }
+      }
+      if (seedVectors.isEmpty()) {
+        throw new IllegalStateException("No seeds found. Check your path: " + seedPathStr);
+      } else {
+        log.info("Seed Vectors size: " + seedVectors.size());
+      }
+    }
+  }
+
+}

Copied: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java (from r1147257, mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java&r1=1147257&r2=1147318&rev=1147318&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java Fri Jul 15 21:07:48 2011
@@ -1,22 +1,28 @@
 package org.apache.mahout.math.hadoop.similarity;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 
 
-import com.google.common.collect.Lists;
 import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileStatus;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.FileUtil;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.Text;
 import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.clustering.canopy.Canopy;
-import org.apache.mahout.clustering.kmeans.Cluster;
-import org.apache.mahout.common.StringTuple;
 import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.iterator.sequencefile.PathFilters;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.NamedVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
@@ -25,15 +31,14 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.List;
 
 /**
- *
- *
- **/
-public class VectorDistanceMapper extends Mapper<WritableComparable<?>, VectorWritable, StringTuple, DoubleWritable> {
-  private transient static Logger log = LoggerFactory.getLogger(VectorDistanceMapper.class);
+ * Similar to {@link org.apache.mahout.math.hadoop.similarity.VectorDistanceMapper}, except it outputs
+ * &lt;input, Vector&gt;, where the vector is a dense vector contain one entry for every seed vector
+ */
+public class VectorDistanceInvertedMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
+  private transient static Logger log = LoggerFactory.getLogger(VectorDistanceInvertedMapper.class);
   protected DistanceMeasure measure;
   protected List<NamedVector> seedVectors;
 
@@ -46,94 +51,30 @@ public class VectorDistanceMapper extend
     } else {
       keyName = key.toString();
     }
+    Vector outVec = new DenseVector(new double[seedVectors.size()]);
+    int i = 0;
     for (NamedVector seedVector : seedVectors) {
-      double distance = measure.distance(seedVector, valVec);
-      StringTuple outKey = new StringTuple();
-      outKey.add(seedVector.getName());
-      outKey.add(keyName);
-      context.write(outKey, new DoubleWritable(distance));
+      outVec.setQuick(i++, measure.distance(seedVector, valVec));
     }
+    context.write(new Text(keyName), new VectorWritable(outVec));
   }
 
   @Override
   protected void setup(Context context) throws IOException, InterruptedException {
     super.setup(context);
     Configuration conf = context.getConfiguration();
+    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
     try {
-      ClassLoader ccl = Thread.currentThread().getContextClassLoader();
       measure = ccl.loadClass(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY))
               .asSubclass(DistanceMeasure.class).newInstance();
       measure.configure(conf);
-
-
-      String seedPathStr = conf.get(VectorDistanceSimilarityJob.SEEDS_PATH_KEY);
-      if (seedPathStr != null && seedPathStr.length() > 0) {
-
-        Path thePath = new Path(seedPathStr, "*");
-        Collection<Path> result = Lists.newArrayList();
-
-        // get all filtered file names in result list
-        FileSystem fs = thePath.getFileSystem(conf);
-        FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(thePath, PathFilters.partFilter())),
-                PathFilters.partFilter());
-
-        for (FileStatus match : matches) {
-          result.add(fs.makeQualified(match.getPath()));
-        }
-        seedVectors = new ArrayList<NamedVector>(100);
-        long item = 0;
-        for (Path seedPath : result) {
-          for (Writable value : new SequenceFileValueIterable<Writable>(seedPath, conf)) {
-            Class<? extends Writable> valueClass = value.getClass();
-            if (valueClass.equals(Cluster.class)) {
-              // get the cluster info
-              Cluster cluster = (Cluster) value;
-              Vector vector = cluster.getCenter();
-              if (vector instanceof NamedVector) {
-                seedVectors.add((NamedVector) vector);
-              } else {
-                seedVectors.add(new NamedVector(vector, cluster.getIdentifier()));
-              }
-            } else if (valueClass.equals(Canopy.class)) {
-              // get the cluster info
-              Canopy canopy = (Canopy) value;
-              Vector vector = canopy.getCenter();
-              if (vector instanceof NamedVector) {
-                seedVectors.add((NamedVector) vector);
-              } else {
-                seedVectors.add(new NamedVector(vector, canopy.getIdentifier()));
-              }
-            } else if (valueClass.equals(Vector.class)) {
-              Vector vector = (Vector) value;
-              if (vector instanceof NamedVector) {
-                seedVectors.add((NamedVector) vector);
-              } else {
-                seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
-              }
-            } else if (valueClass.equals(VectorWritable.class) || valueClass.isInstance(VectorWritable.class)) {
-              VectorWritable vw = (VectorWritable) value;
-              Vector vector = vw.get();
-              if (vector instanceof NamedVector) {
-                seedVectors.add((NamedVector) vector);
-              } else {
-                seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
-              }
-            } else {
-              throw new IllegalStateException("Bad value class: " + valueClass);
-            }
-          }
-        }
-        if (seedVectors.isEmpty()) {
-          throw new IllegalStateException("No seeds found. Check your path: " + seedPathStr);
-        } else {
-          log.info("Seed Vectors size: " + seedVectors.size());
-        }
-      }
-    } catch (ClassNotFoundException e) {
+      seedVectors = new ArrayList<NamedVector>(1000);
+      SeedVectorUtil.loadSeedVectors(conf, seedVectors);
+    } catch (InstantiationException e) {
       throw new IllegalStateException(e);
     } catch (IllegalAccessException e) {
       throw new IllegalStateException(e);
-    } catch (InstantiationException e) {
+    } catch (ClassNotFoundException e) {
       throw new IllegalStateException(e);
     }
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java?rev=1147318&r1=1147317&r2=1147318&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java Fri Jul 15 21:07:48 2011
@@ -1,22 +1,28 @@
 package org.apache.mahout.math.hadoop.similarity;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
 
 
-import com.google.common.collect.Lists;
 import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileStatus;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.FileUtil;
-import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.clustering.canopy.Canopy;
-import org.apache.mahout.clustering.kmeans.Cluster;
 import org.apache.mahout.common.StringTuple;
 import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.iterator.sequencefile.PathFilters;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
 import org.apache.mahout.math.NamedVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
@@ -25,7 +31,6 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.List;
 
 /**
@@ -59,81 +64,18 @@ public class VectorDistanceMapper extend
   protected void setup(Context context) throws IOException, InterruptedException {
     super.setup(context);
     Configuration conf = context.getConfiguration();
+    ClassLoader ccl = Thread.currentThread().getContextClassLoader();
     try {
-      ClassLoader ccl = Thread.currentThread().getContextClassLoader();
       measure = ccl.loadClass(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY))
               .asSubclass(DistanceMeasure.class).newInstance();
       measure.configure(conf);
-
-
-      String seedPathStr = conf.get(VectorDistanceSimilarityJob.SEEDS_PATH_KEY);
-      if (seedPathStr != null && seedPathStr.length() > 0) {
-
-        Path thePath = new Path(seedPathStr, "*");
-        Collection<Path> result = Lists.newArrayList();
-
-        // get all filtered file names in result list
-        FileSystem fs = thePath.getFileSystem(conf);
-        FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(thePath, PathFilters.partFilter())),
-                PathFilters.partFilter());
-
-        for (FileStatus match : matches) {
-          result.add(fs.makeQualified(match.getPath()));
-        }
-        seedVectors = new ArrayList<NamedVector>(100);
-        long item = 0;
-        for (Path seedPath : result) {
-          for (Writable value : new SequenceFileValueIterable<Writable>(seedPath, conf)) {
-            Class<? extends Writable> valueClass = value.getClass();
-            if (valueClass.equals(Cluster.class)) {
-              // get the cluster info
-              Cluster cluster = (Cluster) value;
-              Vector vector = cluster.getCenter();
-              if (vector instanceof NamedVector) {
-                seedVectors.add((NamedVector) vector);
-              } else {
-                seedVectors.add(new NamedVector(vector, cluster.getIdentifier()));
-              }
-            } else if (valueClass.equals(Canopy.class)) {
-              // get the cluster info
-              Canopy canopy = (Canopy) value;
-              Vector vector = canopy.getCenter();
-              if (vector instanceof NamedVector) {
-                seedVectors.add((NamedVector) vector);
-              } else {
-                seedVectors.add(new NamedVector(vector, canopy.getIdentifier()));
-              }
-            } else if (valueClass.equals(Vector.class)) {
-              Vector vector = (Vector) value;
-              if (vector instanceof NamedVector) {
-                seedVectors.add((NamedVector) vector);
-              } else {
-                seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
-              }
-            } else if (valueClass.equals(VectorWritable.class) || valueClass.isInstance(VectorWritable.class)) {
-              VectorWritable vw = (VectorWritable) value;
-              Vector vector = vw.get();
-              if (vector instanceof NamedVector) {
-                seedVectors.add((NamedVector) vector);
-              } else {
-                seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
-              }
-            } else {
-              throw new IllegalStateException("Bad value class: " + valueClass);
-            }
-          }
-        }
-        if (seedVectors.isEmpty()) {
-          throw new IllegalStateException("No seeds found. Check your path: " + seedPathStr);
-        } else {
-          log.info("Seed Vectors size: " + seedVectors.size());
-        }
-      }
-    } catch (ClassNotFoundException e) {
+      seedVectors = new ArrayList<NamedVector>(1000);
+      SeedVectorUtil.loadSeedVectors(conf, seedVectors);
+    } catch (InstantiationException e) {
       throw new IllegalStateException(e);
     } catch (IllegalAccessException e) {
       throw new IllegalStateException(e);
-    } catch (InstantiationException e) {
+    } catch (ClassNotFoundException e) {
       throw new IllegalStateException(e);
     }
   }

Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java?rev=1147318&r1=1147317&r2=1147318&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java Fri Jul 15 21:07:48 2011
@@ -20,6 +20,7 @@ package org.apache.mahout.math.hadoop.si
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapreduce.Job;
 import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
 import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
@@ -32,6 +33,7 @@ import org.apache.mahout.common.StringTu
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
 import org.apache.mahout.common.distance.DistanceMeasure;
 import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.math.VectorWritable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -46,6 +48,7 @@ public class VectorDistanceSimilarityJob
   public static final String SEEDS = "seeds";
   public static final String SEEDS_PATH_KEY = "seedsPath";
   public static final String DISTANCE_MEASURE_KEY = "vectorDistSim.measure";
+  public static final String OUT_TYPE_KEY = "outType";
 
   public static void main(String[] args) throws Exception {
     ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
@@ -59,7 +62,7 @@ public class VectorDistanceSimilarityJob
     addOption(DefaultOptionCreator.distanceMeasureOption().create());
     addOption(SEEDS, "s", "The set of vectors to compute distances against.  Must fit in memory on the mapper");
     addOption(DefaultOptionCreator.overwriteOption().create());
-
+    addOption(OUT_TYPE_KEY, "ot", "[pw|v] -- Define the output style: pairwise, the default, (pw) or vector (v).  Pairwise is a tuple of <seed, other, distance>, vector is <other, <Vector of size the number of seeds>>.", "pw");
     if (parseArguments(args) == null) {
       return -1;
     }
@@ -79,7 +82,12 @@ public class VectorDistanceSimilarityJob
     if (getConf() == null) {
       setConf(new Configuration());
     }
-    run(getConf(), input, seeds, output, measure);
+    String outType = getOption(OUT_TYPE_KEY);
+    if (outType == null) {
+      outType = "pw";
+    }
+
+    run(getConf(), input, seeds, output, measure, outType);
     return 0;
   }
 
@@ -87,17 +95,28 @@ public class VectorDistanceSimilarityJob
                          Path input,
                          Path seeds,
                          Path output,
-                         DistanceMeasure measure) throws IOException, ClassNotFoundException, InterruptedException {
+                         DistanceMeasure measure, String outType) throws IOException, ClassNotFoundException, InterruptedException {
     conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName());
     conf.set(SEEDS_PATH_KEY, seeds.toString());
     Job job = new Job(conf, "Vector Distance Similarity: seeds: " + seeds + " input: " + input);
     job.setInputFormatClass(SequenceFileInputFormat.class);
     job.setOutputFormatClass(SequenceFileOutputFormat.class);
-    job.setMapOutputKeyClass(StringTuple.class);
-    job.setOutputKeyClass(StringTuple.class);
-    job.setMapOutputValueClass(DoubleWritable.class);
-    job.setOutputValueClass(DoubleWritable.class);
-    job.setMapperClass(VectorDistanceMapper.class);
+    if (outType.equalsIgnoreCase("pw")) {
+      job.setMapOutputKeyClass(StringTuple.class);
+      job.setOutputKeyClass(StringTuple.class);
+      job.setMapOutputValueClass(DoubleWritable.class);
+      job.setOutputValueClass(DoubleWritable.class);
+      job.setMapperClass(VectorDistanceMapper.class);
+    } else if (outType.equalsIgnoreCase("v")) {
+      job.setMapOutputKeyClass(Text.class);
+      job.setOutputKeyClass(Text.class);
+      job.setMapOutputValueClass(VectorWritable.class);
+      job.setOutputValueClass(VectorWritable.class);
+      job.setMapperClass(VectorDistanceInvertedMapper.class);
+    } else {
+      throw new InterruptedException("Invalid outType specified: " + outType);
+    }
+
 
     job.setNumReduceTasks(0);
     FileInputFormat.addInputPath(job, input);

Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java?rev=1147318&r1=1147317&r2=1147318&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java Fri Jul 15 21:07:48 2011
@@ -23,15 +23,20 @@ import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.DoubleWritable;
 import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.io.WritableComparable;
 import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.hadoop.util.ToolRunner;
 import org.apache.mahout.clustering.ClusteringTestUtils;
-import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.common.DummyOutputCollector;
 import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.Pair;
 import org.apache.mahout.common.StringTuple;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
 import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.NamedVector;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
@@ -42,6 +47,7 @@ import org.junit.Test;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 
 /**
  *
@@ -98,6 +104,37 @@ public class TestVectorDistanceSimilarit
 
   }
 
+  @Test
+  public void testVectorDistanceInvertedMapper() throws Exception {
+     Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context =
+            EasyMock.createMock(Mapper.Context.class);
+    Vector expectVec = new DenseVector(new double[]{Math.sqrt(2.0), 1.0});
+    context.write(new Text("other"), new VectorWritable(expectVec));
+    EasyMock.replay(context);
+    Vector vector = new NamedVector(new RandomAccessSparseVector(2), "other");
+    vector.set(0, 2);
+    vector.set(1, 2);
+
+    VectorDistanceInvertedMapper mapper = new VectorDistanceInvertedMapper();
+    setField(mapper, "measure", new EuclideanDistanceMeasure());
+    List<NamedVector> seedVectors = new ArrayList<NamedVector>();
+    Vector seed1 = new RandomAccessSparseVector(2);
+    seed1.set(0, 1);
+    seed1.set(1, 1);
+    Vector seed2 = new RandomAccessSparseVector(2);
+    seed2.set(0, 2);
+    seed2.set(1, 1);
+
+    seedVectors.add(new NamedVector(seed1, "foo"));
+    seedVectors.add(new NamedVector(seed2, "foo2"));
+    setField(mapper, "seedVectors", seedVectors);
+
+    mapper.map(new IntWritable(123), new VectorWritable(vector), context);
+
+    EasyMock.verify(context);
+
+  }
+
   public static final double[][] REFERENCE = {
           {1, 1}, {2, 1}, {1, 2}, {2, 2}, {3, 3}, {4, 4}, {5, 4}, {4, 5}, {5, 5}
   };
@@ -119,8 +156,49 @@ public class TestVectorDistanceSimilarit
     String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
             optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
             output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName()
-            };
+    };
     ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
+    int expect = SEEDS.length * REFERENCE.length;
+    DummyOutputCollector<StringTuple, DoubleWritable> collector =
+            new DummyOutputCollector<StringTuple, DoubleWritable>();
+    //
+    for (Pair<StringTuple, DoubleWritable> record :
+            new SequenceFileIterable<StringTuple, DoubleWritable>(
+                    new Path(output, "part-m-00000"), conf)) {
+      collector.collect(record.getFirst(), record.getSecond());
+    }
+    assertEquals(expect, collector.getData().size());
+  }
+
+  @Test
+  public void testRunInverted() throws Exception {
+    Path input = getTestTempDirPath("input");
+    Path output = getTestTempDirPath("output");
+    Path seedsPath = getTestTempDirPath("seeds");
+    List<VectorWritable> points = getPointsWritable(REFERENCE);
+    List<VectorWritable> seeds = getPointsWritable(SEEDS);
+    Configuration conf = new Configuration();
+    ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf);
+    ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"), fs, conf);
+    String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
+            optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
+            output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName(),
+            optKey(VectorDistanceSimilarityJob.OUT_TYPE_KEY), "v"
+    };
+    ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
+
+    DummyOutputCollector<Text, VectorWritable> collector =
+            new DummyOutputCollector<Text, VectorWritable>();
+    //
+    for (Pair<Text, VectorWritable> record :
+            new SequenceFileIterable<Text, VectorWritable>(
+                    new Path(output, "part-m-00000"), conf)) {
+      collector.collect(record.getFirst(), record.getSecond());
+    }
+    assertEquals(REFERENCE.length, collector.getData().size());
+    for (Map.Entry<Text, List<VectorWritable>> entry : collector.getData().entrySet()) {
+      assertEquals(SEEDS.length, entry.getValue().iterator().next().get().size());
+    }
   }
 
   public static List<VectorWritable> getPointsWritable(double[][] raw) {