You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2010/01/20 18:03:03 UTC
svn commit: r901279 - in /lucene/mahout/trunk/core/src:
main/java/org/apache/mahout/clustering/dirichlet/models/
test/java/org/apache/mahout/clustering/dirichlet/
Author: jeastman
Date: Wed Jan 20 17:03:02 2010
New Revision: 901279
URL: http://svn.apache.org/viewvc?rev=901279&view=rev
Log:
MAHOUT-251
- Removed SparseNormalModelDistribution
- Removed TestSparseModelClustering
- Added L1Model
- Added L1ModelDistribution
- Added TestL1ModelClustering which builds Lucene index and Mahout utils TFIDF
- Fixed defect in NormalModel.toString since iterateNonZero does not iterate in sequential order
Unit test seems to find 5 models for the 8 documents, 4 of which are very similar, and models have the correct non-zero terms
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java
- copied, changed from r900615, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java
- copied, changed from r900615, lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java
Removed:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java?rev=901279&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java Wed Jan 20 17:03:02 2010
@@ -0,0 +1,89 @@
+package org.apache.mahout.clustering.dirichlet.models;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Iterator;
+
+import org.apache.mahout.common.distance.DistanceMeasure;
+import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.Vector.Element;
+
+public class L1Model implements Model<VectorWritable> {
+
+ private static DistanceMeasure measure = new ManhattanDistanceMeasure();
+
+ public L1Model() {
+ super();
+ }
+
+ public L1Model(Vector v) {
+ observed = v.like();
+ coefficients = v;
+ }
+
+ private Vector coefficients;
+
+ private int count = 0;
+
+ private Vector observed;
+
+ @Override
+ public void computeParameters() {
+ coefficients = observed.divide(count);
+ }
+
+ @Override
+ public int count() {
+ return count;
+ }
+
+ @Override
+ public void observe(VectorWritable x) {
+ count++;
+ x.get().addTo(observed);
+ }
+
+ @Override
+ public double pdf(VectorWritable x) {
+ return Math.exp(-measure.distance(x.get(), coefficients));
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ coefficients = VectorWritable.readVector(in);
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ VectorWritable.writeVector(out, coefficients);
+ }
+
+ public L1Model sample() {
+ return new L1Model(coefficients.clone());
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder buf = new StringBuilder();
+ buf.append("l1m{n=").append(count).append(" c=[");
+ int nextIx = 0;
+ if (coefficients != null) {
+ // handle sparse Vectors gracefully, suppressing zero values
+ for (int i = 0; i < coefficients.size(); i++) {
+ double elem = coefficients.get(i);
+ if (elem == 0.0)
+ continue;
+ if (i > nextIx)
+ buf.append("..{").append(i).append("}=");
+ buf.append(String.format("%.2f", elem)).append(", ");
+ nextIx = i + 1;
+ }
+ }
+ buf.append("]}");
+ return buf.toString();
+ }
+
+}
Copied: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java (from r900615, lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java)
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java?p2=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java&p1=lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java&r1=900615&r2=901279&rev=901279&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java Wed Jan 20 17:03:02 2010
@@ -22,38 +22,33 @@
/**
* An implementation of the ModelDistribution interface suitable for testing the DirichletCluster algorithm. Uses a
- * Normal Distribution
+ * L1Distribution
*/
-public class SparseNormalModelDistribution extends VectorModelDistribution {
+public class L1ModelDistribution extends VectorModelDistribution {
- public SparseNormalModelDistribution(VectorWritable modelPrototype) {
+ public L1ModelDistribution(VectorWritable modelPrototype) {
super(modelPrototype);
}
- public SparseNormalModelDistribution() {
+ public L1ModelDistribution() {
super();
}
@Override
public Model<VectorWritable>[] sampleFromPrior(int howMany) {
- Model<VectorWritable>[] result = new NormalModel[howMany];
+ Model<VectorWritable>[] result = new L1Model[howMany];
for (int i = 0; i < howMany; i++) {
Vector prototype = getModelPrototype().get();
- result[i] = new NormalModel(prototype.like(), 1);
+ result[i] = new L1Model(prototype.like());
}
return result;
}
@Override
public Model<VectorWritable>[] sampleFromPosterior(Model<VectorWritable>[] posterior) {
- Model<VectorWritable>[] result = new NormalModel[posterior.length];
+ Model<VectorWritable>[] result = new L1Model[posterior.length];
for (int i = 0; i < posterior.length; i++) {
- NormalModel m = ((NormalModel) posterior[i]).sample();
- // trim insignificant mean elements from the posterior model to save sparse vector space
- for (int j = 0; j < m.getMean().size(); j++)
- if (Math.abs(m.getMean().get(j)) < 0.001)
- m.getMean().set(j, 0.0d);
- result[i] = m;
+ result[i] = ((L1Model) posterior[i]).sample();
}
return result;
}
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=901279&r1=901278&r2=901279&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Wed Jan 20 17:03:02 2010
@@ -127,12 +127,14 @@
int nextIx = 0;
if (mean != null) {
// handle sparse Vectors gracefully, suppressing zero values
- for (Iterator<Element> nzElems = mean.iterateNonZero(); nzElems.hasNext();) {
- Element elem = nzElems.next();
- if (elem.index() > nextIx)
- buf.append("..{").append(elem.index()).append("}=");
- buf.append(String.format("%.2f", mean.get(elem.index()))).append(", ");
- nextIx = elem.index() + 1;
+ for (int i = 0; i < mean.size(); i++) {
+ double elem = mean.get(i);
+ if (elem == 0.0)
+ continue;
+ if (i > nextIx)
+ buf.append("..{").append(i).append("}=");
+ buf.append(String.format("%.2f", elem)).append(", ");
+ nextIx = i + 1;
}
}
buf.append("] sd=").append(String.format("%.2f", stdDev)).append('}');
Copied: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java (from r900615, lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java)
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java?p2=lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java&p1=lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java&r1=900615&r2=901279&rev=901279&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java (original)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java Wed Jan 20 17:03:02 2010
@@ -21,70 +21,95 @@
import java.util.List;
import java.util.Random;
+import org.apache.lucene.analysis.standard.StandardAnalyzer;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.store.RAMDirectory;
+import org.apache.lucene.util.Version;
+import org.apache.mahout.clustering.dirichlet.models.L1ModelDistribution;
import org.apache.mahout.clustering.dirichlet.models.Model;
-import org.apache.mahout.clustering.dirichlet.models.SparseNormalModelDistribution;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.math.SequentialAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.Vector.Element;
+import org.apache.mahout.utils.vectors.TFIDF;
+import org.apache.mahout.utils.vectors.TermInfo;
+import org.apache.mahout.utils.vectors.Weight;
+import org.apache.mahout.utils.vectors.lucene.CachedTermInfo;
+import org.apache.mahout.utils.vectors.lucene.LuceneIterable;
+import org.apache.mahout.utils.vectors.lucene.TFDFMapper;
+import org.apache.mahout.utils.vectors.lucene.VectorMapper;
import org.junit.After;
import org.junit.Before;
-public class TestSparseModelClustering extends MahoutTestCase {
+public class TestL1ModelClustering extends MahoutTestCase {
+
+ private static final String[] DOCS = { "The quick red fox jumped over the lazy brown dogs.",
+ "The quick brown fox jumped over the lazy red dogs.", "The quick red cat jumped over the lazy brown dogs.",
+ "The quick brown cat jumped over the lazy red dogs.", "Mary had a little lamb whose fleece was white as snow.",
+ "Moby Dick is a story of a whale and a man obsessed.", "The robber wore a black fleece jacket and a baseball cap.",
+ "The English Springer Spaniel is the best of all dogs." };
private List<VectorWritable> sampleData;
Random random;
+ private RAMDirectory directory;
+
@Before
protected void setUp() throws Exception {
super.setUp();
random = RandomUtils.getRandom();
sampleData = new ArrayList<VectorWritable>();
- }
- /**
- * Generate random samples and add them to the sampleData
- *
- * @param num int number of samples to generate
- * @param mx double value of the sample mean
- * @param sd double standard deviation of the samples
- * @param card int cardinality of the generated sample vectors
- * @param pNz double probability a sample element is non-zero
- */
- private void generateSamples(int num, double mx, double sd, int card, double pNz) {
- Vector sparse = new SequentialAccessSparseVector(card);
- for (int i = 0; i < card; i++)
- if (random.nextDouble() < pNz)
- sparse.set(i, mx);
- System.out.println("Generating " + num + printSampleParameters(sparse, sd) + " pNz=" + pNz);
- for (int i = 0; i < num; i++) {
- SequentialAccessSparseVector v = new SequentialAccessSparseVector(card);
- for (int j = 0; j < card; j++) {
- if (sparse.get(j) > 0.0)
- v.set(j, UncommonDistributions.rNorm(mx, sd));
- }
- sampleData.add(new VectorWritable(v));
+ directory = new RAMDirectory();
+ IndexWriter writer = new IndexWriter(directory, new StandardAnalyzer(Version.LUCENE_CURRENT), true,
+ IndexWriter.MaxFieldLength.UNLIMITED);
+ for (int i = 0; i < DOCS.length; i++) {
+ Document doc = new Document();
+ Field id = new Field("id", "doc_" + i, Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS);
+ doc.add(id);
+ //Store both position and offset information
+ Field text = new Field("content", DOCS[i], Field.Store.NO, Field.Index.ANALYZED, Field.TermVector.YES);
+ doc.add(text);
+ writer.addDocument(doc);
}
+ writer.close();
}
@After
- public void tearDown() throws Exception {
+ protected void tearDown() throws Exception {
}
- public void testDirichletCluster100s() {
- System.out.println("testDirichletCluster100s");
- generateSamples(40, 5, 3, 50, 0.1);
- generateSamples(30, 3, 1, 50, 0.1);
- generateSamples(30, 1, 0.1, 50, 0.1);
+ private static String formatVector(Vector v) {
+ StringBuilder buf = new StringBuilder();
+ int nzero = 0;
+ Iterator<Element> iterateNonZero = v.iterateNonZero();
+ while (iterateNonZero.hasNext()) {
+ iterateNonZero.next();
+ nzero++;
+ }
+ buf.append("(").append(nzero);
+ buf.append("nz) [");
+ int nextIx = 0;
+ if (v != null) {
+ // handle sparse Vectors gracefully, suppressing zero values
+ for (int i = 0; i < v.size(); i++) {
+ double elem = v.get(i);
+ if (elem == 0.0)
+ continue;
+ if (i > nextIx)
+ buf.append("..{").append(i).append("}=");
+ buf.append(String.format("%.2f", elem)).append(", ");
+ nextIx = i + 1;
+ }
+ }
+ buf.append("]");
+ return buf.toString();
- DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>(sampleData, new SparseNormalModelDistribution(
- sampleData.get(0)), 1.0, 10, 1, 0);
- List<Model<VectorWritable>[]> result = dc.cluster(10);
- printResults(result, 1);
- assertNotNull(result);
}
private static void printResults(List<Model<VectorWritable>[]> result, int significant) {
@@ -107,22 +132,23 @@
System.out.println();
}
- private static String printSampleParameters(Vector v, double stdDev) {
- StringBuilder buf = new StringBuilder();
- buf.append(" m=[");
- int nextIx = 0;
- if (v != null) {
- // handle sparse Vectors gracefully, suppressing zero values
- for (Iterator<Element> nzElems = v.iterateNonZero(); nzElems.hasNext();) {
- Element elem = nzElems.next();
- if (elem.index() > nextIx)
- buf.append("..{").append(elem.index()).append("}=");
- buf.append(String.format("%.2f", v.get(elem.index()))).append(", ");
- nextIx = elem.index() + 1;
- }
+ public void testDocs() throws Exception {
+ IndexReader reader = IndexReader.open(directory, true);
+ Weight weight = new TFIDF();
+ TermInfo termInfo = new CachedTermInfo(reader, "content", 1, 100);
+ VectorMapper mapper = new TFDFMapper(reader, weight, termInfo);
+ LuceneIterable iterable = new LuceneIterable(reader, "id", "content", mapper);
+
+ for (Vector vector : iterable) {
+ assertNotNull(vector);
+ System.out.println("Vector=" + formatVector(vector));
+ sampleData.add(new VectorWritable(vector));
}
- buf.append("] sd=").append(String.format("%.2f", stdDev)).append('}');
- return buf.toString();
+ DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>(sampleData, new L1ModelDistribution(sampleData
+ .get(0)), 1.0, 15, 1, 0);
+ List<Model<VectorWritable>[]> result = dc.cluster(10);
+ printResults(result, 0);
+ assertNotNull(result);
}