You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by je...@apache.org on 2010/01/19 01:11:43 UTC

svn commit: r900615 - in /lucene/mahout/trunk/core/src: main/java/org/apache/mahout/clustering/dirichlet/models/ test/java/org/apache/mahout/clustering/dirichlet/

Author: jeastman
Date: Tue Jan 19 00:11:42 2010
New Revision: 900615

URL: http://svn.apache.org/viewvc?rev=900615&view=rev
Log:
MAHOUT-251

Created a new SparseNormalModelDistribution to begin attacking the text clustering domain. First implementation is based upon NormalModel and NormalModelDistribution but adds a filtering step to sampleFromPosterior to set insignificant mean elements to 0.0 to conserve space.

Adjusted NormalModel.toString to be more SparseVector-friendly and introduced TestSparseModelClustering tests

Added:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java
    lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java
Modified:
    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java

Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=900615&r1=900614&r2=900615&view=diff
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Tue Jan 19 00:11:42 2010
@@ -20,10 +20,12 @@
 import org.apache.mahout.math.SquareRootFunction;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.Vector.Element;
 
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.util.Iterator;
 
 public class NormalModel implements Model<VectorWritable> {
 
@@ -55,7 +57,7 @@
   int getS0() {
     return s0;
   }
-  
+
   public Vector getMean() {
     return mean;
   }
@@ -64,7 +66,6 @@
     return stdDev;
   }
 
-
   /**
    * TODO: Return a proper sample from the posterior. For now, return an instance with the same parameters
    *
@@ -98,8 +99,7 @@
     mean = s1.divide(s0);
     // compute the average of the component stds
     if (s0 > 1) {
-      Vector std = s2.times(s0).minus(s1.times(s1)).assign(
-          new SquareRootFunction()).divide(s0);
+      Vector std = s2.times(s0).minus(s1.times(s1)).assign(new SquareRootFunction()).divide(s0);
       stdDev = std.zSum() / std.size();
     } else {
       stdDev = Double.MIN_VALUE;
@@ -124,9 +124,15 @@
   public String toString() {
     StringBuilder buf = new StringBuilder();
     buf.append("nm{n=").append(s0).append(" m=[");
+    int nextIx = 0;
     if (mean != null) {
-      for (int i = 0; i < mean.size(); i++) {
-        buf.append(String.format("%.2f", mean.get(i))).append(", ");
+      // handle sparse Vectors gracefully, suppressing zero values
+      for (Iterator<Element> nzElems = mean.iterateNonZero(); nzElems.hasNext();) {
+        Element elem = nzElems.next();
+        if (elem.index() > nextIx)
+          buf.append("..{").append(elem.index()).append("}=");
+        buf.append(String.format("%.2f", mean.get(elem.index()))).append(", ");
+        nextIx = elem.index() + 1;
       }
     }
     buf.append("] sd=").append(String.format("%.2f", stdDev)).append('}');

Added: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java?rev=900615&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java (added)
+++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SparseNormalModelDistribution.java Tue Jan 19 00:11:42 2010
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.clustering.dirichlet.models;
+
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+
+/**
+ * An implementation of the ModelDistribution interface suitable for testing the DirichletCluster algorithm. Uses a
+ * Normal Distribution
+ */
+public class SparseNormalModelDistribution extends VectorModelDistribution {
+
+  public SparseNormalModelDistribution(VectorWritable modelPrototype) {
+    super(modelPrototype);
+  }
+
+  public SparseNormalModelDistribution() {
+    super();
+  }
+
+  @Override
+  public Model<VectorWritable>[] sampleFromPrior(int howMany) {
+    Model<VectorWritable>[] result = new NormalModel[howMany];
+    for (int i = 0; i < howMany; i++) {
+      Vector prototype = getModelPrototype().get();
+      result[i] = new NormalModel(prototype.like(), 1);
+    }
+    return result;
+  }
+
+  @Override
+  public Model<VectorWritable>[] sampleFromPosterior(Model<VectorWritable>[] posterior) {
+    Model<VectorWritable>[] result = new NormalModel[posterior.length];
+    for (int i = 0; i < posterior.length; i++) {
+      NormalModel m = ((NormalModel) posterior[i]).sample();
+      // trim insignificant mean elements from the posterior model to save sparse vector space
+      for (int j = 0; j < m.getMean().size(); j++)
+        if (Math.abs(m.getMean().get(j)) < 0.001)
+          m.getMean().set(j, 0.0d);
+      result[i] = m;
+    }
+    return result;
+  }
+}

