You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2010/01/19 01:11:43 UTC
svn commit: r900615 - in /lucene/mahout/trunk/core/src:
main/java/org/apache/mahout/clustering/dirichlet/models/
test/java/org/apache/mahout/clustering/dirichlet/
Author: jeastman
Date: Tue Jan 19 00:11:42 2010
New Revision: 900615
URL: http://svn.apache.org/viewvc?rev=900615&view=rev
Log:
MAHOUT-251
Created a new SparseNormalModelDistribution to begin attacking the text clustering domain. First implementation is based upon NormalModel and NormalModelDistribution but adds a filtering step to sampleFromPosterior to set insignificant mean elements to 0.0 to conserve space.
Adjusted NormalModel.toString to be more SparseVector-friendly and introduced TestSparseModelClustering tests
Added:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java
lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java
Modified:
lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=900615&r1=900614&r2=900615&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Tue Jan 19 00:11:42 2010
@@ -20,10 +20,12 @@
import org.apache.mahout.math.SquareRootFunction;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.Vector.Element;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
+import java.util.Iterator;
public class NormalModel implements Model<VectorWritable> {
@@ -55,7 +57,7 @@
int getS0() {
return s0;
}
-
+
public Vector getMean() {
return mean;
}
@@ -64,7 +66,6 @@
return stdDev;
}
-
/**
* TODO: Return a proper sample from the posterior. For now, return an instance with the same parameters
*
@@ -98,8 +99,7 @@
mean = s1.divide(s0);
// compute the average of the component stds
if (s0 > 1) {
- Vector std = s2.times(s0).minus(s1.times(s1)).assign(
- new SquareRootFunction()).divide(s0);
+ Vector std = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0);
stdDev = std.zSum() / std.size();
} else {
stdDev = Double.MIN_VALUE;
@@ -124,9 +124,15 @@
public String toString() {
StringBuilder buf = new StringBuilder();
buf.append("nm{n=").append(s0).append(" m=[");
+ int nextIx = 0;
if (mean != null) {
- for (int i = 0; i < mean.size(); i++) {
- buf.append(String.format("%.2f", mean.get(i))).append(", ");
+ // handle sparse Vectors gracefully, suppressing zero values
+ for (Iterator<Element> nzElems = mean.iterateNonZero(); nzElems.hasNext();) {
+ Element elem = nzElems.next();
+ if (elem.index() > nextIx)
+ buf.append("..{").append(elem.index()).append("}=");
+ buf.append(String.format("%.2f", mean.get(elem.index()))).append(", ");
+ nextIx = elem.index() + 1;
}
}
buf.append("] sd=").append(String.format("%.2f", stdDev)).append('}');
Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java?rev=900615&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java Tue Jan 19 00:11:42 2010
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.dirichlet.models;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * An implementation of the ModelDistribution interface suitable for testing the DirichletCluster algorithm. Uses a
+ * Normal Distribution
+ */
+public class SparseNormalModelDistribution extends VectorModelDistribution {
+
+ public SparseNormalModelDistribution(VectorWritable modelPrototype) {
+ super(modelPrototype);
+ }
+
+ public SparseNormalModelDistribution() {
+ super();
+ }
+
+ @Override
+ public Model<VectorWritable>[] sampleFromPrior(int howMany) {
+ Model<VectorWritable>[] result = new NormalModel[howMany];
+ for (int i = 0; i < howMany; i++) {
+ Vector prototype = getModelPrototype().get();
+ result[i] = new NormalModel(prototype.like(), 1);
+ }
+ return result;
+ }
+
+ @Override
+ public Model<VectorWritable>[] sampleFromPosterior(Model<VectorWritable>[] posterior) {
+ Model<VectorWritable>[] result = new NormalModel[posterior.length];
+ for (int i = 0; i < posterior.length; i++) {
+ NormalModel m = ((NormalModel) posterior[i]).sample();
+ // trim insignificant mean elements from the posterior model to save sparse vector space
+ for (int j = 0; j < m.getMean().size(); j++)
+ if (Math.abs(m.getMean().get(j)) < 0.001)
+ m.getMean().set(j, 0.0d);
+ result[i] = m;
+ }
+ return result;
+ }
+}
Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java?rev=900615&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java Tue Jan 19 00:11:42 2010
@@ -0,0 +1,129 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.dirichlet;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.SparseNormalModelDistribution;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.Vector.Element;
+import org.junit.After;
+import org.junit.Before;
+
+public class TestSparseModelClustering extends MahoutTestCase {
+
+ private List<VectorWritable> sampleData;
+
+ Random random;
+
+ @Before
+ protected void setUp() throws Exception {
+ super.setUp();
+ random = RandomUtils.getRandom();
+ sampleData = new ArrayList<VectorWritable>();
+ }
+
+ /**
+ * Generate random samples and add them to the sampleData
+ *
+ * @param num int number of samples to generate
+ * @param mx double value of the sample mean
+ * @param sd double standard deviation of the samples
+ * @param card int cardinality of the generated sample vectors
+ * @param pNz double probability a sample element is non-zero
+ */
+ private void generateSamples(int num, double mx, double sd, int card, double pNz) {
+ Vector sparse = new SequentialAccessSparseVector(card);
+ for (int i = 0; i < card; i++)
+ if (random.nextDouble() < pNz)
+ sparse.set(i, mx);
+ System.out.println("Generating " + num + printSampleParameters(sparse, sd) + " pNz=" + pNz);
+ for (int i = 0; i < num; i++) {
+ SequentialAccessSparseVector v = new SequentialAccessSparseVector(card);
+ for (int j = 0; j < card; j++) {
+ if (sparse.get(j) > 0.0)
+ v.set(j, UncommonDistributions.rNorm(mx, sd));
+ }
+ sampleData.add(new VectorWritable(v));
+ }
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ }
+
+ public void testDirichletCluster100s() {
+ System.out.println("testDirichletCluster100s");
+ generateSamples(40, 5, 3, 50, 0.1);
+ generateSamples(30, 3, 1, 50, 0.1);
+ generateSamples(30, 1, 0.1, 50, 0.1);
+
+ DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>(sampleData, new SparseNormalModelDistribution(
+ sampleData.get(0)), 1.0, 10, 1, 0);
+ List<Model<VectorWritable>[]> result = dc.cluster(10);
+ printResults(result, 1);
+ assertNotNull(result);
+ }
+
+ private static void printResults(List<Model<VectorWritable>[]> result, int significant) {
+ int row = 0;
+ for (Model<VectorWritable>[] r : result) {
+ int sig = 0;
+ for (Model<VectorWritable> model : r) {
+ if (model.count() > significant) {
+ sig++;
+ }
+ }
+ System.out.print("sample[" + row++ + "] (" + sig + ")= ");
+ for (Model<VectorWritable> model : r) {
+ if (model.count() > significant) {
+ System.out.print(model.toString() + ", ");
+ }
+ }
+ System.out.println();
+ }
+ System.out.println();
+ }
+
+ private static String printSampleParameters(Vector v, double stdDev) {
+ StringBuilder buf = new StringBuilder();
+ buf.append(" m=[");
+ int nextIx = 0;
+ if (v != null) {
+ // handle sparse Vectors gracefully, suppressing zero values
+ for (Iterator<Element> nzElems = v.iterateNonZero(); nzElems.hasNext();) {
+ Element elem = nzElems.next();
+ if (elem.index() > nextIx)
+ buf.append("..{").append(elem.index()).append("}=");
+ buf.append(String.format("%.2f", v.get(elem.index()))).append(", ");
+ nextIx = elem.index() + 1;
+ }
+ }
+ buf.append("] sd=").append(String.format("%.2f", stdDev)).append('}');
+ return buf.toString();
+
+ }
+
+}