You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2010/01/20 18:03:03 UTC

svn commit: r901279 - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/clustering/dirichlet/models/ test/java/org/apache/mahout/clustering/dirichlet/

Author: jeastman
Date: Wed Jan 20 17:03:02 2010
New Revision: 901279

URL: http://svn.apache.org/viewvc?rev=901279&view=rev
Log:
MAHOUT-251

- Removed SparseNormalModelDistribution
- Removed TestSparseModelClustering

- Added L1Model
- Added L1ModelDistribution
- Added TestL1ModelClustering which builds Lucene index and Mahout utils TFIDF

- Fixed defect in NormalModel.toString since iterateNonZero does not iterate in sequential order

Unit test seems to find 5 models for the 8 documents, 4 of which are very similar, and models have the correct non-zero terms

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java
      - copied, changed from r900615, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java
      - copied, changed from r900615, lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java
Removed:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java?rev=901279&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java Wed Jan 20 17:03:02 2010
@@ -0,0 +1,89 @@
+package org.apache.mahout.clustering.dirichlet.models;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.Vector.Element;
+
+public class L1Model implements Model<VectorWritable> {
+
+  private static DistanceMeasure measure = new ManhattanDistanceMeasure();
+
+  public L1Model() {
+    super();
+  }
+
+  public L1Model(Vector v) {
+    observed = v.like();
+    coefficients = v;
+  }
+
+  private Vector coefficients;
+
+  private int count = 0;
+
+  private Vector observed;
+
+  @Override
+  public void computeParameters() {
+    coefficients = observed.divide(count);
+  }
+
+  @Override
+  public int count() {
+    return count;
+  }
+
+  @Override
+  public void observe(VectorWritable x) {
+    count++;
+    x.get().addTo(observed);
+  }
+
+  @Override
+  public double pdf(VectorWritable x) {
+    return Math.exp(-measure.distance(x.get(), coefficients));
+  }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    coefficients = VectorWritable.readVector(in);
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    VectorWritable.writeVector(out, coefficients);
+  }
+
+  public L1Model sample() {
+    return new L1Model(coefficients.clone());
+  }
+
+  @Override
+  public String toString() {
+    StringBuilder buf = new StringBuilder();
+    buf.append("l1m{n=").append(count).append(" c=[");
+    int nextIx = 0;
+    if (coefficients != null) {
+      // handle sparse Vectors gracefully, suppressing zero values
+      for (int i = 0; i < coefficients.size(); i++) {
+        double elem = coefficients.get(i);
+        if (elem == 0.0)
+          continue;
+        if (i > nextIx)
+          buf.append("..{").append(i).append("}=");
+        buf.append(String.format("%.2f", elem)).append(", ");
+        nextIx = i + 1;
+      }
+    }
+    buf.append("]}");
+    return buf.toString();
+  }
+
+}

Copied: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java (from r900615, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java)
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java?p2=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java&p1=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java&r1=900615&r2=901279&rev=901279&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java Wed Jan 20 17:03:02 2010
@@ -22,38 +22,33 @@
 
 /**
  * An implementation of the ModelDistribution interface suitable for testing the DirichletCluster algorithm. Uses a
- * Normal Distribution
+ * L1Distribution
  */
