You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by gs...@apache.org on 2011/07/15 23:07:49 UTC
svn commit: r1147318 - in /mahout/trunk/core/src:
main/java/org/apache/mahout/math/hadoop/similarity/
test/java/org/apache/mahout/math/hadoop/similarity/
Author: gsingers
Date: Fri Jul 15 21:07:48 2011
New Revision: 1147318
URL: http://svn.apache.org/viewvc?rev=1147318&view=rev
Log:
MAHOUT-763: add alternative output mapping
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java
- copied, changed from r1147257, mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
Added: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java?rev=1147318&view=auto
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java (added)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/SeedVectorUtil.java Fri Jul 15 21:07:48 2011
@@ -0,0 +1,119 @@
+package org.apache.mahout.math.hadoop.similarity;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import com.google.common.collect.Lists;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileUtil;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.canopy.Canopy;
+import org.apache.mahout.clustering.kmeans.Cluster;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.NamedVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ *
+ *
+ **/
+class SeedVectorUtil {
+ private transient static Logger log = LoggerFactory.getLogger(SeedVectorUtil.class);
+
+ private SeedVectorUtil() {
+
+ }
+
+ public static void loadSeedVectors(Configuration conf, List<NamedVector> seedVectors) throws IOException {
+
+ String seedPathStr = conf.get(VectorDistanceSimilarityJob.SEEDS_PATH_KEY);
+ if (seedPathStr != null && seedPathStr.length() > 0) {
+
+ Path thePath = new Path(seedPathStr, "*");
+ Collection<Path> result = Lists.newArrayList();
+
+ // get all filtered file names in result list
+ FileSystem fs = thePath.getFileSystem(conf);
+ FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(thePath, PathFilters.partFilter())),
+ PathFilters.partFilter());
+
+ for (FileStatus match : matches) {
+ result.add(fs.makeQualified(match.getPath()));
+ }
+
+ long item = 0;
+ for (Path seedPath : result) {
+ for (Writable value : new SequenceFileValueIterable<Writable>(seedPath, conf)) {
+ Class<? extends Writable> valueClass = value.getClass();
+ if (valueClass.equals(Cluster.class)) {
+ // get the cluster info
+ Cluster cluster = (Cluster) value;
+ Vector vector = cluster.getCenter();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, cluster.getIdentifier()));
+ }
+ } else if (valueClass.equals(Canopy.class)) {
+ // get the cluster info
+ Canopy canopy = (Canopy) value;
+ Vector vector = canopy.getCenter();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, canopy.getIdentifier()));
+ }
+ } else if (valueClass.equals(Vector.class)) {
+ Vector vector = (Vector) value;
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
+ }
+ } else if (valueClass.equals(VectorWritable.class) || valueClass.isInstance(VectorWritable.class)) {
+ VectorWritable vw = (VectorWritable) value;
+ Vector vector = vw.get();
+ if (vector instanceof NamedVector) {
+ seedVectors.add((NamedVector) vector);
+ } else {
+ seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
+ }
+ } else {
+ throw new IllegalStateException("Bad value class: " + valueClass);
+ }
+ }
+ }
+ if (seedVectors.isEmpty()) {
+ throw new IllegalStateException("No seeds found. Check your path: " + seedPathStr);
+ } else {
+ log.info("Seed Vectors size: " + seedVectors.size());
+ }
+ }
+ }
+
+}
Copied: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java (from r1147257, mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java)
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java&r1=1147257&r2=1147318&rev=1147318&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceInvertedMapper.java Fri Jul 15 21:07:48 2011
@@ -1,22 +1,28 @@
package org.apache.mahout.math.hadoop.similarity;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
-import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileStatus;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.FileUtil;
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hadoop.io.Writable;
+import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.clustering.canopy.Canopy;
-import org.apache.mahout.clustering.kmeans.Cluster;
-import org.apache.mahout.common.StringTuple;
import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.iterator.sequencefile.PathFilters;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
+import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -25,15 +31,14 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Collection;
import java.util.List;
/**
- *
- *
- **/
-public class VectorDistanceMapper extends Mapper<WritableComparable<?>, VectorWritable, StringTuple, DoubleWritable> {
- private transient static Logger log = LoggerFactory.getLogger(VectorDistanceMapper.class);
+ * Similar to {@link org.apache.mahout.math.hadoop.similarity.VectorDistanceMapper}, except it outputs
+ * <input, Vector>, where the vector is a dense vector contain one entry for every seed vector
+ */
+public class VectorDistanceInvertedMapper extends Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable> {
+ private transient static Logger log = LoggerFactory.getLogger(VectorDistanceInvertedMapper.class);
protected DistanceMeasure measure;
protected List<NamedVector> seedVectors;
@@ -46,94 +51,30 @@ public class VectorDistanceMapper extend
} else {
keyName = key.toString();
}
+ Vector outVec = new DenseVector(new double[seedVectors.size()]);
+ int i = 0;
for (NamedVector seedVector : seedVectors) {
- double distance = measure.distance(seedVector, valVec);
- StringTuple outKey = new StringTuple();
- outKey.add(seedVector.getName());
- outKey.add(keyName);
- context.write(outKey, new DoubleWritable(distance));
+ outVec.setQuick(i++, measure.distance(seedVector, valVec));
}
+ context.write(new Text(keyName), new VectorWritable(outVec));
}
@Override
protected void setup(Context context) throws IOException, InterruptedException {
super.setup(context);
Configuration conf = context.getConfiguration();
+ ClassLoader ccl = Thread.currentThread().getContextClassLoader();
try {
- ClassLoader ccl = Thread.currentThread().getContextClassLoader();
measure = ccl.loadClass(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY))
.asSubclass(DistanceMeasure.class).newInstance();
measure.configure(conf);
-
-
- String seedPathStr = conf.get(VectorDistanceSimilarityJob.SEEDS_PATH_KEY);
- if (seedPathStr != null && seedPathStr.length() > 0) {
-
- Path thePath = new Path(seedPathStr, "*");
- Collection<Path> result = Lists.newArrayList();
-
- // get all filtered file names in result list
- FileSystem fs = thePath.getFileSystem(conf);
- FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(thePath, PathFilters.partFilter())),
- PathFilters.partFilter());
-
- for (FileStatus match : matches) {
- result.add(fs.makeQualified(match.getPath()));
- }
- seedVectors = new ArrayList<NamedVector>(100);
- long item = 0;
- for (Path seedPath : result) {
- for (Writable value : new SequenceFileValueIterable<Writable>(seedPath, conf)) {
- Class<? extends Writable> valueClass = value.getClass();
- if (valueClass.equals(Cluster.class)) {
- // get the cluster info
- Cluster cluster = (Cluster) value;
- Vector vector = cluster.getCenter();
- if (vector instanceof NamedVector) {
- seedVectors.add((NamedVector) vector);
- } else {
- seedVectors.add(new NamedVector(vector, cluster.getIdentifier()));
- }
- } else if (valueClass.equals(Canopy.class)) {
- // get the cluster info
- Canopy canopy = (Canopy) value;
- Vector vector = canopy.getCenter();
- if (vector instanceof NamedVector) {
- seedVectors.add((NamedVector) vector);
- } else {
- seedVectors.add(new NamedVector(vector, canopy.getIdentifier()));
- }
- } else if (valueClass.equals(Vector.class)) {
- Vector vector = (Vector) value;
- if (vector instanceof NamedVector) {
- seedVectors.add((NamedVector) vector);
- } else {
- seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
- }
- } else if (valueClass.equals(VectorWritable.class) || valueClass.isInstance(VectorWritable.class)) {
- VectorWritable vw = (VectorWritable) value;
- Vector vector = vw.get();
- if (vector instanceof NamedVector) {
- seedVectors.add((NamedVector) vector);
- } else {
- seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
- }
- } else {
- throw new IllegalStateException("Bad value class: " + valueClass);
- }
- }
- }
- if (seedVectors.isEmpty()) {
- throw new IllegalStateException("No seeds found. Check your path: " + seedPathStr);
- } else {
- log.info("Seed Vectors size: " + seedVectors.size());
- }
- }
- } catch (ClassNotFoundException e) {
+ seedVectors = new ArrayList<NamedVector>(1000);
+ SeedVectorUtil.loadSeedVectors(conf, seedVectors);
+ } catch (InstantiationException e) {
throw new IllegalStateException(e);
} catch (IllegalAccessException e) {
throw new IllegalStateException(e);
- } catch (InstantiationException e) {
+ } catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java?rev=1147318&r1=1147317&r2=1147318&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceMapper.java Fri Jul 15 21:07:48 2011
@@ -1,22 +1,28 @@
package org.apache.mahout.math.hadoop.similarity;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
-import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileStatus;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.FileUtil;
-import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
-import org.apache.mahout.clustering.canopy.Canopy;
-import org.apache.mahout.clustering.kmeans.Cluster;
import org.apache.mahout.common.StringTuple;
import org.apache.mahout.common.distance.DistanceMeasure;
-import org.apache.mahout.common.iterator.sequencefile.PathFilters;
-import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -25,7 +31,6 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Collection;
import java.util.List;
/**
@@ -59,81 +64,18 @@ public class VectorDistanceMapper extend
protected void setup(Context context) throws IOException, InterruptedException {
super.setup(context);
Configuration conf = context.getConfiguration();
+ ClassLoader ccl = Thread.currentThread().getContextClassLoader();
try {
- ClassLoader ccl = Thread.currentThread().getContextClassLoader();
measure = ccl.loadClass(conf.get(VectorDistanceSimilarityJob.DISTANCE_MEASURE_KEY))
.asSubclass(DistanceMeasure.class).newInstance();
measure.configure(conf);
-
-
- String seedPathStr = conf.get(VectorDistanceSimilarityJob.SEEDS_PATH_KEY);
- if (seedPathStr != null && seedPathStr.length() > 0) {
-
- Path thePath = new Path(seedPathStr, "*");
- Collection<Path> result = Lists.newArrayList();
-
- // get all filtered file names in result list
- FileSystem fs = thePath.getFileSystem(conf);
- FileStatus[] matches = fs.listStatus(FileUtil.stat2Paths(fs.globStatus(thePath, PathFilters.partFilter())),
- PathFilters.partFilter());
-
- for (FileStatus match : matches) {
- result.add(fs.makeQualified(match.getPath()));
- }
- seedVectors = new ArrayList<NamedVector>(100);
- long item = 0;
- for (Path seedPath : result) {
- for (Writable value : new SequenceFileValueIterable<Writable>(seedPath, conf)) {
- Class<? extends Writable> valueClass = value.getClass();
- if (valueClass.equals(Cluster.class)) {
- // get the cluster info
- Cluster cluster = (Cluster) value;
- Vector vector = cluster.getCenter();
- if (vector instanceof NamedVector) {
- seedVectors.add((NamedVector) vector);
- } else {
- seedVectors.add(new NamedVector(vector, cluster.getIdentifier()));
- }
- } else if (valueClass.equals(Canopy.class)) {
- // get the cluster info
- Canopy canopy = (Canopy) value;
- Vector vector = canopy.getCenter();
- if (vector instanceof NamedVector) {
- seedVectors.add((NamedVector) vector);
- } else {
- seedVectors.add(new NamedVector(vector, canopy.getIdentifier()));
- }
- } else if (valueClass.equals(Vector.class)) {
- Vector vector = (Vector) value;
- if (vector instanceof NamedVector) {
- seedVectors.add((NamedVector) vector);
- } else {
- seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
- }
- } else if (valueClass.equals(VectorWritable.class) || valueClass.isInstance(VectorWritable.class)) {
- VectorWritable vw = (VectorWritable) value;
- Vector vector = vw.get();
- if (vector instanceof NamedVector) {
- seedVectors.add((NamedVector) vector);
- } else {
- seedVectors.add(new NamedVector(vector, seedPath + "." + item++));
- }
- } else {
- throw new IllegalStateException("Bad value class: " + valueClass);
- }
- }
- }
- if (seedVectors.isEmpty()) {
- throw new IllegalStateException("No seeds found. Check your path: " + seedPathStr);
- } else {
- log.info("Seed Vectors size: " + seedVectors.size());
- }
- }
- } catch (ClassNotFoundException e) {
+ seedVectors = new ArrayList<NamedVector>(1000);
+ SeedVectorUtil.loadSeedVectors(conf, seedVectors);
+ } catch (InstantiationException e) {
throw new IllegalStateException(e);
} catch (IllegalAccessException e) {
throw new IllegalStateException(e);
- } catch (InstantiationException e) {
+ } catch (ClassNotFoundException e) {
throw new IllegalStateException(e);
}
}
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java?rev=1147318&r1=1147317&r2=1147318&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/similarity/VectorDistanceSimilarityJob.java Fri Jul 15 21:07:48 2011
@@ -20,6 +20,7 @@ package org.apache.mahout.math.hadoop.si
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
+import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
@@ -32,6 +33,7 @@ import org.apache.mahout.common.StringTu
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure;
+import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -46,6 +48,7 @@ public class VectorDistanceSimilarityJob
public static final String SEEDS = "seeds";
public static final String SEEDS_PATH_KEY = "seedsPath";
public static final String DISTANCE_MEASURE_KEY = "vectorDistSim.measure";
+ public static final String OUT_TYPE_KEY = "outType";
public static void main(String[] args) throws Exception {
ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
@@ -59,7 +62,7 @@ public class VectorDistanceSimilarityJob
addOption(DefaultOptionCreator.distanceMeasureOption().create());
addOption(SEEDS, "s", "The set of vectors to compute distances against. Must fit in memory on the mapper");
addOption(DefaultOptionCreator.overwriteOption().create());
-
+ addOption(OUT_TYPE_KEY, "ot", "[pw|v] -- Define the output style: pairwise, the default, (pw) or vector (v). Pairwise is a tuple of <seed, other, distance>, vector is <other, <Vector of size the number of seeds>>.", "pw");
if (parseArguments(args) == null) {
return -1;
}
@@ -79,7 +82,12 @@ public class VectorDistanceSimilarityJob
if (getConf() == null) {
setConf(new Configuration());
}
- run(getConf(), input, seeds, output, measure);
+ String outType = getOption(OUT_TYPE_KEY);
+ if (outType == null) {
+ outType = "pw";
+ }
+
+ run(getConf(), input, seeds, output, measure, outType);
return 0;
}
@@ -87,17 +95,28 @@ public class VectorDistanceSimilarityJob
Path input,
Path seeds,
Path output,
- DistanceMeasure measure) throws IOException, ClassNotFoundException, InterruptedException {
+ DistanceMeasure measure, String outType) throws IOException, ClassNotFoundException, InterruptedException {
conf.set(DISTANCE_MEASURE_KEY, measure.getClass().getName());
conf.set(SEEDS_PATH_KEY, seeds.toString());
Job job = new Job(conf, "Vector Distance Similarity: seeds: " + seeds + " input: " + input);
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
- job.setMapOutputKeyClass(StringTuple.class);
- job.setOutputKeyClass(StringTuple.class);
- job.setMapOutputValueClass(DoubleWritable.class);
- job.setOutputValueClass(DoubleWritable.class);
- job.setMapperClass(VectorDistanceMapper.class);
+ if (outType.equalsIgnoreCase("pw")) {
+ job.setMapOutputKeyClass(StringTuple.class);
+ job.setOutputKeyClass(StringTuple.class);
+ job.setMapOutputValueClass(DoubleWritable.class);
+ job.setOutputValueClass(DoubleWritable.class);
+ job.setMapperClass(VectorDistanceMapper.class);
+ } else if (outType.equalsIgnoreCase("v")) {
+ job.setMapOutputKeyClass(Text.class);
+ job.setOutputKeyClass(Text.class);
+ job.setMapOutputValueClass(VectorWritable.class);
+ job.setOutputValueClass(VectorWritable.class);
+ job.setMapperClass(VectorDistanceInvertedMapper.class);
+ } else {
+ throw new InterruptedException("Invalid outType specified: " + outType);
+ }
+
job.setNumReduceTasks(0);
FileInputFormat.addInputPath(job, input);
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java?rev=1147318&r1=1147317&r2=1147318&view=diff
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java (original)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/similarity/TestVectorDistanceSimilarityJob.java Fri Jul 15 21:07:48 2011
@@ -23,15 +23,20 @@ import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DoubleWritable;
import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.io.Text;
+import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.ClusteringTestUtils;
-import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.common.DummyOutputCollector;
import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.Pair;
import org.apache.mahout.common.StringTuple;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
import org.apache.mahout.common.distance.EuclideanDistanceMeasure;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable;
+import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
@@ -42,6 +47,7 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
/**
*
@@ -98,6 +104,37 @@ public class TestVectorDistanceSimilarit
}
+ @Test
+ public void testVectorDistanceInvertedMapper() throws Exception {
+ Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context =
+ EasyMock.createMock(Mapper.Context.class);
+ Vector expectVec = new DenseVector(new double[]{Math.sqrt(2.0), 1.0});
+ context.write(new Text("other"), new VectorWritable(expectVec));
+ EasyMock.replay(context);
+ Vector vector = new NamedVector(new RandomAccessSparseVector(2), "other");
+ vector.set(0, 2);
+ vector.set(1, 2);
+
+ VectorDistanceInvertedMapper mapper = new VectorDistanceInvertedMapper();
+ setField(mapper, "measure", new EuclideanDistanceMeasure());
+ List<NamedVector> seedVectors = new ArrayList<NamedVector>();
+ Vector seed1 = new RandomAccessSparseVector(2);
+ seed1.set(0, 1);
+ seed1.set(1, 1);
+ Vector seed2 = new RandomAccessSparseVector(2);
+ seed2.set(0, 2);
+ seed2.set(1, 1);
+
+ seedVectors.add(new NamedVector(seed1, "foo"));
+ seedVectors.add(new NamedVector(seed2, "foo2"));
+ setField(mapper, "seedVectors", seedVectors);
+
+ mapper.map(new IntWritable(123), new VectorWritable(vector), context);
+
+ EasyMock.verify(context);
+
+ }
+
public static final double[][] REFERENCE = {
{1, 1}, {2, 1}, {1, 2}, {2, 2}, {3, 3}, {4, 4}, {5, 4}, {4, 5}, {5, 5}
};
@@ -119,8 +156,49 @@ public class TestVectorDistanceSimilarit
String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName()
- };
+ };
ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
+ int expect = SEEDS.length * REFERENCE.length;
+ DummyOutputCollector<StringTuple, DoubleWritable> collector =
+ new DummyOutputCollector<StringTuple, DoubleWritable>();
+ //
+ for (Pair<StringTuple, DoubleWritable> record :
+ new SequenceFileIterable<StringTuple, DoubleWritable>(
+ new Path(output, "part-m-00000"), conf)) {
+ collector.collect(record.getFirst(), record.getSecond());
+ }
+ assertEquals(expect, collector.getData().size());
+ }
+
+ @Test
+ public void testRunInverted() throws Exception {
+ Path input = getTestTempDirPath("input");
+ Path output = getTestTempDirPath("output");
+ Path seedsPath = getTestTempDirPath("seeds");
+ List<VectorWritable> points = getPointsWritable(REFERENCE);
+ List<VectorWritable> seeds = getPointsWritable(SEEDS);
+ Configuration conf = new Configuration();
+ ClusteringTestUtils.writePointsToFile(points, true, new Path(input, "file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(seeds, true, new Path(seedsPath, "part-seeds"), fs, conf);
+ String[] args = {optKey(DefaultOptionCreator.INPUT_OPTION), input.toString(),
+ optKey(VectorDistanceSimilarityJob.SEEDS), seedsPath.toString(), optKey(DefaultOptionCreator.OUTPUT_OPTION),
+ output.toString(), optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), EuclideanDistanceMeasure.class.getName(),
+ optKey(VectorDistanceSimilarityJob.OUT_TYPE_KEY), "v"
+ };
+ ToolRunner.run(new Configuration(), new VectorDistanceSimilarityJob(), args);
+
+ DummyOutputCollector<Text, VectorWritable> collector =
+ new DummyOutputCollector<Text, VectorWritable>();
+ //
+ for (Pair<Text, VectorWritable> record :
+ new SequenceFileIterable<Text, VectorWritable>(
+ new Path(output, "part-m-00000"), conf)) {
+ collector.collect(record.getFirst(), record.getSecond());
+ }
+ assertEquals(REFERENCE.length, collector.getData().size());
+ for (Map.Entry<Text, List<VectorWritable>> entry : collector.getData().entrySet()) {
+ assertEquals(SEEDS.length, entry.getValue().iterator().next().get().size());
+ }
}
public static List<VectorWritable> getPointsWritable(double[][] raw) {