You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2009/06/24 20:41:38 UTC

svn commit: r788116 - in /lucene/mahout/trunk: core/src/main/java/org/apache/mahout/clustering/dirichlet/ core/src/main/java/org/apache/mahout/clustering/dirichlet/models/ core/src/test/java/org/apache/mahout/clustering/ core/src/test/java/org/apache/m...

Author: jeastman
Date: Wed Jun 24 18:41:37 2009
New Revision: 788116

URL: http://svn.apache.org/viewvc?rev=788116&view=rev
Log:
MAHOUT-137
- modified Dirichlet clustering to use Writable Vectors
- modified to use Writable DirichletClusters and Models
- removed 2d assertion in NormalModel and removed NormalScModel from examples
- updated unit tests; all run

Removed:
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModel.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
    lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java Wed Jun 24 18:41:37 2009
@@ -16,8 +16,12 @@
  */
 package org.apache.mahout.clustering.dirichlet;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
 import java.lang.reflect.Type;
 
+import org.apache.hadoop.io.Writable;
 import org.apache.mahout.clustering.dirichlet.models.Model;
 import org.apache.mahout.matrix.JsonVectorAdapter;
 import org.apache.mahout.matrix.Vector;
@@ -26,7 +30,20 @@
 import com.google.gson.GsonBuilder;
 import com.google.gson.reflect.TypeToken;
 
-public class DirichletCluster<Observation> {
+public class DirichletCluster<Observation> implements Writable {
+
+  @SuppressWarnings("unchecked")
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    this.totalCount = in.readDouble();
+    this.model = readModel(in);
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeDouble(totalCount);
+    writeModel(out, model);
+  }
 
   public Model<Observation> model; // the model for this iteration
 
@@ -58,6 +75,44 @@
     return gson.toJson(this, typeOfModel);
   }
 
+  /**
+   * Reads a typed Model instance from the input stream
+   * 
+   * @param in
+   * @return
+   * @throws IOException
+   */
+  @SuppressWarnings("unchecked")
+  public static Model readModel(DataInput in) throws IOException {
+    String modelClassName = in.readUTF();
+    Model model;
+    try {
+      model = Class.forName(modelClassName).asSubclass(Model.class)
+          .newInstance();
+    } catch (ClassNotFoundException e) {
+      throw new RuntimeException(e);
+    } catch (IllegalAccessException e) {
+      throw new RuntimeException(e);
+    } catch (InstantiationException e) {
+      throw new RuntimeException(e);
+    }
+    model.readFields(in);
+    return model;
+  }
+
+  /**
+   * Writes a typed Model instance to the output stream
+   * 
+   * @param out
+   * @param matrix
+   * @throws IOException
+   */
+  @SuppressWarnings("unchecked")
+  public static void writeModel(DataOutput out, Model model) throws IOException {
+    out.writeUTF(model.getClass().getName());
+    model.write(out);
+  }
+
   public static DirichletCluster<Vector> fromFormatString(String formatString) {
     GsonBuilder builder = new GsonBuilder();
     builder.registerTypeAdapter(Vector.class, new JsonVectorAdapter());

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Wed Jun 24 18:41:37 2009
@@ -27,9 +27,11 @@
 import org.apache.hadoop.mapred.FileOutputFormat;
 import org.apache.hadoop.mapred.JobClient;
 import org.apache.hadoop.mapred.JobConf;
+import org.apache.hadoop.mapred.SequenceFileInputFormat;
 import org.apache.hadoop.mapred.SequenceFileOutputFormat;
 import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
 import org.apache.mahout.clustering.kmeans.KMeansDriver;
+import org.apache.mahout.matrix.SparseVector;
 import org.apache.mahout.matrix.Vector;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -109,9 +111,8 @@
     for (int i = 0; i < numModels; i++) {
       Path path = new Path(stateIn + "/part-" + i);
       SequenceFile.Writer writer = new SequenceFile.Writer(fs, job, path,
-          Text.class, Text.class);
-      String stateString = state.clusters.get(i).asFormatString();
-      writer.append(new Text(Integer.toString(i)), new Text(stateString));
+          Text.class, DirichletCluster.class);
+      writer.append(new Text(Integer.toString(i)), state.clusters.get(i));
       writer.close();
     }
   }