Added: lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java
URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java?rev=900615&view=auto
==============================================================================
--- lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java (added)
+++ lucene/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestSparseModelClustering.java Tue Jan 19 00:11:42 2010
@@ -0,0 +1,129 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering.dirichlet;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.mahout.clustering.dirichlet.models.Model;
+import org.apache.mahout.clustering.dirichlet.models.SparseNormalModelDistribution;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.SequentialAccessSparseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.Vector.Element;
+import org.junit.After;
+import org.junit.Before;
+
+public class TestSparseModelClustering extends MahoutTestCase {
+
+  private List<VectorWritable> sampleData;
+
+  Random random;
+
+  @Before
+  protected void setUp() throws Exception {
+    super.setUp();
+    random = RandomUtils.getRandom();
+    sampleData = new ArrayList<VectorWritable>();
+  }
+
+  /**
+   * Generate random samples and add them to the sampleData
+   *
+   * @param num int number of samples to generate
+   * @param mx  double value of the sample mean
+   * @param sd  double standard deviation of the samples
+   * @param card int cardinality of the generated sample vectors
+   * @param pNz double probability a sample element is non-zero
+   */
+  private void generateSamples(int num, double mx, double sd, int card, double pNz) {
+    Vector sparse = new SequentialAccessSparseVector(card);
+    for (int i = 0; i < card; i++)
+      if (random.nextDouble() < pNz)
+        sparse.set(i, mx);
+    System.out.println("Generating " + num + printSampleParameters(sparse, sd) + " pNz=" + pNz);
+    for (int i = 0; i < num; i++) {
+      SequentialAccessSparseVector v = new SequentialAccessSparseVector(card);
+      for (int j = 0; j < card; j++) {
+        if (sparse.get(j) > 0.0)
+          v.set(j, UncommonDistributions.rNorm(mx, sd));
+      }
+      sampleData.add(new VectorWritable(v));
+    }
+  }
+
+  @After
+  public void tearDown() throws Exception {
+  }
+
+  public void testDirichletCluster100s() {
+    System.out.println("testDirichletCluster100s");
+    generateSamples(40, 5, 3, 50, 0.1);
+    generateSamples(30, 3, 1, 50, 0.1);
+    generateSamples(30, 1, 0.1, 50, 0.1);
+
+    DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>(sampleData, new SparseNormalModelDistribution(
+        sampleData.get(0)), 1.0, 10, 1, 0);
+    List<Model<VectorWritable>[]> result = dc.cluster(10);
+    printResults(result, 1);
+    assertNotNull(result);
+  }
+
+  private static void printResults(List<Model<VectorWritable>[]> result, int significant) {
+    int row = 0;
+    for (Model<VectorWritable>[] r : result) {
+      int sig = 0;
+      for (Model<VectorWritable> model : r) {
+        if (model.count() > significant) {
+          sig++;
+        }
+      }
+      System.out.print("sample[" + row++ + "] (" + sig + ")= ");
+      for (Model<VectorWritable> model : r) {
+        if (model.count() > significant) {
+          System.out.print(model.toString() + ", ");
+        }
+      }
+      System.out.println();
+    }
+    System.out.println();
+  }
+
+  private static String printSampleParameters(Vector v, double stdDev) {
+    StringBuilder buf = new StringBuilder();
+    buf.append(" m=[");
+    int nextIx = 0;
+    if (v != null) {
+      // handle sparse Vectors gracefully, suppressing zero values
+      for (Iterator<Element> nzElems = v.iterateNonZero(); nzElems.hasNext();) {
+        Element elem = nzElems.next();
+        if (elem.index() > nextIx)
+          buf.append("..{").append(elem.index()).append("}=");
+        buf.append(String.format("%.2f", v.get(elem.index()))).append(", ");
+        nextIx = elem.index() + 1;
+      }
+    }
+    buf.append("] sd=").append(String.format("%.2f", stdDev)).append('}');
+    return buf.toString();
+
+  }
+
+}