You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2009/06/24 20:41:38 UTC
svn commit: r788116 - in /lucene/mahout/trunk:
core/src/main/java/org/apache/mahout/clustering/dirichlet/
core/src/main/java/org/apache/mahout/clustering/dirichlet/models/
core/src/test/java/org/apache/mahout/clustering/
core/src/test/java/org/apache/m...
Author: jeastman
Date: Wed Jun 24 18:41:37 2009
New Revision: 788116
URL: http://svn.apache.org/viewvc?rev=788116&view=rev
Log:
MAHOUT-137
- modified Dirichlet clustering to use Writable Vectors
- modified to use Writable DirichletClusters and Models
- removed 2d assertion in NormalModel and removed NormalScModel from examples
- updated unit tests; all run
Removed:
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java Wed Jun 24 18:41:37 2009
@@ -16,8 +16,12 @@
*/
package org.apache.mahout.clustering.dirichlet;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
import java.lang.reflect.Type;
+import org.apache.hadoop.io.Writable;
import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.matrix.JsonVectorAdapter;
import org.apache.mahout.matrix.Vector;
@@ -26,7 +30,20 @@
import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
-public class DirichletCluster<Observation> {
+public class DirichletCluster<Observation> implements Writable {
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.totalCount = in.readDouble();
+ this.model = readModel(in);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeDouble(totalCount);
+ writeModel(out, model);
+ }
public Model<Observation> model; // the model for this iteration
@@ -58,6 +75,44 @@
return gson.toJson(this, typeOfModel);
}
+ /**
+ * Reads a typed Model instance from the input stream
+ *
+ * @param in
+ * @return
+ * @throws IOException
+ */
+ @SuppressWarnings("unchecked")
+ public static Model readModel(DataInput in) throws IOException {
+ String modelClassName = in.readUTF();
+ Model model;
+ try {
+ model = Class.forName(modelClassName).asSubclass(Model.class)
+ .newInstance();
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException(e);
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException(e);
+ } catch (InstantiationException e) {
+ throw new RuntimeException(e);
+ }
+ model.readFields(in);
+ return model;
+ }
+
+ /**
+ * Writes a typed Model instance to the output stream
+ *
+ * @param out
+ * @param matrix
+ * @throws IOException
+ */
+ @SuppressWarnings("unchecked")
+ public static void writeModel(DataOutput out, Model model) throws IOException {
+ out.writeUTF(model.getClass().getName());
+ model.write(out);
+ }
+
public static DirichletCluster<Vector> fromFormatString(String formatString) {
GsonBuilder builder = new GsonBuilder();
builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Wed Jun 24 18:41:37 2009
@@ -27,9 +27,11 @@
import org.apache.hadoop.mapred.FileOutputFormat;
import org.apache.hadoop.mapred.JobClient;
import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.SequenceFileInputFormat;
import org.apache.hadoop.mapred.SequenceFileOutputFormat;
import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.matrix.SparseVector;
import org.apache.mahout.matrix.Vector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -109,9 +111,8 @@
for (int i = 0; i < numModels; i++) {
Path path = new Path(stateIn + "/part-" + i);
SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path,
- Text.class, Text.class);
- String stateString = state.clusters.get(i).asFormatString();
- writer.append(new Text(Integer.toString(i)), new Text(stateString));
+ Text.class, DirichletCluster.class);
+ writer.append(new Text(Integer.toString(i)), state.clusters.get(i));
writer.close();
}
}
@@ -146,7 +147,9 @@
JobConf conf = new JobConf(DirichletDriver.class);
conf.setOutputKeyClass(Text.class);
- conf.setOutputValueClass(Text.class);
+ conf.setOutputValueClass(DirichletCluster.class);
+ conf.setMapOutputKeyClass(Text.class);
+ conf.setMapOutputValueClass(SparseVector.class);
FileInputFormat.setInputPaths(conf, new Path(input));
Path outPath = new Path(stateOut);
@@ -155,6 +158,7 @@
conf.setMapperClass(DirichletMapper.class);
conf.setReducerClass(DirichletReducer.class);
conf.setNumReduceTasks(numReducers);
+ conf.setInputFormat(SequenceFileInputFormat.class);
conf.setOutputFormat(SequenceFileOutputFormat.class);
conf.set(STATE_IN_KEY, stateIn);
conf.set(MODEL_FACTORY_KEY, modelFactory);
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java Wed Jun 24 18:41:37 2009
@@ -30,26 +30,24 @@
import org.apache.hadoop.mapred.Mapper;
import org.apache.hadoop.mapred.OutputCollector;
import org.apache.hadoop.mapred.Reporter;
-import org.apache.mahout.matrix.AbstractVector;
import org.apache.mahout.matrix.DenseVector;
import org.apache.mahout.matrix.TimesFunction;
import org.apache.mahout.matrix.Vector;
public class DirichletMapper extends MapReduceBase implements
- Mapper<WritableComparable<?>, Text, Text, Text> {
+ Mapper<WritableComparable<?>, Vector, Text, Vector> {
private DirichletState<Vector> state;
@Override
- public void map(WritableComparable<?> key, Text values,
- OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
- Vector v = AbstractVector.decodeVector(values.toString());
+ public void map(WritableComparable<?> key, Vector v,
+ OutputCollector<Text, Vector> output, Reporter reporter) throws IOException {
// compute a normalized vector of probabilities that v is described by each model
Vector pi = normalizedProbabilities(state, v);
// then pick one model by sampling a Multinomial distribution based upon them
// see: http://en.wikipedia.org/wiki/Multinomial_distribution
int k = UncommonDistributions.rMultinom(pi);
- output.collect(new Text(String.valueOf(k)), values);
+ output.collect(new Text(String.valueOf(k)), v);
}
public void configure(DirichletState<Vector> state) {
@@ -62,6 +60,7 @@
state = getDirichletState(job);
}
+ @SuppressWarnings("unchecked")
public static DirichletState<Vector> getDirichletState(JobConf job) {
String statePath = job.get(DirichletDriver.STATE_IN_KEY);
String modelFactory = job.get(DirichletDriver.MODEL_FACTORY_KEY);
@@ -79,13 +78,11 @@
job);
try {
Text key = new Text();
- Text value = new Text();
- while (reader.next(key, value)) {
+ DirichletCluster cluster = new DirichletCluster();
+ while (reader.next(key, cluster)) {
int index = Integer.parseInt(key.toString());
- String formatString = value.toString();
- DirichletCluster<Vector> cluster = DirichletCluster
- .fromFormatString(formatString);
state.clusters.set(index, cluster);
+ cluster = new DirichletCluster();
}
} finally {
reader.close();
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java Wed Jun 24 18:41:37 2009
@@ -27,29 +27,29 @@
import org.apache.hadoop.mapred.Reducer;
import org.apache.hadoop.mapred.Reporter;
import org.apache.mahout.clustering.dirichlet.models.Model;
-import org.apache.mahout.matrix.AbstractVector;
import org.apache.mahout.matrix.Vector;
public class DirichletReducer extends MapReduceBase implements
- Reducer<Text, Text, Text, Text> {
+ Reducer<Text, Vector, Text, DirichletCluster<Vector>> {
DirichletState<Vector> state;
public Model<Vector>[] newModels;
@Override
- public void reduce(Text key, Iterator<Text> values,
- OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
+ public void reduce(Text key, Iterator<Vector> values,
+ OutputCollector<Text, DirichletCluster<Vector>> output, Reporter reporter)
+ throws IOException {
int k = Integer.parseInt(key.toString());
Model<Vector> model = newModels[k];
while (values.hasNext()) {
- Vector v = AbstractVector.decodeVector(values.next().toString());
+ Vector v = values.next();
model.observe(v);
}
model.computeParameters();
DirichletCluster<Vector> cluster = state.clusters.get(k);
cluster.setModel(model);
- output.collect(key, new Text(cluster.asFormatString()));
+ output.collect(key, cluster);
}
public void configure(DirichletState<Vector> state) {
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java Wed Jun 24 18:41:37 2009
@@ -16,6 +16,11 @@
*/
package org.apache.mahout.clustering.dirichlet.models;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.mahout.matrix.AbstractVector;
import org.apache.mahout.matrix.SquareRootFunction;
import org.apache.mahout.matrix.Vector;
@@ -50,6 +55,7 @@
/**
* Return an instance with the same parameters
+ *
* @return an AsymmetricSampledNormalModel
*/
AsymmetricSampledNormalModel sample() {
@@ -60,7 +66,7 @@
public void observe(Vector x) {
s0++;
if (s1 == null)
- s1 = x.like();
+ s1 = x.clone();
else
s1 = s1.plus(x);
if (s2 == null)
@@ -86,10 +92,10 @@
/**
* Calculate a pdf using the supplied sample and sd
*
- * @param x a Vector sample
- * @param sd a double std deviation
- * @return
- */
+ * @param x a Vector sample
+ * @param sd a double std deviation
+ * @return
+ */
private double pdf(Vector x, double sd) {
assert x.getNumNondefaultElements() == 2;
double sd2 = sd * sd;
@@ -104,8 +110,8 @@
assert x.getNumNondefaultElements() == 2;
double pdf0 = pdf(x, sd.get(0));
double pdf1 = pdf(x, sd.get(1));
- //if (pdf0 < 0 || pdf0 > 1 || pdf1 < 0 || pdf1 > 1)
- // System.out.print("");
+ // if (pdf0 < 0 || pdf0 > 1 || pdf1 < 0 || pdf1 > 1)
+ // System.out.print("");
return pdf0 * pdf1;
}
@@ -128,4 +134,22 @@
buf.append("]}");
return buf.toString();
}
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.mean = AbstractVector.readVector(in);
+ this.sd = AbstractVector.readVector(in);
+ this.s0 = in.readInt();
+ this.s1 = AbstractVector.readVector(in);
+ this.s2 = AbstractVector.readVector(in);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ AbstractVector.writeVector(out, mean);
+ AbstractVector.writeVector(out, sd);
+ out.writeInt(s0);
+ AbstractVector.writeVector(out, s1);
+ AbstractVector.writeVector(out, s2);
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java Wed Jun 24 18:41:37 2009
@@ -1,5 +1,7 @@
package org.apache.mahout.clustering.dirichlet.models;
+import org.apache.hadoop.io.Writable;
+
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
@@ -21,7 +23,7 @@
* A model is a probability distribution over observed data points and allows
* the probability of any data point to be computed.
*/
-public interface Model<Observation> {
+public interface Model<Observation> extends Writable {
/**
* Observe the given observation, retaining information about it
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Wed Jun 24 18:41:37 2009
@@ -16,6 +16,11 @@
*/
package org.apache.mahout.clustering.dirichlet.models;
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.mahout.matrix.AbstractVector;
import org.apache.mahout.matrix.SquareRootFunction;
import org.apache.mahout.matrix.Vector;
@@ -47,10 +52,12 @@
}
/**
- * Return an instance with the same parameters
+ * TODO: Return a proper sample from the posterior. For now, return an
+ * instance with the same parameters
+ *
* @return an NormalModel
*/
- NormalModel sample() {
+ public NormalModel sample() {
return new NormalModel(mean, sd);
}
@@ -58,7 +65,7 @@
public void observe(Vector x) {
s0++;
if (s1 == null)
- s1 = x;
+ s1 = x.clone();
else
s1 = s1.plus(x);
if (s2 == null)
@@ -83,7 +90,6 @@
@Override
public double pdf(Vector x) {
- assert x.getNumNondefaultElements() == 2;
double sd2 = sd * sd;
double exp = -(x.dot(x) - 2 * x.dot(mean) + mean.dot(mean)) / (2 * sd2);
double ex = Math.exp(exp);
@@ -105,4 +111,22 @@
buf.append("] sd=").append(String.format("%.2f", sd)).append('}');
return buf.toString();
}
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ this.mean = AbstractVector.readVector(in);
+ this.sd = in.readDouble();
+ this.s0 = in.readInt();
+ this.s1 = AbstractVector.readVector(in);
+ this.s2 = AbstractVector.readVector(in);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ AbstractVector.writeVector(out, mean);
+ out.writeDouble(sd);
+ out.writeInt(s0);
+ AbstractVector.writeVector(out, s1);
+ AbstractVector.writeVector(out, s2);
+ }
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java Wed Jun 24 18:41:37 2009
@@ -44,7 +44,7 @@
* @return an SampledNormalModel
*/
@Override
- NormalModel sample() {
+ public NormalModel sample() {
return new SampledNormalModel(mean, sd);
}
}
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java Wed Jun 24 18:41:37 2009
@@ -9,6 +9,7 @@
import org.apache.hadoop.io.LongWritable;
import java.util.List;
+import java.io.File;
import java.io.IOException;
/**
@@ -28,4 +29,16 @@
}
writer.close();
}
+
+ public static void rmr(String path) throws Exception {
+ File f = new File(path);
+ if (f.exists()) {
+ if (f.isDirectory()) {
+ String[] contents = f.list();
+ for (int i = 0; i < contents.length; i++)
+ rmr(f.toString() + File.separator + contents[i]);
+ }
+ f.delete();
+ }
+ }
}
Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java Wed Jun 24 18:41:37 2009
@@ -16,20 +16,19 @@
*/
package org.apache.mahout.clustering.dirichlet;
-import java.io.BufferedWriter;
import java.io.File;
-import java.io.FileOutputStream;
import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import junit.framework.TestCase;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
import org.apache.mahout.clustering.dirichlet.models.Model;
@@ -40,6 +39,7 @@
import org.apache.mahout.clustering.kmeans.KMeansDriver;
import org.apache.mahout.matrix.DenseVector;
import org.apache.mahout.matrix.JsonVectorAdapter;
+import org.apache.mahout.matrix.SparseVector;
import org.apache.mahout.matrix.Vector;
import org.apache.mahout.utils.DummyOutputCollector;
@@ -50,8 +50,40 @@
private List<Vector> sampleData = new ArrayList<Vector>();
+ FileSystem fs;
+
+ Configuration conf;
+
/**
* Generate random samples and add them to the sampleData
+ *
+ * @param num int number of samples to generate
+ * @param mx double x-value of the sample mean
+ * @param my double y-value of the sample mean
+ * @param sdx double x-standard deviation of the samples
+ * @param sdy double y-standard deviation of the samples
+ */
+ private void generateSamples(int num, double mx, double my, double sdx,
+ double sdy) {
+ System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
+ + "] sd=[" + sdx + ", " + sdy + "]");
+ for (int i = 0; i < num; i++) {
+ addSample(new double[] {
+ UncommonDistributions.rNorm(mx, sdx),
+ UncommonDistributions.rNorm(my, sdy) });
+ }
+ }
+
+ private void addSample(double[] values) {
+ Vector v = new SparseVector(2);
+ for (int j = 0; j < values.length; j++)
+ v.setQuick(j, values[j]);
+ sampleData.add(v);
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ *
* @param num int number of samples to generate
* @param mx double x-value of the sample mean
* @param my double y-value of the sample mean
@@ -60,36 +92,27 @@
private void generateSamples(int num, double mx, double my, double sd) {
System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
+ "] sd=" + sd);
- for (int i = 0; i < num; i++)
- sampleData.add(new DenseVector(new double[] {
- UncommonDistributions.rNorm(mx, sd),
- UncommonDistributions.rNorm(my, sd) }));
- }
-
- public static void writePointsToFileWithPayload(List<Vector> points,
- String fileName, String payload) throws IOException {
- BufferedWriter output = new BufferedWriter(new OutputStreamWriter(
- new FileOutputStream(fileName), Charset.forName("UTF-8")));
- for (Vector point : points) {
- output.write(point.asFormatString());
- output.write(payload);
- output.write('\n');
+ for (int i = 0; i < num; i++) {
+ addSample(new double[] { UncommonDistributions.rNorm(mx, sd),
+ UncommonDistributions.rNorm(my, sd) });
}
- output.flush();
- output.close();
}
@Override
protected void setUp() throws Exception {
super.setUp();
UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+ ClusteringTestUtils.rmr("output");
+ ClusteringTestUtils.rmr("input");
+ conf = new Configuration();
+ fs = FileSystem.get(conf);
File f = new File("input");
- if (!f.exists())
- f.mkdir();
+ f.mkdir();
}
/**
- * Test the basic Mapper
+ * Test the basic Mapper
+ *
* @throws Exception
*/
public void testMapper() throws Exception {
@@ -99,16 +122,17 @@
DirichletMapper mapper = new DirichletMapper();
mapper.configure(state);
- DummyOutputCollector<Text, Text> collector = new DummyOutputCollector<Text, Text>();
+ DummyOutputCollector<Text, Vector> collector = new DummyOutputCollector<Text, Vector>();
for (Vector v : sampleData)
- mapper.map(null, new Text(v.asFormatString()), collector, null);
- Map<String, List<Text>> data = collector.getData();
+ mapper.map(null, v, collector, null);
+ Map<String, List<Vector>> data = collector.getData();
// this seed happens to produce two partitions, but they work
assertEquals("output size", 3, data.size());
}
/**
- * Test the basic Reducer
+ * Test the basic Reducer
+ *
* @throws Exception
*/
public void testReducer() throws Exception {
@@ -121,16 +145,16 @@
DirichletMapper mapper = new DirichletMapper();
mapper.configure(state);
- DummyOutputCollector<Text, Text> mapCollector = new DummyOutputCollector<Text, Text>();
+ DummyOutputCollector<Text, Vector> mapCollector = new DummyOutputCollector<Text, Vector>();
for (Vector v : sampleData)
- mapper.map(null, new Text(v.asFormatString()), mapCollector, null);
- Map<String, List<Text>> data = mapCollector.getData();
+ mapper.map(null, v, mapCollector, null);
+ Map<String, List<Vector>> data = mapCollector.getData();
// this seed happens to produce three partitions, but they work
assertEquals("output size", 7, data.size());
DirichletReducer reducer = new DirichletReducer();
reducer.configure(state);
- DummyOutputCollector<Text, Text> reduceCollector = new DummyOutputCollector<Text, Text>();
+ DummyOutputCollector<Text, DirichletCluster<Vector>> reduceCollector = new DummyOutputCollector<Text, DirichletCluster<Vector>>();
for (String key : mapCollector.getKeys())
reducer.reduce(new Text(key), mapCollector.getValue(key).iterator(),
reduceCollector, null);
@@ -155,7 +179,8 @@
}
/**
- * Test the Mapper and Reducer in an iteration loop
+ * Test the Mapper and Reducer in an iteration loop
+ *
* @throws Exception
*/
public void testMRIterations() throws Exception {
@@ -171,13 +196,13 @@
for (int iteration = 0; iteration < 10; iteration++) {
DirichletMapper mapper = new DirichletMapper();
mapper.configure(state);
- DummyOutputCollector<Text, Text> mapCollector = new DummyOutputCollector<Text, Text>();
+ DummyOutputCollector<Text, Vector> mapCollector = new DummyOutputCollector<Text, Vector>();
for (Vector v : sampleData)
- mapper.map(null, new Text(v.asFormatString()), mapCollector, null);
+ mapper.map(null, v, mapCollector, null);
DirichletReducer reducer = new DirichletReducer();
reducer.configure(state);
- DummyOutputCollector<Text, Text> reduceCollector = new DummyOutputCollector<Text, Text>();
+ DummyOutputCollector<Text, DirichletCluster<Vector>> reduceCollector = new DummyOutputCollector<Text, DirichletCluster<Vector>>();
for (String key : mapCollector.getKeys())
reducer.reduce(new Text(key), mapCollector.getValue(key).iterator(),
reduceCollector, null);
@@ -322,13 +347,13 @@
assertEquals("modelFactory", state.modelFactory.getClass().getName(),
state2.modelFactory.getClass().getName());
assertEquals("clusters", state.clusters.size(), state2.clusters.size());
- assertEquals("mixture", state.mixture.size(), state2.mixture
- .size());
+ assertEquals("mixture", state.mixture.size(), state2.mixture.size());
assertEquals("dirichlet", state.offset, state2.offset);
}
/**
- * Test the Mapper and Reducer using the Driver
+ * Test the Mapper and Reducer using the Driver
+ *
* @throws Exception
*/
public void testDriverMRIterations() throws Exception {
@@ -339,16 +364,21 @@
generateSamples(100, 2, 0, 0.2);
generateSamples(100, 0, 2, 0.3);
generateSamples(100, 2, 2, 1);
- writePointsToFileWithPayload(sampleData, "input/data.txt", "");
+ ClusteringTestUtils.writePointsToFile(sampleData, "input/data.txt", fs,
+ conf);
// Now run the driver
- DirichletDriver.runJob("input", "output",
- "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
- 10, 1.0, 1);
+ DirichletDriver
+ .runJob(
+ "input",
+ "output",
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution",
+ 20, 10, 1.0, 1);
// and inspect results
List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
JobConf conf = new JobConf(KMeansDriver.class);
- conf.set(DirichletDriver.MODEL_FACTORY_KEY,
- "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
+ conf
+ .set(DirichletDriver.MODEL_FACTORY_KEY,
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
for (int i = 0; i < 11; i++) {
@@ -358,8 +388,8 @@
printResults(clusters, 0);
}
- private static void printResults(List<List<DirichletCluster<Vector>>> clusters,
- int significant) {
+ private static void printResults(
+ List<List<DirichletCluster<Vector>>> clusters, int significant) {
int row = 0;
for (List<DirichletCluster<Vector>> r : clusters) {
System.out.print("sample[" + row++ + "]= ");
@@ -377,33 +407,28 @@
}
/**
- * Test the Mapper and Reducer using the Driver
+ * Test the Mapper and Reducer using the Driver
+ *
* @throws Exception
*/
public void testDriverMnRIterations() throws Exception {
File f = new File("input");
for (File g : f.listFiles())
g.delete();
- generateSamples(500, 0, 0, 0.5);
- writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
- sampleData = new ArrayList<Vector>();
- generateSamples(500, 2, 0, 0.2);
- writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
- sampleData = new ArrayList<Vector>();
- generateSamples(500, 0, 2, 0.3);
- writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
- sampleData = new ArrayList<Vector>();
- generateSamples(500, 2, 2, 1);
- writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
+ generate4Datasets();
// Now run the driver
- DirichletDriver.runJob("input", "output",
- "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
- 15, 1.0, 1);
+ DirichletDriver
+ .runJob(
+ "input",
+ "output",
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution",
+ 20, 15, 1.0, 1);
// and inspect results
List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
JobConf conf = new JobConf(KMeansDriver.class);
- conf.set(DirichletDriver.MODEL_FACTORY_KEY,
- "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
+ conf
+ .set(DirichletDriver.MODEL_FACTORY_KEY,
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
for (int i = 0; i < 11; i++) {
@@ -413,34 +438,47 @@
printResults(clusters, 0);
}
- /**
- * Test the Mapper and Reducer using the Driver
- * @throws Exception
- */
- public void testDriverMnRnIterations() throws Exception {
- File f = new File("input");
- for (File g : f.listFiles())
- g.delete();
+ private void generate4Datasets() throws IOException {
generateSamples(500, 0, 0, 0.5);
- writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
+ ClusteringTestUtils.writePointsToFile(sampleData, "input/data1.txt", fs,
+ conf);
sampleData = new ArrayList<Vector>();
generateSamples(500, 2, 0, 0.2);
- writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
+ ClusteringTestUtils.writePointsToFile(sampleData, "input/data2.txt", fs,
+ conf);
sampleData = new ArrayList<Vector>();
generateSamples(500, 0, 2, 0.3);
- writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
+ ClusteringTestUtils.writePointsToFile(sampleData, "input/data3.txt", fs,
+ conf);
sampleData = new ArrayList<Vector>();
generateSamples(500, 2, 2, 1);
- writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
+ ClusteringTestUtils.writePointsToFile(sampleData, "input/data4.txt", fs,
+ conf);
+ }
+
+ /**
+ * Test the Mapper and Reducer using the Driver
+ *
+ * @throws Exception
+ */
+ public void testDriverMnRnIterations() throws Exception {
+ File f = new File("input");
+ for (File g : f.listFiles())
+ g.delete();
+ generate4Datasets();
// Now run the driver
- DirichletDriver.runJob("input", "output",
- "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
- 15, 1.0, 2);
+ DirichletDriver
+ .runJob(
+ "input",
+ "output",
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution",
+ 20, 15, 1.0, 2);
// and inspect results
List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
JobConf conf = new JobConf(KMeansDriver.class);
- conf.set(DirichletDriver.MODEL_FACTORY_KEY,
- "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
+ conf
+ .set(DirichletDriver.MODEL_FACTORY_KEY,
+ "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
for (int i = 0; i < 11; i++) {
@@ -451,25 +489,8 @@
}
/**
- * Generate random samples and add them to the sampleData
- * @param num int number of samples to generate
- * @param mx double x-value of the sample mean
- * @param my double y-value of the sample mean
- * @param sdx double x-standard deviation of the samples
- * @param sdy double y-standard deviation of the samples
- */
- private void generateSamples(int num, double mx, double my, double sdx,
- double sdy) {
- System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
- + "] sd=[" + sdx + ", " + sdy + "]");
- for (int i = 0; i < num; i++)
- sampleData.add(new DenseVector(new double[] {
- UncommonDistributions.rNorm(mx, sdx),
- UncommonDistributions.rNorm(my, sdy) }));
- }
-
- /**
- * Test the Mapper and Reducer using the Driver
+ * Test the Mapper and Reducer using the Driver
+ *
* @throws Exception
*/
public void testDriverMnRnIterationsAsymmetric() throws Exception {
@@ -477,16 +498,20 @@
for (File g : f.listFiles())
g.delete();
generateSamples(500, 0, 0, 0.5, 1.0);
- writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
+ ClusteringTestUtils.writePointsToFile(sampleData, "input/data1.txt", fs,
+ conf);
sampleData = new ArrayList<Vector>();
- generateSamples(500, 2, 0, 0.2, 0.1);
- writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
+ generateSamples(500, 2, 0, 0.2);
+ ClusteringTestUtils.writePointsToFile(sampleData, "input/data2.txt", fs,
+ conf);
sampleData = new ArrayList<Vector>();
- generateSamples(500, 0, 2, 0.3, 0.5);
- writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
+ generateSamples(500, 0, 2, 0.3);
+ ClusteringTestUtils.writePointsToFile(sampleData, "input/data3.txt", fs,
+ conf);
sampleData = new ArrayList<Vector>();
- generateSamples(500, 2, 2, 1, 0.5);
- writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
+ generateSamples(500, 2, 2, 1);
+ ClusteringTestUtils.writePointsToFile(sampleData, "input/data4.txt", fs,
+ conf);
// Now run the driver
DirichletDriver
.runJob(
@@ -498,7 +523,8 @@
List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
JobConf conf = new JobConf(KMeansDriver.class);
conf
- .set(DirichletDriver.MODEL_FACTORY_KEY,
+ .set(
+ DirichletDriver.MODEL_FACTORY_KEY,
"org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution");
conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java Wed Jun 24 18:41:37 2009
@@ -20,6 +20,7 @@
import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.clustering.dirichlet.models.Model;
import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.NormalModel;
import org.apache.mahout.matrix.DenseVector;
import org.apache.mahout.matrix.Vector;
@@ -31,21 +32,21 @@
@Override
public Model<Vector>[] sampleFromPrior(int howMany) {
- Model<Vector>[] result = new NormalScModel[howMany];
+ Model<Vector>[] result = new NormalModel[howMany];
for (int i = 0; i < howMany; i++) {
DenseVector mean = new DenseVector(60);
for (int j = 0; j < 60; j++)
mean.set(j, UncommonDistributions.rNorm(30, 0.5));
- result[i] = new NormalScModel(mean, 1);
+ result[i] = new NormalModel(mean, 1);
}
return result;
}
@Override
public Model<Vector>[] sampleFromPosterior(Model<Vector>[] posterior) {
- Model<Vector>[] result = new NormalScModel[posterior.length];
+ Model<Vector>[] result = new NormalModel[posterior.length];
for (int i = 0; i < posterior.length; i++) {
- NormalScModel m = (NormalScModel) posterior[i];
+ NormalModel m = (NormalModel) posterior[i];
result[i] = m.sample();
}
return result;