@@ -146,7 +147,9 @@
     JobConf conf = new JobConf(DirichletDriver.class);
 
     conf.setOutputKeyClass(Text.class);
-    conf.setOutputValueClass(Text.class);
+    conf.setOutputValueClass(DirichletCluster.class);
+    conf.setMapOutputKeyClass(Text.class);
+    conf.setMapOutputValueClass(SparseVector.class);
 
     FileInputFormat.setInputPaths(conf, new Path(input));
     Path outPath = new Path(stateOut);
@@ -155,6 +158,7 @@
     conf.setMapperClass(DirichletMapper.class);
     conf.setReducerClass(DirichletReducer.class);
     conf.setNumReduceTasks(numReducers);
+    conf.setInputFormat(SequenceFileInputFormat.class);
     conf.setOutputFormat(SequenceFileOutputFormat.class);
     conf.set(STATE_IN_KEY, stateIn);
     conf.set(MODEL_FACTORY_KEY, modelFactory);

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java Wed Jun 24 18:41:37 2009
@@ -30,26 +30,24 @@
 import org.apache.hadoop.mapred.Mapper;
 import org.apache.hadoop.mapred.OutputCollector;
 import org.apache.hadoop.mapred.Reporter;
-import org.apache.mahout.matrix.AbstractVector;
 import org.apache.mahout.matrix.DenseVector;
 import org.apache.mahout.matrix.TimesFunction;
 import org.apache.mahout.matrix.Vector;
 
 public class DirichletMapper extends MapReduceBase implements
