You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by jm...@apache.org on 2010/03/06 06:08:37 UTC
svn commit: r919702 - in /lucene/mahout/trunk/core/src:
main/java/org/apache/mahout/math/hadoop/
test/java/org/apache/mahout/math/hadoop/
Author: jmannix
Date: Sat Mar 6 05:08:36 2010
New Revision: 919702
URL: http://svn.apache.org/viewvc?rev=919702&view=rev
Log:
Implements MAHOUT-312, MAHOUT-313, and MAHOUT-314
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/MatrixMultiplicationJob.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TimesSquaredJob.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java?rev=919702&r1=919701&r2=919702&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/DistributedRowMatrix.java Sat Mar 6 05:08:36 2010
@@ -17,18 +17,25 @@
package org.apache.mahout.math.hadoop;
+import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
+import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.JobConfigurable;
+import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.MatrixSlice;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorIterable;
import org.apache.mahout.math.VectorWritable;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.io.DataInput;
+import java.io.DataOutput;
import java.io.IOException;
import java.util.Iterator;
import java.util.NoSuchElementException;
@@ -54,8 +61,10 @@
*/
public class DistributedRowMatrix implements VectorIterable, JobConfigurable {
+ private static final Logger log = LoggerFactory.getLogger(DistributedRowMatrix.class);
+
private final String inputPathString;
- private final String outputTmpPathString;
+ private String outputTmpPathString;
private JobConf conf;
private Path rowPath;
private Path outputTmpBasePath;
@@ -91,12 +100,20 @@
return outputTmpBasePath;
}
+ public void setOutputTempPathString(String outPathString) {
+ try {
+ outputTmpBasePath = FileSystem.get(conf).makeQualified(new Path(outPathString));
+ } catch (IOException ioe) {
+ log.warn("Unable to set outputBasePath to {}, leaving as {}",
+ outPathString, outputTmpBasePath.toString());
+ }
+ }
+
@Override
public Iterator<MatrixSlice> iterateAll() {
try {
FileSystem fs = FileSystem.get(conf);
- SequenceFile.Reader reader = new SequenceFile.Reader(fs, rowPath, conf);
- return new DistributedMatrixIterator(reader);
+ return new DistributedMatrixIterator(fs, rowPath, conf);
} catch (IOException ioe) {
throw new IllegalStateException(ioe);
}
@@ -117,10 +134,49 @@
return numCols;
}
+ public DistributedRowMatrix times(DistributedRowMatrix other) {
+ if(numRows != other.numRows()) {
+ throw new CardinalityException(numRows, other.numRows());
+ }
+ Path outPath = new Path(outputTmpBasePath.getParent(), "productWith");
+ JobConf conf = MatrixMultiplicationJob.createMatrixMultiplyJobConf(rowPath, other.rowPath, outPath, other.numCols);
+ try {
+ JobClient.runJob(conf);
+ DistributedRowMatrix out = new DistributedRowMatrix(outPath.toString(),
+ outputTmpPathString, numRows, other.numCols());
+ out.configure(conf);
+ return out;
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ }
+
+ public DistributedRowMatrix transpose() {
+ Path outputPath = new Path(rowPath.getParent(), "transpose-" + (byte)System.nanoTime());
+ try {
+ JobConf conf = TransposeJob.buildTransposeJobConf(rowPath, outputPath, numRows);
+ JobClient.runJob(conf);
+ DistributedRowMatrix m = new DistributedRowMatrix(outputPath.toString(), outputTmpPathString, numCols, numRows);
+ m.configure(this.conf);
+ return m;
+ } catch (IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
+ }
+
@Override
public Vector times(Vector v) {
- // TODO: times(Vector) is easy, works pretty much like timesSquared.
- throw new UnsupportedOperationException("DistributedRowMatrix methods other than timesSquared not supported yet");
+ try {
+ JobConf conf = TimesSquaredJob.createTimesJobConf(v,
+ numRows,
+ rowPath,
+ new Path(outputTmpPathString,
+ new Path(Long.toString(System.nanoTime()))));
+ JobClient.runJob(conf);
+ return TimesSquaredJob.retrieveTimesSquaredOutputVector(conf);
+ } catch(IOException ioe) {
+ throw new RuntimeException(ioe);
+ }
}
@Override
@@ -143,14 +199,21 @@
}
public static class DistributedMatrixIterator implements Iterator<MatrixSlice> {
- private final SequenceFile.Reader reader;
+ private SequenceFile.Reader reader;
+ private FileStatus[] statuses;
private boolean hasBuffered = false;
private boolean hasNext = false;
+ private int statusIndex = 0;
+ private final FileSystem fs;
+ private final JobConf conf;
private final IntWritable i = new IntWritable();
private final VectorWritable v = new VectorWritable();
- public DistributedMatrixIterator(SequenceFile.Reader reader) {
- this.reader = reader;
+ public DistributedMatrixIterator(FileSystem fs, Path rowPath, JobConf conf) throws IOException {
+ this.fs = fs;
+ this.conf = conf;
+ statuses = fs.globStatus(new Path(rowPath, "*"));
+ reader = new SequenceFile.Reader(fs, statuses[statusIndex].getPath(), conf);
}
@Override
@@ -158,6 +221,11 @@
try {
if(!hasBuffered) {
hasNext = reader.next(i, v);
+ if(!hasNext && statusIndex < statuses.length - 1) {
+ statusIndex++;
+ reader = new SequenceFile.Reader(fs, statuses[statusIndex].getPath(), conf);
+ hasNext = reader.next(i, v);
+ }
hasBuffered = true;
}
} catch (IOException ioe) {
@@ -186,4 +254,65 @@
}
}
+ public static class MatrixEntryWritable implements WritableComparable<MatrixEntryWritable> {
+ private int row;
+ private int col;
+ private double val;
+
+ public int getRow() {
+ return row;
+ }
+
+ public void setRow(int row) {
+ this.row = row;
+ }
+
+ public int getCol() {
+ return col;
+ }
+
+ public void setCol(int col) {
+ this.col = col;
+ }
+
+ public double getVal() {
+ return val;
+ }
+
+ public void setVal(double val) {
+ this.val = val;
+ }
+
+ @Override
+ public int compareTo(MatrixEntryWritable o) {
+ if(row > o.row) {
+ return 1;
+ } else if(row < o.row) {
+ return -1;
+ } else {
+ if(col > o.col) {
+ return 1;
+ } else if(col < o.col) {
+ return -1;
+ } else {
+ return 0;
+ }
+ }
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(row);
+ out.writeInt(col);
+ out.writeDouble(val);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ row = in.readInt();
+ col = in.readInt();
+ val = in.readDouble();
+ }
+ }
+
}
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/MatrixMultiplicationJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/MatrixMultiplicationJob.java?rev=919702&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/MatrixMultiplicationJob.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/MatrixMultiplicationJob.java Sat Mar 6 05:08:36 2010
@@ -0,0 +1,157 @@
+package org.apache.mahout.math.hadoop;
+
+import org.apache.commons.cli2.Option;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapred.FileOutputFormat;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reducer;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapred.SequenceFileInputFormat;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.mapred.join.CompositeInputFormat;
+import org.apache.hadoop.mapred.join.TupleWritable;
+import org.apache.hadoop.mapred.lib.MultipleInputs;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.AbstractJob;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.Map;
+
+public class MatrixMultiplicationJob extends AbstractJob {
+
+ private static final String OUT_CARD = "output.vector.cardinality";
+
+ private Map<String,String> argMap;
+
+ public static JobConf createMatrixMultiplyJobConf(Path aPath, Path bPath, Path outPath, int outCardinality) {
+ JobConf conf = new JobConf(MatrixMultiplicationJob.class);
+ conf.setInputFormat(CompositeInputFormat.class);
+ conf.set("mapred.join.expr", CompositeInputFormat.compose(
+ "inner", SequenceFileInputFormat.class, new Path[] {aPath, bPath}));
+ conf.setInt(OUT_CARD, outCardinality);
+ conf.setOutputFormat(SequenceFileOutputFormat.class);
+ FileOutputFormat.setOutputPath(conf, outPath);
+ conf.setMapperClass(MatrixMultiplyMapper.class);
+ conf.setCombinerClass(MatrixMultiplicationReducer.class);
+ conf.setReducerClass(MatrixMultiplicationReducer.class);
+ conf.setMapOutputKeyClass(IntWritable.class);
+ conf.setMapOutputValueClass(VectorWritable.class);
+ conf.setOutputKeyClass(IntWritable.class);
+ conf.setOutputValueClass(VectorWritable.class);
+ return conf;
+ }
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new MatrixMultiplicationJob(), args);
+ }
+
+ @Override
+ public int run(String[] strings) throws Exception {
+ Option numRowsAOpt = buildOption("numRowsA",
+ "nra",
+ "Number of rows of the first input matrix");
+ Option numColsAOpt = buildOption("numColsA",
+ "nca",
+ "Number of columns of the first input matrix");
+ Option numRowsBOpt = buildOption("numRowsB",
+ "nrb",
+ "Number of rows of the second input matrix");
+
+ Option numColsBOpt = buildOption("numColsB",
+ "ncb",
+ "Number of columns of the second input matrix");
+ Option inputPathA = buildOption("inputPathA",
+ "ia",
+ "Path to the first input matrix");
+ Option inputPathB = buildOption("inputPathB",
+ "ib",
+ "Path to the second input matrix");
+
+ argMap = parseArguments(strings,
+ numRowsAOpt,
+ numRowsBOpt,
+ numColsAOpt,
+ numColsBOpt,
+ inputPathA,
+ inputPathB);
+
+ DistributedRowMatrix a = new DistributedRowMatrix(argMap.get("--inputPathA"),
+ argMap.get("--tempDir"),
+ Integer.parseInt(argMap.get("--numRowsA")),
+ Integer.parseInt(argMap.get("--numColsA")));
+ DistributedRowMatrix b = new DistributedRowMatrix(argMap.get("--inputPathB"),
+ argMap.get("--tempDir"),
+ Integer.parseInt(argMap.get("--numRowsB")),
+ Integer.parseInt(argMap.get("--numColsB")));
+
+ a.configure(new JobConf(getConf()));
+ b.configure(new JobConf(getConf()));
+
+ DistributedRowMatrix c = a.times(b);
+
+ return 0;
+ }
+
+ public static class MatrixMultiplyMapper extends MapReduceBase
+ implements Mapper<IntWritable,TupleWritable,IntWritable,VectorWritable> {
+
+ private int outCardinality;
+ private final IntWritable row = new IntWritable();
+ private final VectorWritable outVector = new VectorWritable();
+
+ public void configure(JobConf conf) {
+ outCardinality = conf.getInt(OUT_CARD, Integer.MAX_VALUE);
+ }
+
+ @Override
+ public void map(IntWritable index,
+ TupleWritable v,
+ OutputCollector<IntWritable,VectorWritable> out,
+ Reporter reporter) throws IOException {
+ boolean firstIsOutFrag = ((VectorWritable)v.get(0)).get().size() == outCardinality;
+ Vector outFrag = firstIsOutFrag ? ((VectorWritable)v.get(0)).get() : ((VectorWritable)v.get(1)).get();
+ Vector multiplier = firstIsOutFrag ? ((VectorWritable)v.get(1)).get() : ((VectorWritable)v.get(0)).get();
+
+ Iterator<Vector.Element> it = multiplier.iterateNonZero();
+ while(it.hasNext()) {
+ Vector.Element e = it.next();
+ row.set(e.index());
+ outVector.set(outFrag.times(e.get()));
+ out.collect(row, outVector);
+ }
+ }
+ }
+
+ public static class MatrixMultiplicationReducer extends MapReduceBase
+ implements Reducer<IntWritable,VectorWritable,IntWritable,VectorWritable> {
+
+ @Override
+ public void reduce(IntWritable rowNum,
+ Iterator<VectorWritable> it,
+ OutputCollector<IntWritable,VectorWritable> out,
+ Reporter reporter) throws IOException {
+ Vector accumulator;
+ Vector row;
+ if(it.hasNext()) {
+ accumulator = new RandomAccessSparseVector(it.next().get());
+ } else {
+ return;
+ }
+ while(it.hasNext()) {
+ row = it.next().get();
+ row.addTo(accumulator);
+ }
+ out.collect(rowNum, new VectorWritable(new SequentialAccessSparseVector(accumulator)));
+ }
+ }
+
+}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TimesSquaredJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TimesSquaredJob.java?rev=919702&r1=919701&r2=919702&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TimesSquaredJob.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TimesSquaredJob.java Sat Mar 6 05:08:36 2010
@@ -20,6 +20,7 @@
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.WritableComparable;
@@ -50,11 +51,11 @@
private static final Logger log = LoggerFactory.getLogger(TimesSquaredJob.class);
- public static final String INPUT_VECTOR = "timesSquared.inputVector";
- public static final String IS_SPARSE_OUTPUT = "timesSquared.outputVector.sparse";
- public static final String OUTPUT_VECTOR_DIMENSION = "timesSquared.output.dimension";
+ public static final String INPUT_VECTOR = "DistributedMatrix.times.inputVector";
+ public static final String IS_SPARSE_OUTPUT = "DistributedMatrix.times.outputVector.sparse";
+ public static final String OUTPUT_VECTOR_DIMENSION = "DistributedMatrix.times.output.dimension";
- public static final String OUTPUT_VECTOR_FILENAME = "timesSquaredOutputVector";
+ public static final String OUTPUT_VECTOR_FILENAME = "DistributedMatrix.times.outputVector";
private TimesSquaredJob() {}
@@ -68,7 +69,29 @@
VectorSummingReducer.class);
}
+ public static JobConf createTimesJobConf(Vector v,
+ int outDim,
+ Path matrixInputPath,
+ Path outputVectorPath) throws IOException {
+ return createTimesSquaredJobConf(v,
+ outDim,
+ matrixInputPath,
+ outputVectorPath,
+ TimesMapper.class,
+ VectorSummingReducer.class);
+ }
+
+
+ public static JobConf createTimesSquaredJobConf(Vector v,
+ Path matrixInputPath,
+ Path outputVectorPathBase,
+ Class<? extends TimesSquaredMapper> mapClass,
+ Class<? extends VectorSummingReducer> redClass) throws IOException {
+ return createTimesSquaredJobConf(v, v.size(), matrixInputPath, outputVectorPathBase, mapClass, redClass);
+ }
+
public static JobConf createTimesSquaredJobConf(Vector v,
+ int outputVectorDim,
Path matrixInputPath,
Path outputVectorPathBase,
Class<? extends TimesSquaredMapper> mapClass,
@@ -92,7 +115,7 @@
conf.set(INPUT_VECTOR, ivpURI.toString());
conf.setBoolean(IS_SPARSE_OUTPUT, !(v instanceof DenseVector));
- conf.setInt(OUTPUT_VECTOR_DIMENSION, v.size());
+ conf.setInt(OUTPUT_VECTOR_DIMENSION, outputVectorDim);
FileInputFormat.addInputPath(conf, matrixInputPath);
conf.setInputFormat(SequenceFileInputFormat.class);
FileOutputFormat.setOutputPath(conf, new Path(outputVectorPathBase, OUTPUT_VECTOR_FILENAME));
@@ -121,12 +144,12 @@
return vector;
}
- public static class TimesSquaredMapper extends MapReduceBase
- implements Mapper<WritableComparable<?>,VectorWritable, NullWritable,VectorWritable> {
+ public static class TimesSquaredMapper<T extends WritableComparable> extends MapReduceBase
+ implements Mapper<T,VectorWritable, NullWritable,VectorWritable> {
private Vector inputVector;
- private Vector outputVector;
- private OutputCollector<NullWritable,VectorWritable> out;
+ protected Vector outputVector;
+ protected OutputCollector<NullWritable,VectorWritable> out;
@Override
public void configure(JobConf conf) {
@@ -150,16 +173,17 @@
if(!(inputVector instanceof SequentialAccessSparseVector || inputVector instanceof DenseVector)) {
inputVector = new SequentialAccessSparseVector(inputVector);
}
+ int outDim = conf.getInt(OUTPUT_VECTOR_DIMENSION, Integer.MAX_VALUE);
outputVector = conf.getBoolean(IS_SPARSE_OUTPUT, false)
- ? new RandomAccessSparseVector(inputVector.size(), 10)
- : new DenseVector(inputVector.size());
+ ? new RandomAccessSparseVector(outDim, 10)
+ : new DenseVector(outDim);
} catch (IOException ioe) {
throw new IllegalStateException(ioe);
}
}
@Override
- public void map(WritableComparable<?> rowNum,
+ public void map(T rowNum,
VectorWritable v,
OutputCollector<NullWritable,VectorWritable> out,
Reporter rep) throws IOException {
@@ -183,6 +207,20 @@
}
+ public static class TimesMapper extends TimesSquaredMapper<IntWritable> {
+ @Override
+ public void map(IntWritable rowNum,
+ VectorWritable v,
+ OutputCollector<NullWritable,VectorWritable> out,
+ Reporter rep) {
+ this.out = out;
+ double d = scale(v);
+ if(d != 0.0) {
+ outputVector.setQuick(rowNum.get(), d);
+ }
+ }
+ }
+
public static class VectorSummingReducer extends MapReduceBase
implements Reducer<NullWritable,VectorWritable,NullWritable,VectorWritable> {
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java?rev=919702&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/math/hadoop/TransposeJob.java Sat Mar 6 05:08:36 2010
@@ -0,0 +1,131 @@
+package org.apache.mahout.math.hadoop;
+
+import org.apache.commons.cli2.Option;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.IntWritable;
+import org.apache.hadoop.mapred.FileInputFormat;
+import org.apache.hadoop.mapred.FileOutputFormat;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.MapReduceBase;
+import org.apache.hadoop.mapred.Mapper;
+import org.apache.hadoop.mapred.OutputCollector;
+import org.apache.hadoop.mapred.Reducer;
+import org.apache.hadoop.mapred.Reporter;
+import org.apache.hadoop.mapred.SequenceFileInputFormat;
+import org.apache.hadoop.mapred.SequenceFileOutputFormat;
+import org.apache.hadoop.util.ToolRunner;
+import org.apache.mahout.cf.taste.hadoop.AbstractJob;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.Map;
+
+/**
+ * TODO: rewrite to use helpful combiner.
+ */
+public class TransposeJob extends AbstractJob {
+ public static final String NUM_ROWS_KEY = "SparseRowMatrix.numRows";
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new TransposeJob(), args);
+ }
+
+ @Override
+ public int run(String[] strings) throws Exception {
+ Option numRowsOpt = buildOption("numRows",
+ "nr",
+ "Number of rows of the input matrix");
+ Option numColsOpt = buildOption("numCols",
+ "nc",
+ "Number of columns of the input matrix");
+ Map<String,String> parsedArgs = parseArguments(strings, numRowsOpt, numColsOpt);
+
+ String inputPathString = parsedArgs.get("--input");
+ String outputTmpPathString = parsedArgs.get("--tempDir");
+ int numRows = Integer.parseInt(parsedArgs.get("--numRows"));
+ int numCols = Integer.parseInt(parsedArgs.get("--numCols"));
+
+ DistributedRowMatrix matrix = new DistributedRowMatrix(inputPathString, outputTmpPathString, numRows, numCols);
+ matrix.configure(new JobConf(getConf()));
+ matrix.transpose();
+
+ return 0;
+ }
+
+ public static JobConf buildTransposeJobConf(Path matrixInputPath,
+ Path matrixOutputPath,
+ int numInputRows) throws IOException {
+ JobConf conf = new JobConf(TransposeJob.class);
+ conf.setJobName("TransposeJob: " + matrixInputPath + " transpose -> " + matrixOutputPath);
+ FileSystem fs = FileSystem.get(conf);
+ matrixInputPath = fs.makeQualified(matrixInputPath);
+ matrixOutputPath = fs.makeQualified(matrixOutputPath);
+ conf.setInt(NUM_ROWS_KEY, numInputRows);
+
+ FileInputFormat.addInputPath(conf, matrixInputPath);
+ conf.setInputFormat(SequenceFileInputFormat.class);
+ FileOutputFormat.setOutputPath(conf, matrixOutputPath);
+ conf.setMapperClass(TransposeMapper.class);
+ conf.setReducerClass(TransposeReducer.class);
+ conf.setMapOutputKeyClass(IntWritable.class);
+ conf.setMapOutputValueClass(DistributedRowMatrix.MatrixEntryWritable.class);
+ conf.setOutputFormat(SequenceFileOutputFormat.class);
+ conf.setOutputKeyClass(IntWritable.class);
+ conf.setOutputValueClass(VectorWritable.class);
+ return conf;
+ }
+
+ public static class TransposeMapper extends MapReduceBase
+ implements Mapper<IntWritable,VectorWritable,IntWritable,DistributedRowMatrix.MatrixEntryWritable> {
+
+ @Override
+ public void map(IntWritable r,
+ VectorWritable v,
+ OutputCollector<IntWritable, DistributedRowMatrix.MatrixEntryWritable> out,
+ Reporter reporter) throws IOException {
+ DistributedRowMatrix.MatrixEntryWritable entry = new DistributedRowMatrix.MatrixEntryWritable();
+ Iterator<Vector.Element> it = v.get().iterateNonZero();
+ int row = r.get();
+ entry.setCol(row);
+ entry.setRow(-1); // output "row" is captured in the key
+ while(it.hasNext()) {
+ Vector.Element e = it.next();
+ r.set(e.index());
+ entry.setVal(e.get());
+ out.collect(r, entry);
+ }
+ }
+ }
+
+ public static class TransposeReducer extends MapReduceBase
+ implements Reducer<IntWritable,DistributedRowMatrix.MatrixEntryWritable,IntWritable,VectorWritable> {
+
+ private JobConf conf;
+ private int newNumCols;
+
+ @Override
+ public void configure(JobConf conf) {
+ this.conf = conf;
+ newNumCols = conf.getInt(NUM_ROWS_KEY, Integer.MAX_VALUE);
+ }
+
+ @Override
+ public void reduce(IntWritable outRow,
+ Iterator<DistributedRowMatrix.MatrixEntryWritable> it,
+ OutputCollector<IntWritable, VectorWritable> out,
+ Reporter reporter) throws IOException {
+ RandomAccessSparseVector tmp = new RandomAccessSparseVector(newNumCols, 100);
+ while(it.hasNext()) {
+ DistributedRowMatrix.MatrixEntryWritable e = it.next();
+ tmp.setQuick(e.getCol(), e.getVal());
+ }
+ SequentialAccessSparseVector outVector = new SequentialAccessSparseVector(tmp);
+ out.collect(outRow, new VectorWritable(outVector));
+ }
+ }
+}
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java?rev=919702&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/math/hadoop/TestDistributedRowMatrix.java Sat Mar 6 05:08:36 2010
@@ -0,0 +1,167 @@
+package org.apache.mahout.math.hadoop;
+
+import junit.framework.TestCase;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.clustering.ClusteringTestUtils;
+import org.apache.mahout.clustering.canopy.TestCanopyCreation;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixSlice;
+import org.apache.mahout.math.RandomAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorIterable;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.decomposer.SolverTest;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.Map;
+
+public class TestDistributedRowMatrix extends TestCase {
+
+ private static final String TESTDATA = "testdata";
+
+ public TestDistributedRowMatrix() {
+ super();
+ }
+
+ @Override
+ public void setUp() throws Exception {
+ File testData = new File(TESTDATA);
+ if (testData.exists()) {
+ TestCanopyCreation.rmr(TESTDATA);
+ }
+ testData.mkdir();
+ }
+
+ @Override
+ public void tearDown() throws Exception {
+ TestCanopyCreation.rmr(TESTDATA);
+ super.tearDown();
+ }
+
+ public static void assertEquals(double d1, double d2, double errorTolerance) {
+ assertTrue(Math.abs(d1-d2) < errorTolerance);
+ }
+
+ public static void assertEquals(VectorIterable m, VectorIterable mtt, double errorTolerance) {
+ Iterator<MatrixSlice> mIt = m.iterateAll();
+ Iterator<MatrixSlice> mttIt = mtt.iterateAll();
+ Map<Integer, Vector> mMap = new HashMap<Integer,Vector>();
+ Map<Integer, Vector> mttMap = new HashMap<Integer, Vector>();
+ while(mIt.hasNext() && mttIt.hasNext()) {
+ MatrixSlice ms = mIt.next();
+ mMap.put(ms.index(), ms.vector());
+ MatrixSlice mtts = mttIt.next();
+ mttMap.put(mtts.index(), mtts.vector());
+ }
+ for(Integer i : mMap.keySet()) {
+ if(mMap.get(i) == null || mttMap.get(i) == null) {
+ assertTrue(mMap.get(i) == null || mMap.get(i).norm(2) == 0);
+ assertTrue(mttMap.get(i) == null || mttMap.get(i).norm(2) == 0);
+ } else {
+ assertTrue(mMap.get(i).getDistanceSquared(mttMap.get(i)) < errorTolerance);
+ }
+ }
+ }
+
+ public void testTranspose() throws Exception {
+ DistributedRowMatrix m = randomDistributedMatrix(10, 9, 5, 4, 1.0, false);
+ DistributedRowMatrix mt = m.transpose();
+ mt.setOutputTempPathString(new Path(m.getOutputTempPath().getParent(), "/tmpOutTranspose").toString());
+ DistributedRowMatrix mtt = mt.transpose();
+ assertEquals(m, mtt, 1e-9);
+ }
+
+ public void testMatrixTimesVector() throws Exception {
+ Vector v = new RandomAccessSparseVector(50);
+ v.assign(1.0);
+ Matrix m = SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0);
+ DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, false);
+
+ Vector expected = m.times(v);
+ Vector actual = dm.times(v);
+ assertEquals(expected.getDistanceSquared(actual), 0.0, 1e-9);
+ }
+
+ public void testMatrixTimesSquaredVector() throws Exception {
+ Vector v = new RandomAccessSparseVector(50);
+ v.assign(1.0);
+ Matrix m = SolverTest.randomSequentialAccessSparseMatrix(100, 90, 50, 20, 1.0);
+ DistributedRowMatrix dm = randomDistributedMatrix(100, 90, 50, 20, 1.0, false);
+
+ Vector expected = m.timesSquared(v);
+ Vector actual = dm.timesSquared(v);
+ assertEquals(expected.getDistanceSquared(actual), 0.0, 1e-9);
+ }
+
+ public void testMatrixTimesMatrix() throws Exception {
+ Matrix inputA = SolverTest.randomSequentialAccessSparseMatrix(20, 19, 15, 5, 10.0);
+ Matrix inputB = SolverTest.randomSequentialAccessSparseMatrix(20, 13, 25, 10, 5.0);
+ Matrix expected = inputA.transpose().times(inputB);
+
+ DistributedRowMatrix distA = randomDistributedMatrix(20, 19, 15, 5, 10.0, false, "/distA");
+ DistributedRowMatrix distB = randomDistributedMatrix(20, 13, 25, 10, 5.0, false, "/distB");
+ DistributedRowMatrix product = distA.times(distB);
+
+ assertEquals(expected, product, 1e-9);
+ }
+
+ public static DistributedRowMatrix randomDistributedMatrix(int numRows,
+ int nonNullRows,
+ int numCols,
+ int entriesPerRow,
+ double entryMean,
+ boolean isSymmetric) throws Exception {
+ return randomDistributedMatrix(numRows, nonNullRows, numCols, entriesPerRow, entryMean, isSymmetric, "");
+ }
+
+ public static DistributedRowMatrix randomDistributedMatrix(int numRows,
+ int nonNullRows,
+ int numCols,
+ int entriesPerRow,
+ double entryMean,
+ boolean isSymmetric,
+ String baseTmpDir) throws IOException {
+ baseTmpDir = TESTDATA + baseTmpDir;
+ Matrix c = SolverTest.randomSequentialAccessSparseMatrix(numRows, nonNullRows, numCols, entriesPerRow, entryMean);
+ if(isSymmetric) {
+ c = c.times(c.transpose());
+ }
+ final Matrix m = c;
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(conf);
+
+ ClusteringTestUtils.writePointsToFile(new Iterable<VectorWritable>() {
+ @Override
+ public Iterator<VectorWritable> iterator() {
+ final Iterator<MatrixSlice> it = m.iterator();
+ final VectorWritable v = new VectorWritable();
+ return new Iterator<VectorWritable>() {
+ @Override
+ public boolean hasNext() { return it.hasNext(); }
+ @Override
+ public VectorWritable next() {
+ MatrixSlice slice = it.next();
+ v.set(slice.vector());
+ return v;
+ }
+ @Override
+ public void remove() { it.remove(); }
+ };
+ }
+ }, true, baseTmpDir + "/distMatrix/part-00000", fs, conf);
+
+ DistributedRowMatrix distMatrix = new DistributedRowMatrix(baseTmpDir + "/distMatrix",
+ baseTmpDir + "/tmpOut",
+ m.numRows(),
+ m.numCols());
+ distMatrix.configure(new JobConf(conf));
+
+ return distMatrix;
+ }
+}