-public class SparseNormalModelDistribution extends VectorModelDistribution {
+public class L1ModelDistribution extends VectorModelDistribution {
 
-  public SparseNormalModelDistribution(VectorWritable modelPrototype) {
+  public L1ModelDistribution(VectorWritable modelPrototype) {
     super(modelPrototype);
   }
 
-  public SparseNormalModelDistribution() {
+  public L1ModelDistribution() {
     super();
   }
 
   @Override
   public Model<VectorWritable>[] sampleFromPrior(int howMany) {
-    Model<VectorWritable>[] result = new NormalModel[howMany];
+    Model<VectorWritable>[] result = new L1Model[howMany];
     for (int i = 0; i < howMany; i++) {
       Vector prototype = getModelPrototype().get();
-      result[i] = new NormalModel(prototype.like(), 1);
+      result[i] = new L1Model(prototype.like());
     }
     return result;
   }
 
   @Override
   public Model<VectorWritable>[] sampleFromPosterior(Model<VectorWritable>[] posterior) {
-    Model<VectorWritable>[] result = new NormalModel[posterior.length];
+    Model<VectorWritable>[] result = new L1Model[posterior.length];
     for (int i = 0; i < posterior.length; i++) {
-      NormalModel m = ((NormalModel) posterior[i]).sample();
-      // trim insignificant mean elements from the posterior model to save sparse vector space
-      for (int j = 0; j < m.getMean().size(); j++)
-        if (Math.abs(m.getMean().get(j)) < 0.001)
-          m.getMean().set(j, 0.0d);
-      result[i] = m;
+      result[i] = ((L1Model) posterior[i]).sample();
     }
     return result;
   }

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=901279&r1=901278&r2=901279&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Wed Jan 20 17:03:02 2010
@@ -127,12 +127,14 @@
     int nextIx = 0;
     if (mean != null) {
       // handle sparse Vectors gracefully, suppressing zero values
-      for (Iterator<Element> nzElems = mean.iterateNonZero(); nzElems.hasNext();) {
-        Element elem = nzElems.next();
-        if (elem.index() > nextIx)
-          buf.append("..{").append(elem.index()).append("}=");
-        buf.append(String.format("%.2f", mean.get(elem.index()))).append(", ");
-        nextIx = elem.index() + 1;
+      for (int i = 0; i < mean.size(); i++) {
+        double elem = mean.get(i);
+        if (elem == 0.0)
+          continue;
+        if (i > nextIx)
+          buf.append("..{").append(i).append("}=");
+        buf.append(String.format("%.2f", elem)).append(", ");
+        nextIx = i + 1;
       }
     }
     buf.append("] sd=").append(String.format("%.2f", stdDev)).append('}');

Copied: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java (from r900615, lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java)
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java?p2=lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java&p1=lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java&r1=900615&r2=901279&rev=901279&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java Wed Jan 20 17:03:02 2010
@@ -21,70 +21,95 @@
 import java.util.List;
 import java.util.Random;
 
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.store.RAMDirectory;
+import org.apache.lucene.util.Version;
+import org.apache.mahout.clustering.dirichlet.models.L1ModelDistribution;
 import org.apache.mahout.clustering.dirichlet.models.Model;
-import org.apache.mahout.clustering.dirichlet.models.SparseNormalModelDistribution;
 import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.math.SequentialAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.utils.vectors.TFIDF;
+import org.apache.mahout.utils.vectors.TermInfo;
+import org.apache.mahout.utils.vectors.Weight;
+import org.apache.mahout.utils.vectors.lucene.CachedTermInfo;
+import org.apache.mahout.utils.vectors.lucene.LuceneIterable;
+import org.apache.mahout.utils.vectors.lucene.TFDFMapper;
+import org.apache.mahout.utils.vectors.lucene.VectorMapper;
 import org.junit.After;
 import org.junit.Before;
 
-public class TestSparseModelClustering extends MahoutTestCase {
+public class TestL1ModelClustering extends MahoutTestCase {
+
+  private static final String[] DOCS = { "The quick red fox jumped over the lazy brown dogs.",
+      "The quick brown fox jumped over the lazy red dogs.", "The quick red cat jumped over the lazy brown dogs.",
+      "The quick brown cat jumped over the lazy red dogs.", "Mary had a little lamb whose fleece was white as snow.",
+      "Moby Dick is a story of a whale and a man obsessed.", "The robber wore a black fleece jacket and a baseball cap.",
+      "The English Springer Spaniel is the best of all dogs." };
 
   private List<VectorWritable> sampleData;
 
   Random random;
 
+  private RAMDirectory directory;
+
   @Before
   protected void setUp() throws Exception {
     super.setUp();
     random = RandomUtils.getRandom();
     sampleData = new ArrayList<VectorWritable>();
-  }
 
-  /**
-   * Generate random samples and add them to the sampleData
-   *
-   * @param num int number of samples to generate
-   * @param mx  double value of the sample mean
-   * @param sd  double standard deviation of the samples
-   * @param card int cardinality of the generated sample vectors
-   * @param pNz double probability a sample element is non-zero
-   */
-  private void generateSamples(int num, double mx, double sd, int card, double pNz) {
-    Vector sparse = new SequentialAccessSparseVector(card);
-    for (int i = 0; i < card; i++)
-      if (random.nextDouble() < pNz)
-        sparse.set(i, mx);
-    System.out.println("Generating " + num + printSampleParameters(sparse, sd) + " pNz=" + pNz);
-    for (int i = 0; i < num; i++) {
-      SequentialAccessSparseVector v = new SequentialAccessSparseVector(card);
-      for (int j = 0; j < card; j++) {
-        if (sparse.get(j) > 0.0)
-          v.set(j, UncommonDistributions.rNorm(mx, sd));
-      }
-      sampleData.add(new VectorWritable(v));
+    directory = new RAMDirectory();
+    IndexWriter writer = new IndexWriter(directory, new StandardAnalyzer(Version.LUCENE_CURRENT), true,
+        IndexWriter.MaxFieldLength.UNLIMITED);
+    for (int i = 0; i < DOCS.length; i++) {
+      Document doc = new Document();
+      Field id = new Field("id", "doc_" + i, Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS);
+      doc.add(id);
+      //Store both position and offset information
+      Field text = new Field("content", DOCS[i], Field.Store.NO, Field.Index.ANALYZED, Field.TermVector.YES);
+      doc.add(text);
+      writer.addDocument(doc);
     }
+    writer.close();
   }
 
   @After
-  public void tearDown() throws Exception {
+  protected void tearDown() throws Exception {
   }
 
-  public void testDirichletCluster100s() {
-    System.out.println("testDirichletCluster100s");
-    generateSamples(40, 5, 3, 50, 0.1);
-    generateSamples(30, 3, 1, 50, 0.1);
-    generateSamples(30, 1, 0.1, 50, 0.1);
+  private static String formatVector(Vector v) {
+    StringBuilder buf = new StringBuilder();
+    int nzero = 0;
+    Iterator<Element> iterateNonZero = v.iterateNonZero();
+    while (iterateNonZero.hasNext()) {
+      iterateNonZero.next();
+      nzero++;
+    }
+    buf.append("(").append(nzero);
+    buf.append("nz) [");
+    int nextIx = 0;
+    if (v != null) {
+      // handle sparse Vectors gracefully, suppressing zero values
+      for (int i = 0; i < v.size(); i++) {
+        double elem = v.get(i);
+        if (elem == 0.0)
+          continue;
+        if (i > nextIx)
+          buf.append("..{").append(i).append("}=");
+        buf.append(String.format("%.2f", elem)).append(", ");
+        nextIx = i + 1;
+      }
+    }
+    buf.append("]");
+    return buf.toString();
 
-    DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>(sampleData, new SparseNormalModelDistribution(
-        sampleData.get(0)), 1.0, 10, 1, 0);
-    List<Model<VectorWritable>[]> result = dc.cluster(10);
-    printResults(result, 1);
-    assertNotNull(result);
   }
 
   private static void printResults(List<Model<VectorWritable>[]> result, int significant) {
@@ -107,22 +132,23 @@
     System.out.println();
   }
 
-  private static String printSampleParameters(Vector v, double stdDev) {
-    StringBuilder buf = new StringBuilder();
-    buf.append(" m=[");
-    int nextIx = 0;
-    if (v != null) {
-      // handle sparse Vectors gracefully, suppressing zero values
-      for (Iterator<Element> nzElems = v.iterateNonZero(); nzElems.hasNext();) {
-        Element elem = nzElems.next();
-        if (elem.index() > nextIx)
-          buf.append("..{").append(elem.index()).append("}=");
-        buf.append(String.format("%.2f", v.get(elem.index()))).append(", ");
-        nextIx = elem.index() + 1;
-      }
+  public void testDocs() throws Exception {
+    IndexReader reader = IndexReader.open(directory, true);
+    Weight weight = new TFIDF();
+    TermInfo termInfo = new CachedTermInfo(reader, "content", 1, 100);
+    VectorMapper mapper = new TFDFMapper(reader, weight, termInfo);
+    LuceneIterable iterable = new LuceneIterable(reader, "id", "content", mapper);
+
+    for (Vector vector : iterable) {
+      assertNotNull(vector);
+      System.out.println("Vector=" + formatVector(vector));
+      sampleData.add(new VectorWritable(vector));
     }
-    buf.append("] sd=").append(String.format("%.2f", stdDev)).append('}');
-    return buf.toString();
+    DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>(sampleData, new L1ModelDistribution(sampleData
+        .get(0)), 1.0, 15, 1, 0);
+    List<Model<VectorWritable>[]> result = dc.cluster(10);
+    printResults(result, 0);
+    assertNotNull(result);
 
   }