-    Mapper<WritableComparable<?>, Text, Text, Text> {
+    Mapper<WritableComparable<?>, Vector, Text, Vector> {
 
   private DirichletState<Vector> state;
 
   @Override
-  public void map(WritableComparable<?> key, Text values,
-      OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
-    Vector v = AbstractVector.decodeVector(values.toString());
+  public void map(WritableComparable<?> key, Vector v,
+      OutputCollector<Text, Vector> output, Reporter reporter) throws IOException {
     // compute a normalized vector of probabilities that v is described by each model
     Vector pi = normalizedProbabilities(state, v);
     // then pick one model by sampling a Multinomial distribution based upon them
     // see: http://en.wikipedia.org/wiki/Multinomial_distribution
     int k = UncommonDistributions.rMultinom(pi);
-    output.collect(new Text(String.valueOf(k)), values);
+    output.collect(new Text(String.valueOf(k)), v);
   }
 
   public void configure(DirichletState<Vector> state) {
@@ -62,6 +60,7 @@
     state = getDirichletState(job);
   }
 
+  @SuppressWarnings("unchecked")
   public static DirichletState<Vector> getDirichletState(JobConf job) {
     String statePath = job.get(DirichletDriver.STATE_IN_KEY);
     String modelFactory = job.get(DirichletDriver.MODEL_FACTORY_KEY);
@@ -79,13 +78,11 @@
             job);
         try {
           Text key = new Text();
-          Text value = new Text();
-          while (reader.next(key, value)) {
+          DirichletCluster cluster = new DirichletCluster();
+          while (reader.next(key, cluster)) {
             int index = Integer.parseInt(key.toString());
-            String formatString = value.toString();
-            DirichletCluster<Vector> cluster = DirichletCluster
-                .fromFormatString(formatString);
             state.clusters.set(index, cluster);
+            cluster = new DirichletCluster();
           }
         } finally {
           reader.close();

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletReducer.java Wed Jun 24 18:41:37 2009
@@ -27,29 +27,29 @@
 import org.apache.hadoop.mapred.Reducer;
 import org.apache.hadoop.mapred.Reporter;
 import org.apache.mahout.clustering.dirichlet.models.Model;
-import org.apache.mahout.matrix.AbstractVector;
 import org.apache.mahout.matrix.Vector;
 
 public class DirichletReducer extends MapReduceBase implements
-    Reducer<Text, Text, Text, Text> {
+    Reducer<Text, Vector, Text, DirichletCluster<Vector>> {
 
   DirichletState<Vector> state;
 
   public Model<Vector>[] newModels;
 
   @Override
-  public void reduce(Text key, Iterator<Text> values,
-      OutputCollector<Text, Text> output, Reporter reporter) throws IOException {
+  public void reduce(Text key, Iterator<Vector> values,
+      OutputCollector<Text, DirichletCluster<Vector>> output, Reporter reporter)
+      throws IOException {
     int k = Integer.parseInt(key.toString());
     Model<Vector> model = newModels[k];
     while (values.hasNext()) {
-      Vector v = AbstractVector.decodeVector(values.next().toString());
+      Vector v = values.next();
       model.observe(v);
     }
     model.computeParameters();
     DirichletCluster<Vector> cluster = state.clusters.get(k);
     cluster.setModel(model);
-    output.collect(key, new Text(cluster.asFormatString()));
+    output.collect(key, cluster);
   }
 
   public void configure(DirichletState<Vector> state) {

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java Wed Jun 24 18:41:37 2009
@@ -16,6 +16,11 @@
  */
 package org.apache.mahout.clustering.dirichlet.models;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.mahout.matrix.AbstractVector;
 import org.apache.mahout.matrix.SquareRootFunction;
 import org.apache.mahout.matrix.Vector;
 
@@ -50,6 +55,7 @@
 
   /**
    * Return an instance with the same parameters
+   * 
    * @return an AsymmetricSampledNormalModel
    */
   AsymmetricSampledNormalModel sample() {
@@ -60,7 +66,7 @@
   public void observe(Vector x) {
     s0++;
     if (s1 == null)
-      s1 = x.like();
+      s1 = x.clone();
     else
       s1 = s1.plus(x);
     if (s2 == null)
@@ -86,10 +92,10 @@
   /**
    * Calculate a pdf using the supplied sample and sd
    * 
-  * @param x a Vector sample
-  * @param sd a double std deviation
-  * @return
-  */
+   * @param x a Vector sample
+   * @param sd a double std deviation
+   * @return
+   */
   private double pdf(Vector x, double sd) {
     assert x.getNumNondefaultElements() == 2;
     double sd2 = sd * sd;
@@ -104,8 +110,8 @@
     assert x.getNumNondefaultElements() == 2;
     double pdf0 = pdf(x, sd.get(0));
     double pdf1 = pdf(x, sd.get(1));
-    //if (pdf0 < 0 || pdf0 > 1 || pdf1 < 0 || pdf1 > 1)
-    //  System.out.print("");
+    // if (pdf0 < 0 || pdf0 > 1 || pdf1 < 0 || pdf1 > 1)
+    // System.out.print("");
     return pdf0 * pdf1;
   }
 
@@ -128,4 +134,22 @@
     buf.append("]}");
     return buf.toString();
   }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    this.mean = AbstractVector.readVector(in);
+    this.sd = AbstractVector.readVector(in);
+    this.s0 = in.readInt();
+    this.s1 = AbstractVector.readVector(in);
+    this.s2 = AbstractVector.readVector(in);
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    AbstractVector.writeVector(out, mean);
+    AbstractVector.writeVector(out, sd);
+    out.writeInt(s0);
+    AbstractVector.writeVector(out, s1);
+    AbstractVector.writeVector(out, s2);
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/Model.java Wed Jun 24 18:41:37 2009
@@ -1,5 +1,7 @@
 package org.apache.mahout.clustering.dirichlet.models;
 
+import org.apache.hadoop.io.Writable;
+
 /**
  * Licensed to the Apache Software Foundation (ASF) under one or more
  * contributor license agreements.  See the NOTICE file distributed with
@@ -21,7 +23,7 @@
  * A model is a probability distribution over observed data points and allows 
  * the probability of any data point to be computed.
  */
-public interface Model<Observation> {
+public interface Model<Observation> extends Writable {
 
   /**
    * Observe the given observation, retaining information about it

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Wed Jun 24 18:41:37 2009
@@ -16,6 +16,11 @@
  */
 package org.apache.mahout.clustering.dirichlet.models;
 
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.mahout.matrix.AbstractVector;
 import org.apache.mahout.matrix.SquareRootFunction;
 import org.apache.mahout.matrix.Vector;
 
@@ -47,10 +52,12 @@
   }
 
   /**
-   * Return an instance with the same parameters
+   * TODO: Return a proper sample from the posterior. For now, return an 
+   * instance with the same parameters
+   * 
    * @return an NormalModel
    */
-  NormalModel sample() {
+  public NormalModel sample() {
     return new NormalModel(mean, sd);
   }
 
@@ -58,7 +65,7 @@
   public void observe(Vector x) {
     s0++;
     if (s1 == null)
-      s1 = x;
+      s1 = x.clone();
     else
       s1 = s1.plus(x);
     if (s2 == null)
@@ -83,7 +90,6 @@
 
   @Override
   public double pdf(Vector x) {
-    assert x.getNumNondefaultElements() == 2;
     double sd2 = sd * sd;
     double exp = -(x.dot(x) - 2 * x.dot(mean) + mean.dot(mean)) / (2 * sd2);
     double ex = Math.exp(exp);
@@ -105,4 +111,22 @@
     buf.append("] sd=").append(String.format("%.2f", sd)).append('}');
     return buf.toString();
   }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    this.mean = AbstractVector.readVector(in);
+    this.sd = in.readDouble();
+    this.s0 = in.readInt();
+    this.s1 = AbstractVector.readVector(in);
+    this.s2 = AbstractVector.readVector(in);
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    AbstractVector.writeVector(out, mean);
+    out.writeDouble(sd);
+    out.writeInt(s0);
+    AbstractVector.writeVector(out, s1);
+    AbstractVector.writeVector(out, s2);    
+  }
 }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java Wed Jun 24 18:41:37 2009
@@ -44,7 +44,7 @@
    * @return an SampledNormalModel
    */
   @Override
-  NormalModel sample() {
+  public NormalModel sample() {
     return new SampledNormalModel(mean, sd);
   }
 }

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java Wed Jun 24 18:41:37 2009
@@ -9,6 +9,7 @@
 import org.apache.hadoop.io.LongWritable;
 
 import java.util.List;
+import java.io.File;
 import java.io.IOException;
 
 /**
@@ -28,4 +29,16 @@
     }
     writer.close();
   }
+
+  public static void rmr(String path) throws Exception {
+    File f = new File(path);
+    if (f.exists()) {
+      if (f.isDirectory()) {
+        String[] contents = f.list();
+        for (int i = 0; i < contents.length; i++)
+          rmr(f.toString() + File.separator + contents[i]);
+      }
+      f.delete();
+    }
+  }
 }

Modified: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java Wed Jun 24 18:41:37 2009
@@ -16,20 +16,19 @@
  */
 package org.apache.mahout.clustering.dirichlet;
 
-import java.io.BufferedWriter;
 import java.io.File;
-import java.io.FileOutputStream;
 import java.io.IOException;
-import java.io.OutputStreamWriter;
-import java.nio.charset.Charset;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 
 import junit.framework.TestCase;
 
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.io.Text;
 import org.apache.hadoop.mapred.JobConf;
+import org.apache.mahout.clustering.ClusteringTestUtils;
 import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
 import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
 import org.apache.mahout.clustering.dirichlet.models.Model;
@@ -40,6 +39,7 @@
 import org.apache.mahout.clustering.kmeans.KMeansDriver;
 import org.apache.mahout.matrix.DenseVector;
 import org.apache.mahout.matrix.JsonVectorAdapter;
+import org.apache.mahout.matrix.SparseVector;
 import org.apache.mahout.matrix.Vector;
 import org.apache.mahout.utils.DummyOutputCollector;
 
@@ -50,8 +50,40 @@
 
   private List<Vector> sampleData = new ArrayList<Vector>();
 
+  FileSystem fs;
+
+  Configuration conf;
+
   /**
    * Generate random samples and add them to the sampleData
+   * 
+   * @param num int number of samples to generate
+   * @param mx double x-value of the sample mean
+   * @param my double y-value of the sample mean
+   * @param sdx double x-standard deviation of the samples
+   * @param sdy double y-standard deviation of the samples
+   */
+  private void generateSamples(int num, double mx, double my, double sdx,
+      double sdy) {
+    System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
+        + "] sd=[" + sdx + ", " + sdy + "]");
+    for (int i = 0; i < num; i++) {
+      addSample(new double[] {
+          UncommonDistributions.rNorm(mx, sdx),
+          UncommonDistributions.rNorm(my, sdy) });
+    }
+  }
+
+  private void addSample(double[] values) {
+    Vector v = new SparseVector(2);
+    for (int j = 0; j < values.length; j++)
+      v.setQuick(j, values[j]);
+    sampleData.add(v);
+  }
+
+  /**
+   * Generate random samples and add them to the sampleData
+   * 
    * @param num int number of samples to generate
    * @param mx double x-value of the sample mean
    * @param my double y-value of the sample mean
@@ -60,36 +92,27 @@
   private void generateSamples(int num, double mx, double my, double sd) {
     System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
         + "] sd=" + sd);
-    for (int i = 0; i < num; i++)
-      sampleData.add(new DenseVector(new double[] {
-          UncommonDistributions.rNorm(mx, sd),
-          UncommonDistributions.rNorm(my, sd) }));
-  }
-
-  public static void writePointsToFileWithPayload(List<Vector> points,
-      String fileName, String payload) throws IOException {
-    BufferedWriter output = new BufferedWriter(new OutputStreamWriter(
-        new FileOutputStream(fileName), Charset.forName("UTF-8")));
-    for (Vector point : points) {
-      output.write(point.asFormatString());
-      output.write(payload);
-      output.write('\n');
+    for (int i = 0; i < num; i++) {
+      addSample(new double[] { UncommonDistributions.rNorm(mx, sd),
+          UncommonDistributions.rNorm(my, sd) });
     }
-    output.flush();
-    output.close();
   }
 
   @Override
   protected void setUp() throws Exception {
     super.setUp();
     UncommonDistributions.init("Mahout=Hadoop+ML".getBytes());
+    ClusteringTestUtils.rmr("output");
+    ClusteringTestUtils.rmr("input");
+    conf = new Configuration();
+    fs = FileSystem.get(conf);
     File f = new File("input");
-    if (!f.exists())
-      f.mkdir();
+    f.mkdir();
   }
 
   /**
-   * Test the basic Mapper 
+   * Test the basic Mapper
+   * 
    * @throws Exception
    */
   public void testMapper() throws Exception {
@@ -99,16 +122,17 @@
     DirichletMapper mapper = new DirichletMapper();
     mapper.configure(state);
 
-    DummyOutputCollector<Text, Text> collector = new DummyOutputCollector<Text, Text>();
+    DummyOutputCollector<Text, Vector> collector = new DummyOutputCollector<Text, Vector>();
     for (Vector v : sampleData)
-      mapper.map(null, new Text(v.asFormatString()), collector, null);
-    Map<String, List<Text>> data = collector.getData();
+      mapper.map(null, v, collector, null);
+    Map<String, List<Vector>> data = collector.getData();
     // this seed happens to produce two partitions, but they work
     assertEquals("output size", 3, data.size());
   }
 
   /**
-   * Test the basic Reducer 
+   * Test the basic Reducer
+   * 
    * @throws Exception
    */
   public void testReducer() throws Exception {
@@ -121,16 +145,16 @@
     DirichletMapper mapper = new DirichletMapper();
     mapper.configure(state);
 
-    DummyOutputCollector<Text, Text> mapCollector = new DummyOutputCollector<Text, Text>();
+    DummyOutputCollector<Text, Vector> mapCollector = new DummyOutputCollector<Text, Vector>();
     for (Vector v : sampleData)
-      mapper.map(null, new Text(v.asFormatString()), mapCollector, null);
-    Map<String, List<Text>> data = mapCollector.getData();
+      mapper.map(null, v, mapCollector, null);
+    Map<String, List<Vector>> data = mapCollector.getData();
     // this seed happens to produce three partitions, but they work
     assertEquals("output size", 7, data.size());
 
     DirichletReducer reducer = new DirichletReducer();
     reducer.configure(state);
-    DummyOutputCollector<Text, Text> reduceCollector = new DummyOutputCollector<Text, Text>();
+    DummyOutputCollector<Text, DirichletCluster<Vector>> reduceCollector = new DummyOutputCollector<Text, DirichletCluster<Vector>>();
     for (String key : mapCollector.getKeys())
       reducer.reduce(new Text(key), mapCollector.getValue(key).iterator(),
           reduceCollector, null);
@@ -155,7 +179,8 @@
   }
 
   /**
-   * Test the Mapper and Reducer in an iteration loop 
+   * Test the Mapper and Reducer in an iteration loop
+   * 
    * @throws Exception
    */
   public void testMRIterations() throws Exception {
@@ -171,13 +196,13 @@
     for (int iteration = 0; iteration < 10; iteration++) {
       DirichletMapper mapper = new DirichletMapper();
       mapper.configure(state);
-      DummyOutputCollector<Text, Text> mapCollector = new DummyOutputCollector<Text, Text>();
+      DummyOutputCollector<Text, Vector> mapCollector = new DummyOutputCollector<Text, Vector>();
       for (Vector v : sampleData)
-        mapper.map(null, new Text(v.asFormatString()), mapCollector, null);
+        mapper.map(null, v, mapCollector, null);
 
       DirichletReducer reducer = new DirichletReducer();
       reducer.configure(state);
-      DummyOutputCollector<Text, Text> reduceCollector = new DummyOutputCollector<Text, Text>();
+      DummyOutputCollector<Text, DirichletCluster<Vector>> reduceCollector = new DummyOutputCollector<Text, DirichletCluster<Vector>>();
       for (String key : mapCollector.getKeys())
         reducer.reduce(new Text(key), mapCollector.getValue(key).iterator(),
             reduceCollector, null);
@@ -322,13 +347,13 @@
     assertEquals("modelFactory", state.modelFactory.getClass().getName(),
         state2.modelFactory.getClass().getName());
     assertEquals("clusters", state.clusters.size(), state2.clusters.size());
-    assertEquals("mixture", state.mixture.size(), state2.mixture
-        .size());
+    assertEquals("mixture", state.mixture.size(), state2.mixture.size());
     assertEquals("dirichlet", state.offset, state2.offset);
   }
 
   /**
-   * Test the Mapper and Reducer using the Driver 
+   * Test the Mapper and Reducer using the Driver
+   * 
    * @throws Exception
    */
   public void testDriverMRIterations() throws Exception {
@@ -339,16 +364,21 @@
     generateSamples(100, 2, 0, 0.2);
     generateSamples(100, 0, 2, 0.3);
     generateSamples(100, 2, 2, 1);
-    writePointsToFileWithPayload(sampleData, "input/data.txt", "");
+    ClusteringTestUtils.writePointsToFile(sampleData, "input/data.txt", fs,
+        conf);
     // Now run the driver
-    DirichletDriver.runJob("input", "output",
-        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
-        10, 1.0, 1);
+    DirichletDriver
+        .runJob(
+            "input",
+            "output",
+            "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution",
+            20, 10, 1.0, 1);
     // and inspect results
     List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
     JobConf conf = new JobConf(KMeansDriver.class);
-    conf.set(DirichletDriver.MODEL_FACTORY_KEY,
-        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
+    conf
+        .set(DirichletDriver.MODEL_FACTORY_KEY,
+            "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
     conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
     conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
     for (int i = 0; i < 11; i++) {
@@ -358,8 +388,8 @@
     printResults(clusters, 0);
   }
 
-  private static void printResults(List<List<DirichletCluster<Vector>>> clusters,
-      int significant) {
+  private static void printResults(
+      List<List<DirichletCluster<Vector>>> clusters, int significant) {
     int row = 0;
     for (List<DirichletCluster<Vector>> r : clusters) {
       System.out.print("sample[" + row++ + "]= ");
@@ -377,33 +407,28 @@
   }
 
   /**
-   * Test the Mapper and Reducer using the Driver 
+   * Test the Mapper and Reducer using the Driver
+   * 
    * @throws Exception
    */
   public void testDriverMnRIterations() throws Exception {
     File f = new File("input");
     for (File g : f.listFiles())
       g.delete();
-    generateSamples(500, 0, 0, 0.5);
-    writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
-    sampleData = new ArrayList<Vector>();
-    generateSamples(500, 2, 0, 0.2);
-    writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
-    sampleData = new ArrayList<Vector>();
-    generateSamples(500, 0, 2, 0.3);
-    writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
-    sampleData = new ArrayList<Vector>();
-    generateSamples(500, 2, 2, 1);
-    writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
+    generate4Datasets();
     // Now run the driver
-    DirichletDriver.runJob("input", "output",
-        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
-        15, 1.0, 1);
+    DirichletDriver
+        .runJob(
+            "input",
+            "output",
+            "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution",
+            20, 15, 1.0, 1);
     // and inspect results
     List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
     JobConf conf = new JobConf(KMeansDriver.class);
-    conf.set(DirichletDriver.MODEL_FACTORY_KEY,
-        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
+    conf
+        .set(DirichletDriver.MODEL_FACTORY_KEY,
+            "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
     conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
     conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
     for (int i = 0; i < 11; i++) {
@@ -413,34 +438,47 @@
     printResults(clusters, 0);
   }
 
-  /**
-   * Test the Mapper and Reducer using the Driver 
-   * @throws Exception
-   */
-  public void testDriverMnRnIterations() throws Exception {
-    File f = new File("input");
-    for (File g : f.listFiles())
-      g.delete();
+  private void generate4Datasets() throws IOException {
     generateSamples(500, 0, 0, 0.5);
-    writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
+    ClusteringTestUtils.writePointsToFile(sampleData, "input/data1.txt", fs,
+        conf);
     sampleData = new ArrayList<Vector>();
     generateSamples(500, 2, 0, 0.2);
-    writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
+    ClusteringTestUtils.writePointsToFile(sampleData, "input/data2.txt", fs,
+        conf);
     sampleData = new ArrayList<Vector>();
     generateSamples(500, 0, 2, 0.3);
-    writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
+    ClusteringTestUtils.writePointsToFile(sampleData, "input/data3.txt", fs,
+        conf);
     sampleData = new ArrayList<Vector>();
     generateSamples(500, 2, 2, 1);
-    writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
+    ClusteringTestUtils.writePointsToFile(sampleData, "input/data4.txt", fs,
+        conf);
+  }
+
+  /**
+   * Test the Mapper and Reducer using the Driver
+   * 
+   * @throws Exception
+   */
+  public void testDriverMnRnIterations() throws Exception {
+    File f = new File("input");
+    for (File g : f.listFiles())
+      g.delete();
+    generate4Datasets();
     // Now run the driver
-    DirichletDriver.runJob("input", "output",
-        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution", 20,
-        15, 1.0, 2);
+    DirichletDriver
+        .runJob(
+            "input",
+            "output",
+            "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution",
+            20, 15, 1.0, 2);
     // and inspect results
     List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
     JobConf conf = new JobConf(KMeansDriver.class);
-    conf.set(DirichletDriver.MODEL_FACTORY_KEY,
-        "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
+    conf
+        .set(DirichletDriver.MODEL_FACTORY_KEY,
+            "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution");
     conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
     conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));
     for (int i = 0; i < 11; i++) {
@@ -451,25 +489,8 @@
   }
 
   /**
-   * Generate random samples and add them to the sampleData
-   * @param num int number of samples to generate
-   * @param mx double x-value of the sample mean
-   * @param my double y-value of the sample mean
-   * @param sdx double x-standard deviation of the samples
-   * @param sdy double y-standard deviation of the samples
-   */
-  private void generateSamples(int num, double mx, double my, double sdx,
-      double sdy) {
-    System.out.println("Generating " + num + " samples m=[" + mx + ", " + my
-        + "] sd=[" + sdx + ", " + sdy + "]");
-    for (int i = 0; i < num; i++)
-      sampleData.add(new DenseVector(new double[] {
-          UncommonDistributions.rNorm(mx, sdx),
-          UncommonDistributions.rNorm(my, sdy) }));
-  }
-
-  /**
-   * Test the Mapper and Reducer using the Driver 
+   * Test the Mapper and Reducer using the Driver
+   * 
    * @throws Exception
    */
   public void testDriverMnRnIterationsAsymmetric() throws Exception {
@@ -477,16 +498,20 @@
     for (File g : f.listFiles())
       g.delete();
     generateSamples(500, 0, 0, 0.5, 1.0);
-    writePointsToFileWithPayload(sampleData, "input/data1.txt", "");
+    ClusteringTestUtils.writePointsToFile(sampleData, "input/data1.txt", fs,
+        conf);
     sampleData = new ArrayList<Vector>();
-    generateSamples(500, 2, 0, 0.2, 0.1);
-    writePointsToFileWithPayload(sampleData, "input/data2.txt", "");
+    generateSamples(500, 2, 0, 0.2);
+    ClusteringTestUtils.writePointsToFile(sampleData, "input/data2.txt", fs,
+        conf);
     sampleData = new ArrayList<Vector>();
-    generateSamples(500, 0, 2, 0.3, 0.5);
-    writePointsToFileWithPayload(sampleData, "input/data3.txt", "");
+    generateSamples(500, 0, 2, 0.3);
+    ClusteringTestUtils.writePointsToFile(sampleData, "input/data3.txt", fs,
+        conf);
     sampleData = new ArrayList<Vector>();
-    generateSamples(500, 2, 2, 1, 0.5);
-    writePointsToFileWithPayload(sampleData, "input/data4.txt", "");
+    generateSamples(500, 2, 2, 1);
+    ClusteringTestUtils.writePointsToFile(sampleData, "input/data4.txt", fs,
+        conf);
     // Now run the driver
     DirichletDriver
         .runJob(
@@ -498,7 +523,8 @@
     List<List<DirichletCluster<Vector>>> clusters = new ArrayList<List<DirichletCluster<Vector>>>();
     JobConf conf = new JobConf(KMeansDriver.class);
     conf
-        .set(DirichletDriver.MODEL_FACTORY_KEY,
+        .set(
+            DirichletDriver.MODEL_FACTORY_KEY,
             "org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution");
     conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20));
     conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0));

Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java?rev=788116&r1=788115&r2=788116&view=diff
==============================================================================
--- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java (original)
+++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java Wed Jun 24 18:41:37 2009
@@ -20,6 +20,7 @@
 import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
 import org.apache.mahout.clustering.dirichlet.models.Model;
 import org.apache.mahout.clustering.dirichlet.models.ModelDistribution;
+import org.apache.mahout.clustering.dirichlet.models.NormalModel;
 import org.apache.mahout.matrix.DenseVector;
 import org.apache.mahout.matrix.Vector;
 
@@ -31,21 +32,21 @@
 
   @Override
   public Model<Vector>[] sampleFromPrior(int howMany) {
-    Model<Vector>[] result = new NormalScModel[howMany];
+    Model<Vector>[] result = new NormalModel[howMany];
     for (int i = 0; i < howMany; i++) {
       DenseVector mean = new DenseVector(60);
       for (int j = 0; j < 60; j++)
         mean.set(j, UncommonDistributions.rNorm(30, 0.5));
-      result[i] = new NormalScModel(mean, 1);
+      result[i] = new NormalModel(mean, 1);
     }
     return result;
   }
 
   @Override
   public Model<Vector>[] sampleFromPosterior(Model<Vector>[] posterior) {
-    Model<Vector>[] result = new NormalScModel[posterior.length];
+    Model<Vector>[] result = new NormalModel[posterior.length];
     for (int i = 0; i < posterior.length; i++) {
-      NormalScModel m = (NormalScModel) posterior[i];
+      NormalModel m = (NormalModel) posterior[i];
       result[i] = m.sample();
     }
     return result;